Getting Started
Installation
Install Julia v1.10 or above. Lux.jl is available through the Julia package manager. You can enter it by pressing ]
in the REPL and then typing add Lux
. Alternatively, you can also do
import Pkg
Pkg.add("Lux")
Update to v1
If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.
Quickstart
Pre-Requisites
You need to install Optimisers
and Zygote
if not done already. Pkg.add(["Optimisers", "Zygote"])
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
We take randomness very seriously
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
Random.TaskLocalRNG()
Build the model
# Construct the layer
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))
Chain(
layer_1 = Dense(128 => 256, tanh), # 33_024 parameters
layer_2 = Chain(
layer_1 = Dense(256 => 1, tanh), # 257 parameters
layer_2 = Dense(1 => 10), # 20 parameters
),
) # Total: 33_301 parameters,
# plus 0 states.
Models don't hold parameters and states so initialize them. From there on, we can just use our standard AD and Optimisers API. However, here we will show how to use Lux's Training API that provides an uniform API over all supported AD systems.
# Get the device determined by Lux
dev = gpu_device()
# Parameter and State Variables
ps, st = Lux.setup(rng, model) |> dev
# Dummy Input
x = rand(rng, Float32, 128, 2) |> dev
# Run the model
y, st = Lux.apply(model, x, ps, st)
# Gradients
## First construct a TrainState
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
## We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux.Training.compute_gradients(
AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state
)
## Optimization
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)
# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training.single_train_step!(
AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state
)
((layer_1 = (weight = Float32[0.0017983615 0.006062332 … 0.0053392933 0.0056276177; 0.0011292367 0.0041270256 … 0.003585879 0.0038155357; … ; -0.0008762945 -0.0031371699 … -0.0027350332 -0.0029033197; 0.0011154839 0.002197485 … 0.0021741025 0.0021157824], bias = Float32[0.006656272, 0.004425203, 0.0028994146, -0.0116051175, 0.0031301186, 0.0037318026, 0.0136483535, 0.013969757, -0.015173428, -0.005173992 … -0.0018621369, -0.0015270555, -0.007873881, -0.0076395273, -0.0022123815, 0.0039605754, 0.0034407252, -0.0045406874, -0.003383829, 0.0029306945]), layer_2 = (layer_1 = (weight = Float32[0.04993449 0.03202845 … -0.059382 0.07701616], bias = Float32[0.08797912]), layer_2 = (weight = Float32[-0.094527975; -0.11476975; … ; -0.016841749; -0.0698748;;], bias = Float32[-0.21608135, -0.26255828, -0.23534852, -0.21524015, -0.055711076, -0.20314303, -0.1895644, 0.03666526, -0.03937737, -0.15905891]))), 0.8455785f0, NamedTuple(), Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, Adam{Float32, Tuple{Float64, Float64}, Float64}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}}(nothing, nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}((layer_1 = Dense(128 => 256, tanh), layer_2 = Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(256 => 1, tanh), layer_2 = Dense(1 => 10)), nothing)), nothing), (layer_1 = (weight = Float32[-0.22542597 0.22379348 … 0.1997513 -0.018708104; -0.023026714 0.15451026 … -0.065325744 0.18120264; … ; 0.038037397 -0.07125516 … -0.03306083 0.039138064; -0.18810266 -0.09693537 … -0.18102062 0.019230088], bias = Float32[0.030937059, -0.060276944, 0.084569596, 0.00040024254, -0.065509446, -0.08527214, -0.026523968, 0.06347208, 0.042247728, 0.027705256 … -0.06052852, 0.03504307, -0.028244259, 0.06788022, 0.0027464977, -0.06942153, 0.0064240773, 0.0141069945, -0.029283267, 0.01174226]), layer_2 = (layer_1 = (weight = Float32[0.12008221 0.06026435 … -0.070576 0.1577647], bias = Float32[0.026844418]), layer_2 = (weight = Float32[0.5345728; -0.28288874; … ; -0.32983455; -0.45298168;;], bias = Float32[-0.59751064, -0.7033041, -0.8457602, -0.53789175, -0.31473723, 0.17461234, -0.82945836, 0.67841595, 0.35837248, -0.14941788]))), (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (layer_1 = (weight = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.000926728 0.000860063 … 0.00110328 0.000908301; 0.000480834 0.000574605 … 0.000665883 0.000584197; … ; -0.000391039 -0.000438617 … -0.000520651 -0.000449867; 0.00106235 0.000365587 … 0.000813131 0.000495484], Float32[7.20343f-8 4.46976f-8 … 6.84867f-8 4.63952f-8; 1.79691f-8 2.02649f-8 … 2.45046f-8 1.96227f-8; … ; 1.21215f-8 1.17657f-8 … 1.50136f-8 1.15681f-8; 1.12738f-7 7.45199f-9 … 4.8495f-8 1.44173f-8], (0.729, 0.997003))), bias = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.00169459, 0.000977637, 0.00103866, -0.00234933, 0.000659175, 0.000868318, 0.00303222, 0.00271383, -0.00326585, -0.0014993 … -0.000480712, -0.000501535, -0.00174489, -0.00160158, -0.000470662, 0.00127967, 0.000618911, -0.00103705, -0.000773079, 0.00146704], Float32[1.74884f-7, 5.48983f-8, 7.75433f-8, 3.08981f-7, 2.45763f-8, 4.41623f-8, 5.29156f-7, 4.09021f-7, 6.07287f-7, 1.45678f-7 … 1.4164f-8, 1.73391f-8, 1.7507f-7, 1.44894f-7, 1.25673f-8, 1.1198f-7, 2.11545f-8, 6.25338f-8, 3.4755f-8, 1.78565f-7], (0.729, 0.997003)))), layer_2 = (layer_1 = (weight = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.00443555 0.00163654 … -0.0124978 0.0123434], Float32[2.53181f-6 1.32838f-6 … 8.83289f-6 8.58873f-6], (0.729, 0.997003))), bias = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0191175], Float32[2.08743f-5], (0.729, 0.997003)))), layer_2 = (weight = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[-0.0172084; -0.0213176; … ; -0.00376332; -0.0116419;;], Float32[1.63537f-5; 2.51152f-5; … ; 8.16783f-7; 7.55419f-6;;], (0.729, 0.997003))), bias = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[-0.0365001, -0.045083, -0.0507623, -0.0390298, -0.0242259, -0.0404982, -0.0358925, 0.0114351, -0.00803444, -0.0248332], Float32[7.40417f-5, 0.000112652, 0.000146818, 8.41229f-5, 4.60234f-5, 9.15105f-5, 7.13093f-5, 8.78741f-6, 3.62043f-6, 3.51285f-5], (0.729, 0.997003)))))), 2))
Defining Custom Layers
We can train our model using the above code, but let's go ahead and see how to use Reactant. Reactant is a julia frontend that generates MLIR and then compiles it using XLA (after running fancy optimizations). It is the current recommended way to train large models in Lux. For more details on using Reactant, see the manual.
using Lux, Random, Optimisers, Reactant, Enzyme
using Printf # For pretty printing
dev = reactant_device()
(::ReactantDevice{Missing, Missing, Missing}) (generic function with 1 method)
We will define a custom MLP using the @compact
macro. The macro takes in a list of parameters, layers and states, and a function defining the forward pass of the neural network.
n_in = 1
n_out = 1
nlayers = 3
model = @compact(
w1=Dense(n_in => 32),
w2=[Dense(32 => 32) for i in 1:nlayers],
w3=Dense(32 => n_out),
act=relu
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
end
out = w3(embed)
@return out
end
@compact(
w1 = Dense(1 => 32), # 64 parameters
w2 = NamedTuple(
1 = Dense(32 => 32), # 1_056 parameters
2 = Dense(32 => 32), # 1_056 parameters
3 = Dense(32 => 32), # 1_056 parameters
),
w3 = Dense(32 => 1), # 33 parameters
act = relu,
) do x
embed = act(w1(x))
for w = w2
embed = act(w(embed))
end
out = w3(embed)
return out
end # Total: 3_265 parameters,
# plus 1 states.
We can initialize the model and train it with the same code as before!
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model) |> dev
x = rand(rng, Float32, n_in, 32) |> dev
@jit model(x, ps, st) # 1×32 Matrix and updated state as output.
x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :)
y_data = 2 .* x_data .- x_data .^ 3
x_data, y_data = dev(x_data), dev(y_data)
function train_model!(model, ps, st, x_data, y_data)
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0))
for iter in 1:1000
_, loss, _, train_state = Lux.Training.single_train_step!(
AutoEnzyme(), MSELoss(),
(x_data, y_data), train_state
)
if iter % 100 == 1 || iter == 1000
@printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
end
end
return model, ps, st
end
train_model!(model, ps, st, x_data, y_data)
2025-05-08 14:35:08.425417: I external/xla/xla/service/service.cc:152] XLA service 0x2d9e4320 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-08 14:35:08.425990: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1746714908.427235 3747133 se_gpu_pjrt_client.cc:1026] Using BFC allocator.
I0000 00:00:1746714908.427470 3747133 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1746714908.427658 3747133 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1746714908.446832 3747133 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1746715153.416851 3747133 buffer_comparator.cc:145] Difference at 16: -nan, expected 29.4863
E0000 00:00:1746715153.417704 3747133 buffer_comparator.cc:145] Difference at 17: -nan, expected 25.4275
E0000 00:00:1746715153.417711 3747133 buffer_comparator.cc:145] Difference at 18: -nan, expected 29.498
E0000 00:00:1746715153.417714 3747133 buffer_comparator.cc:145] Difference at 19: -nan, expected 24.9024
E0000 00:00:1746715153.417717 3747133 buffer_comparator.cc:145] Difference at 20: -nan, expected 31.8883
E0000 00:00:1746715153.417720 3747133 buffer_comparator.cc:145] Difference at 21: -nan, expected 30.5795
E0000 00:00:1746715153.417723 3747133 buffer_comparator.cc:145] Difference at 22: -nan, expected 26.1755
E0000 00:00:1746715153.417726 3747133 buffer_comparator.cc:145] Difference at 23: -nan, expected 30.0282
E0000 00:00:1746715153.417728 3747133 buffer_comparator.cc:145] Difference at 24: -nan, expected 25.7237
E0000 00:00:1746715153.417731 3747133 buffer_comparator.cc:145] Difference at 25: -nan, expected 25.7191
2025-05-08 14:39:13.417740: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.420568 3747133 buffer_comparator.cc:145] Difference at 16: -nan, expected 29.4863
E0000 00:00:1746715153.420579 3747133 buffer_comparator.cc:145] Difference at 17: -nan, expected 25.4275
E0000 00:00:1746715153.420583 3747133 buffer_comparator.cc:145] Difference at 18: -nan, expected 29.498
E0000 00:00:1746715153.420586 3747133 buffer_comparator.cc:145] Difference at 19: -nan, expected 24.9024
E0000 00:00:1746715153.420589 3747133 buffer_comparator.cc:145] Difference at 20: -nan, expected 31.8883
E0000 00:00:1746715153.420592 3747133 buffer_comparator.cc:145] Difference at 21: -nan, expected 30.5795
E0000 00:00:1746715153.420594 3747133 buffer_comparator.cc:145] Difference at 22: -nan, expected 26.1755
E0000 00:00:1746715153.420597 3747133 buffer_comparator.cc:145] Difference at 23: -nan, expected 30.0282
E0000 00:00:1746715153.420600 3747133 buffer_comparator.cc:145] Difference at 24: -nan, expected 25.7237
E0000 00:00:1746715153.420603 3747133 buffer_comparator.cc:145] Difference at 25: -nan, expected 25.7191
2025-05-08 14:39:13.420608: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.423417 3747133 buffer_comparator.cc:145] Difference at 512: -nan, expected 13.9275
E0000 00:00:1746715153.423428 3747133 buffer_comparator.cc:145] Difference at 513: -nan, expected 12.9447
E0000 00:00:1746715153.423432 3747133 buffer_comparator.cc:145] Difference at 514: -nan, expected 13.899
E0000 00:00:1746715153.423435 3747133 buffer_comparator.cc:145] Difference at 515: -nan, expected 14.1578
E0000 00:00:1746715153.423438 3747133 buffer_comparator.cc:145] Difference at 516: -nan, expected 15.4892
E0000 00:00:1746715153.423440 3747133 buffer_comparator.cc:145] Difference at 517: -nan, expected 16.545
E0000 00:00:1746715153.423443 3747133 buffer_comparator.cc:145] Difference at 518: -nan, expected 17.8581
E0000 00:00:1746715153.423446 3747133 buffer_comparator.cc:145] Difference at 519: -nan, expected 13.0536
E0000 00:00:1746715153.423449 3747133 buffer_comparator.cc:145] Difference at 520: -nan, expected 16.1329
E0000 00:00:1746715153.423452 3747133 buffer_comparator.cc:145] Difference at 521: -nan, expected 14.5245
2025-05-08 14:39:13.423456: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.426297 3747133 buffer_comparator.cc:145] Difference at 528: -nan, expected 17.5032
E0000 00:00:1746715153.426309 3747133 buffer_comparator.cc:145] Difference at 529: -nan, expected 15.1785
E0000 00:00:1746715153.426312 3747133 buffer_comparator.cc:145] Difference at 530: -nan, expected 15.9473
E0000 00:00:1746715153.426315 3747133 buffer_comparator.cc:145] Difference at 531: -nan, expected 14.437
E0000 00:00:1746715153.426319 3747133 buffer_comparator.cc:145] Difference at 532: -nan, expected 17.9637
E0000 00:00:1746715153.426322 3747133 buffer_comparator.cc:145] Difference at 533: -nan, expected 17.3157
E0000 00:00:1746715153.426325 3747133 buffer_comparator.cc:145] Difference at 534: -nan, expected 15.7802
E0000 00:00:1746715153.426328 3747133 buffer_comparator.cc:145] Difference at 535: -nan, expected 17.6887
E0000 00:00:1746715153.426331 3747133 buffer_comparator.cc:145] Difference at 536: -nan, expected 15.1881
E0000 00:00:1746715153.426333 3747133 buffer_comparator.cc:145] Difference at 537: -nan, expected 14.4224
2025-05-08 14:39:13.426338: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.429161 3747133 buffer_comparator.cc:145] Difference at 528: -nan, expected 17.5032
E0000 00:00:1746715153.429172 3747133 buffer_comparator.cc:145] Difference at 529: -nan, expected 15.1785
E0000 00:00:1746715153.429176 3747133 buffer_comparator.cc:145] Difference at 530: -nan, expected 15.9473
E0000 00:00:1746715153.429179 3747133 buffer_comparator.cc:145] Difference at 531: -nan, expected 14.437
E0000 00:00:1746715153.429182 3747133 buffer_comparator.cc:145] Difference at 532: -nan, expected 17.9637
E0000 00:00:1746715153.429184 3747133 buffer_comparator.cc:145] Difference at 533: -nan, expected 17.3157
E0000 00:00:1746715153.429187 3747133 buffer_comparator.cc:145] Difference at 534: -nan, expected 15.7802
E0000 00:00:1746715153.429190 3747133 buffer_comparator.cc:145] Difference at 535: -nan, expected 17.6887
E0000 00:00:1746715153.429193 3747133 buffer_comparator.cc:145] Difference at 536: -nan, expected 15.1881
E0000 00:00:1746715153.429195 3747133 buffer_comparator.cc:145] Difference at 537: -nan, expected 14.4224
2025-05-08 14:39:13.429200: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.432004 3747133 buffer_comparator.cc:145] Difference at 528: -nan, expected 17.5032
E0000 00:00:1746715153.432016 3747133 buffer_comparator.cc:145] Difference at 529: -nan, expected 15.1785
E0000 00:00:1746715153.432019 3747133 buffer_comparator.cc:145] Difference at 530: -nan, expected 15.9473
E0000 00:00:1746715153.432022 3747133 buffer_comparator.cc:145] Difference at 531: -nan, expected 14.437
E0000 00:00:1746715153.432025 3747133 buffer_comparator.cc:145] Difference at 532: -nan, expected 17.9637
E0000 00:00:1746715153.432028 3747133 buffer_comparator.cc:145] Difference at 533: -nan, expected 17.3157
E0000 00:00:1746715153.432031 3747133 buffer_comparator.cc:145] Difference at 534: -nan, expected 15.7802
E0000 00:00:1746715153.432033 3747133 buffer_comparator.cc:145] Difference at 535: -nan, expected 17.6887
E0000 00:00:1746715153.432036 3747133 buffer_comparator.cc:145] Difference at 536: -nan, expected 15.1881
E0000 00:00:1746715153.432039 3747133 buffer_comparator.cc:145] Difference at 537: -nan, expected 14.4224
2025-05-08 14:39:13.432044: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.457731 3747133 buffer_comparator.cc:145] Difference at 16: -nan, expected 11.328
E0000 00:00:1746715153.457775 3747133 buffer_comparator.cc:145] Difference at 17: -nan, expected 8.55983
E0000 00:00:1746715153.457781 3747133 buffer_comparator.cc:145] Difference at 18: -nan, expected 10.4588
E0000 00:00:1746715153.457784 3747133 buffer_comparator.cc:145] Difference at 19: -nan, expected 8.81169
E0000 00:00:1746715153.457787 3747133 buffer_comparator.cc:145] Difference at 20: -nan, expected 8.98138
E0000 00:00:1746715153.457789 3747133 buffer_comparator.cc:145] Difference at 21: -nan, expected 9.49466
E0000 00:00:1746715153.457792 3747133 buffer_comparator.cc:145] Difference at 22: -nan, expected 8.4604
E0000 00:00:1746715153.457797 3747133 buffer_comparator.cc:145] Difference at 23: -nan, expected 9.78691
E0000 00:00:1746715153.457800 3747133 buffer_comparator.cc:145] Difference at 24: -nan, expected 8.15491
E0000 00:00:1746715153.457803 3747133 buffer_comparator.cc:145] Difference at 25: -nan, expected 13.0125
2025-05-08 14:39:13.457812: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.459873 3747133 buffer_comparator.cc:145] Difference at 16: -nan, expected 11.328
E0000 00:00:1746715153.459886 3747133 buffer_comparator.cc:145] Difference at 17: -nan, expected 8.55983
E0000 00:00:1746715153.459889 3747133 buffer_comparator.cc:145] Difference at 18: -nan, expected 10.4588
E0000 00:00:1746715153.459892 3747133 buffer_comparator.cc:145] Difference at 19: -nan, expected 8.81169
E0000 00:00:1746715153.459895 3747133 buffer_comparator.cc:145] Difference at 20: -nan, expected 8.98138
E0000 00:00:1746715153.459898 3747133 buffer_comparator.cc:145] Difference at 21: -nan, expected 9.49466
E0000 00:00:1746715153.459900 3747133 buffer_comparator.cc:145] Difference at 22: -nan, expected 8.4604
E0000 00:00:1746715153.459903 3747133 buffer_comparator.cc:145] Difference at 23: -nan, expected 9.78691
E0000 00:00:1746715153.459906 3747133 buffer_comparator.cc:145] Difference at 24: -nan, expected 8.15491
E0000 00:00:1746715153.459909 3747133 buffer_comparator.cc:145] Difference at 25: -nan, expected 13.0125
2025-05-08 14:39:13.459913: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.461971 3747133 buffer_comparator.cc:145] Difference at 656: -nan, expected 8.69665
E0000 00:00:1746715153.461981 3747133 buffer_comparator.cc:145] Difference at 657: -nan, expected 7.68202
E0000 00:00:1746715153.461985 3747133 buffer_comparator.cc:145] Difference at 658: -nan, expected 7.88703
E0000 00:00:1746715153.461987 3747133 buffer_comparator.cc:145] Difference at 659: -nan, expected 7.16689
E0000 00:00:1746715153.461990 3747133 buffer_comparator.cc:145] Difference at 660: -nan, expected 6.63868
E0000 00:00:1746715153.461993 3747133 buffer_comparator.cc:145] Difference at 661: -nan, expected 8.39542
E0000 00:00:1746715153.461995 3747133 buffer_comparator.cc:145] Difference at 662: -nan, expected 7.00635
E0000 00:00:1746715153.461998 3747133 buffer_comparator.cc:145] Difference at 663: -nan, expected 7.06674
E0000 00:00:1746715153.462001 3747133 buffer_comparator.cc:145] Difference at 664: -nan, expected 6.11613
E0000 00:00:1746715153.462003 3747133 buffer_comparator.cc:145] Difference at 665: -nan, expected 8.63651
2025-05-08 14:39:13.462008: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.464056 3747133 buffer_comparator.cc:145] Difference at 672: -nan, expected 8.61244
E0000 00:00:1746715153.464067 3747133 buffer_comparator.cc:145] Difference at 673: -nan, expected 6.1493
E0000 00:00:1746715153.464070 3747133 buffer_comparator.cc:145] Difference at 674: -nan, expected 8.90756
E0000 00:00:1746715153.464073 3747133 buffer_comparator.cc:145] Difference at 675: -nan, expected 7.1184
E0000 00:00:1746715153.464075 3747133 buffer_comparator.cc:145] Difference at 676: -nan, expected 8.03527
E0000 00:00:1746715153.464078 3747133 buffer_comparator.cc:145] Difference at 677: -nan, expected 7.44864
E0000 00:00:1746715153.464081 3747133 buffer_comparator.cc:145] Difference at 678: -nan, expected 7.35203
E0000 00:00:1746715153.464084 3747133 buffer_comparator.cc:145] Difference at 679: -nan, expected 7.89603
E0000 00:00:1746715153.464086 3747133 buffer_comparator.cc:145] Difference at 680: -nan, expected 7.3266
E0000 00:00:1746715153.464089 3747133 buffer_comparator.cc:145] Difference at 681: -nan, expected 9.7807
2025-05-08 14:39:13.464095: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.466141 3747133 buffer_comparator.cc:145] Difference at 672: -nan, expected 8.61244
E0000 00:00:1746715153.466153 3747133 buffer_comparator.cc:145] Difference at 673: -nan, expected 6.1493
E0000 00:00:1746715153.466157 3747133 buffer_comparator.cc:145] Difference at 674: -nan, expected 8.90756
E0000 00:00:1746715153.466160 3747133 buffer_comparator.cc:145] Difference at 675: -nan, expected 7.1184
E0000 00:00:1746715153.466162 3747133 buffer_comparator.cc:145] Difference at 676: -nan, expected 8.03527
E0000 00:00:1746715153.466165 3747133 buffer_comparator.cc:145] Difference at 677: -nan, expected 7.44864
E0000 00:00:1746715153.466168 3747133 buffer_comparator.cc:145] Difference at 678: -nan, expected 7.35203
E0000 00:00:1746715153.466171 3747133 buffer_comparator.cc:145] Difference at 679: -nan, expected 7.89603
E0000 00:00:1746715153.466173 3747133 buffer_comparator.cc:145] Difference at 680: -nan, expected 7.3266
E0000 00:00:1746715153.466176 3747133 buffer_comparator.cc:145] Difference at 681: -nan, expected 9.7807
2025-05-08 14:39:13.466181: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.468223 3747133 buffer_comparator.cc:145] Difference at 688: -nan, expected 7.86868
E0000 00:00:1746715153.468236 3747133 buffer_comparator.cc:145] Difference at 689: -nan, expected 7.33715
E0000 00:00:1746715153.468239 3747133 buffer_comparator.cc:145] Difference at 690: -nan, expected 6.05665
E0000 00:00:1746715153.468242 3747133 buffer_comparator.cc:145] Difference at 691: -nan, expected 7.16547
E0000 00:00:1746715153.468245 3747133 buffer_comparator.cc:145] Difference at 692: -nan, expected 8.27916
E0000 00:00:1746715153.468247 3747133 buffer_comparator.cc:145] Difference at 693: -nan, expected 5.80258
E0000 00:00:1746715153.468250 3747133 buffer_comparator.cc:145] Difference at 694: -nan, expected 6.06621
E0000 00:00:1746715153.468253 3747133 buffer_comparator.cc:145] Difference at 695: -nan, expected 7.00273
E0000 00:00:1746715153.468255 3747133 buffer_comparator.cc:145] Difference at 696: -nan, expected 7.92525
E0000 00:00:1746715153.468258 3747133 buffer_comparator.cc:145] Difference at 729: -nan, expected 7.66068
2025-05-08 14:39:13.468263: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.470325 3747133 buffer_comparator.cc:145] Difference at 688: -nan, expected 7.86868
E0000 00:00:1746715153.470338 3747133 buffer_comparator.cc:145] Difference at 689: -nan, expected 7.33715
E0000 00:00:1746715153.470341 3747133 buffer_comparator.cc:145] Difference at 690: -nan, expected 6.05665
E0000 00:00:1746715153.470344 3747133 buffer_comparator.cc:145] Difference at 691: -nan, expected 7.16547
E0000 00:00:1746715153.470347 3747133 buffer_comparator.cc:145] Difference at 692: -nan, expected 8.27916
E0000 00:00:1746715153.470350 3747133 buffer_comparator.cc:145] Difference at 693: -nan, expected 5.80258
E0000 00:00:1746715153.470353 3747133 buffer_comparator.cc:145] Difference at 694: -nan, expected 6.06621
E0000 00:00:1746715153.470355 3747133 buffer_comparator.cc:145] Difference at 695: -nan, expected 7.00273
E0000 00:00:1746715153.470359 3747133 buffer_comparator.cc:145] Difference at 696: -nan, expected 7.92525
E0000 00:00:1746715153.470362 3747133 buffer_comparator.cc:145] Difference at 729: -nan, expected 7.66068
2025-05-08 14:39:13.470368: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.482323 3747133 buffer_comparator.cc:145] Difference at 16: 0, expected 18.4532
E0000 00:00:1746715153.482366 3747133 buffer_comparator.cc:145] Difference at 17: 0, expected 16.1701
E0000 00:00:1746715153.482373 3747133 buffer_comparator.cc:145] Difference at 18: 0, expected 18.5372
E0000 00:00:1746715153.482376 3747133 buffer_comparator.cc:145] Difference at 19: 0, expected 17.7684
E0000 00:00:1746715153.482379 3747133 buffer_comparator.cc:145] Difference at 20: 0, expected 17.8078
E0000 00:00:1746715153.482382 3747133 buffer_comparator.cc:145] Difference at 21: 0, expected 17.412
E0000 00:00:1746715153.482385 3747133 buffer_comparator.cc:145] Difference at 22: 0, expected 18.0425
E0000 00:00:1746715153.482388 3747133 buffer_comparator.cc:145] Difference at 23: 0, expected 17.7822
E0000 00:00:1746715153.482391 3747133 buffer_comparator.cc:145] Difference at 24: 0, expected 16.8692
E0000 00:00:1746715153.482394 3747133 buffer_comparator.cc:145] Difference at 25: 0, expected 19.6248
2025-05-08 14:39:13.482402: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.485246 3747133 buffer_comparator.cc:145] Difference at 16: 0, expected 18.4532
E0000 00:00:1746715153.485262 3747133 buffer_comparator.cc:145] Difference at 17: 0, expected 16.1701
E0000 00:00:1746715153.485266 3747133 buffer_comparator.cc:145] Difference at 18: 0, expected 18.5372
E0000 00:00:1746715153.485269 3747133 buffer_comparator.cc:145] Difference at 19: 0, expected 17.7684
E0000 00:00:1746715153.485272 3747133 buffer_comparator.cc:145] Difference at 20: 0, expected 17.8078
E0000 00:00:1746715153.485275 3747133 buffer_comparator.cc:145] Difference at 21: 0, expected 17.412
E0000 00:00:1746715153.485278 3747133 buffer_comparator.cc:145] Difference at 22: 0, expected 18.0425
E0000 00:00:1746715153.485281 3747133 buffer_comparator.cc:145] Difference at 23: 0, expected 17.7822
E0000 00:00:1746715153.485298 3747133 buffer_comparator.cc:145] Difference at 24: 0, expected 16.8692
E0000 00:00:1746715153.485301 3747133 buffer_comparator.cc:145] Difference at 25: 0, expected 19.6248
2025-05-08 14:39:13.485306: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.488132 3747133 buffer_comparator.cc:145] Difference at 656: 0, expected 15.8892
E0000 00:00:1746715153.488146 3747133 buffer_comparator.cc:145] Difference at 657: 0, expected 15.1292
E0000 00:00:1746715153.488150 3747133 buffer_comparator.cc:145] Difference at 658: 0, expected 14.0499
E0000 00:00:1746715153.488154 3747133 buffer_comparator.cc:145] Difference at 659: 0, expected 13.8377
E0000 00:00:1746715153.488157 3747133 buffer_comparator.cc:145] Difference at 660: 0, expected 13.7353
E0000 00:00:1746715153.488159 3747133 buffer_comparator.cc:145] Difference at 661: 0, expected 15.7468
E0000 00:00:1746715153.488162 3747133 buffer_comparator.cc:145] Difference at 662: 0, expected 14.9101
E0000 00:00:1746715153.488165 3747133 buffer_comparator.cc:145] Difference at 663: 0, expected 14.8135
E0000 00:00:1746715153.488168 3747133 buffer_comparator.cc:145] Difference at 664: 0, expected 13.6403
E0000 00:00:1746715153.488171 3747133 buffer_comparator.cc:145] Difference at 665: 0, expected 15.8348
2025-05-08 14:39:13.488176: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.490997 3747133 buffer_comparator.cc:145] Difference at 672: 0, expected 16.0696
E0000 00:00:1746715153.491012 3747133 buffer_comparator.cc:145] Difference at 673: 0, expected 14.3019
E0000 00:00:1746715153.491016 3747133 buffer_comparator.cc:145] Difference at 674: 0, expected 15.5573
E0000 00:00:1746715153.491019 3747133 buffer_comparator.cc:145] Difference at 675: 0, expected 14.6242
E0000 00:00:1746715153.491022 3747133 buffer_comparator.cc:145] Difference at 676: 0, expected 14.8486
E0000 00:00:1746715153.491027 3747133 buffer_comparator.cc:145] Difference at 677: 0, expected 14.7699
E0000 00:00:1746715153.491030 3747133 buffer_comparator.cc:145] Difference at 678: 0, expected 15.1617
E0000 00:00:1746715153.491032 3747133 buffer_comparator.cc:145] Difference at 679: 0, expected 14.9394
E0000 00:00:1746715153.491035 3747133 buffer_comparator.cc:145] Difference at 680: 0, expected 13.4678
E0000 00:00:1746715153.491038 3747133 buffer_comparator.cc:145] Difference at 681: 0, expected 16.1851
2025-05-08 14:39:13.491043: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.493882 3747133 buffer_comparator.cc:145] Difference at 672: 0, expected 16.0696
E0000 00:00:1746715153.493901 3747133 buffer_comparator.cc:145] Difference at 673: 0, expected 14.3019
E0000 00:00:1746715153.493905 3747133 buffer_comparator.cc:145] Difference at 674: 0, expected 15.5573
E0000 00:00:1746715153.493908 3747133 buffer_comparator.cc:145] Difference at 675: 0, expected 14.6242
E0000 00:00:1746715153.493911 3747133 buffer_comparator.cc:145] Difference at 676: 0, expected 14.8486
E0000 00:00:1746715153.493914 3747133 buffer_comparator.cc:145] Difference at 677: 0, expected 14.7699
E0000 00:00:1746715153.493917 3747133 buffer_comparator.cc:145] Difference at 678: 0, expected 15.1617
E0000 00:00:1746715153.493920 3747133 buffer_comparator.cc:145] Difference at 679: 0, expected 14.9394
E0000 00:00:1746715153.493923 3747133 buffer_comparator.cc:145] Difference at 680: 0, expected 13.4678
E0000 00:00:1746715153.493926 3747133 buffer_comparator.cc:145] Difference at 681: 0, expected 16.1851
2025-05-08 14:39:13.493931: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.496795 3747133 buffer_comparator.cc:145] Difference at 688: 0, expected 15.1187
E0000 00:00:1746715153.496815 3747133 buffer_comparator.cc:145] Difference at 689: 0, expected 14.6251
E0000 00:00:1746715153.496819 3747133 buffer_comparator.cc:145] Difference at 690: 0, expected 14.2005
E0000 00:00:1746715153.496822 3747133 buffer_comparator.cc:145] Difference at 691: 0, expected 15.1561
E0000 00:00:1746715153.496825 3747133 buffer_comparator.cc:145] Difference at 692: 0, expected 15.4235
E0000 00:00:1746715153.496828 3747133 buffer_comparator.cc:145] Difference at 693: 0, expected 14.1331
E0000 00:00:1746715153.496831 3747133 buffer_comparator.cc:145] Difference at 694: 0, expected 14.4063
E0000 00:00:1746715153.496834 3747133 buffer_comparator.cc:145] Difference at 695: 0, expected 14.0259
E0000 00:00:1746715153.496837 3747133 buffer_comparator.cc:145] Difference at 696: 0, expected 15.0279
E0000 00:00:1746715153.496840 3747133 buffer_comparator.cc:145] Difference at 729: 0, expected 14.5946
2025-05-08 14:39:13.496846: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1746715153.499711 3747133 buffer_comparator.cc:145] Difference at 688: 0, expected 15.1187
E0000 00:00:1746715153.499731 3747133 buffer_comparator.cc:145] Difference at 689: 0, expected 14.6251
E0000 00:00:1746715153.499735 3747133 buffer_comparator.cc:145] Difference at 690: 0, expected 14.2005
E0000 00:00:1746715153.499738 3747133 buffer_comparator.cc:145] Difference at 691: 0, expected 15.1561
E0000 00:00:1746715153.499741 3747133 buffer_comparator.cc:145] Difference at 692: 0, expected 15.4235
E0000 00:00:1746715153.499744 3747133 buffer_comparator.cc:145] Difference at 693: 0, expected 14.1331
E0000 00:00:1746715153.499747 3747133 buffer_comparator.cc:145] Difference at 694: 0, expected 14.4063
E0000 00:00:1746715153.499750 3747133 buffer_comparator.cc:145] Difference at 695: 0, expected 14.0259
E0000 00:00:1746715153.499753 3747133 buffer_comparator.cc:145] Difference at 696: 0, expected 15.0279
E0000 00:00:1746715153.499758 3747133 buffer_comparator.cc:145] Difference at 729: 0, expected 14.5946
2025-05-08 14:39:13.499763: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iteration: 0001 Loss: 2.08073235
Iteration: 0101 Loss: 0.142443746
Iteration: 0201 Loss: 0.00498719513
Iteration: 0301 Loss: 0.00116170431
Iteration: 0401 Loss: 0.000497820089
Iteration: 0501 Loss: 0.00027494994
Iteration: 0601 Loss: 0.000180605697
Iteration: 0701 Loss: 0.000201430521
Iteration: 0801 Loss: 0.000379697973
Iteration: 0901 Loss: 0.000106757798
Iteration: 1000 Loss: 0.000231371887
Training with Optimization.jl
If you are coming from the SciML ecosystem and want to use Optimization.jl, please refer to the Optimization.jl Tutorial.
Additional Packages
LuxDL
hosts various packages that provide additional functionality for Lux.jl. All packages mentioned in this documentation are available via the Julia General Registry.
You can install all those packages via import Pkg; Pkg.add(<package name>)
.
XLA (CPU/GPU/TPU) Support
Lux.jl supports XLA compilation for CPU, GPU, and TPU using Reactant.jl.
GPU Support
GPU Support for Lux.jl requires loading additional packages:
LuxCUDA.jl
for CUDA support.AMDGPU.jl
for AMDGPU support.Metal.jl
for Apple Metal support.oneAPI.jl
for oneAPI support.