Skip to content

Getting Started

Installation

Install Julia v1.10 or above. Lux.jl is available through the Julia package manager. You can enter it by pressing ] in the REPL and then typing add Lux. Alternatively, you can also do

julia
import Pkg
Pkg.add("Lux")

Update to v1

If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.

Quickstart

Pre-Requisites

You need to install Optimisers and Zygote if not done already. Pkg.add(["Optimisers", "Zygote"])

julia
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support

We take randomness very seriously

julia
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
Random.TaskLocalRNG()

Build the model

julia
# Construct the layer
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))
Chain(
    layer_1 = Dense(128 => 256, tanh),  # 33_024 parameters
    layer_2 = Chain(
        layer_1 = Dense(256 => 1, tanh),  # 257 parameters
        layer_2 = Dense(1 => 10),       # 20 parameters
    ),
)         # Total: 33_301 parameters,
          #        plus 0 states.

Models don't hold parameters and states so initialize them. From there on, we can just use our standard AD and Optimisers API. However, here we will show how to use Lux's Training API that provides an uniform API over all supported AD systems.

julia
# Get the device determined by Lux
dev = gpu_device()

# Parameter and State Variables
ps, st = Lux.setup(rng, model) |> dev

# Dummy Input
x = rand(rng, Float32, 128, 2) |> dev

# Run the model
y, st = Lux.apply(model, x, ps, st)

# Gradients
## First construct a TrainState
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))

## We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux.Training.compute_gradients(
    AutoZygote(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state
)

## Optimization
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)

# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training.single_train_step!(
    AutoZygote(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state
)
((layer_1 = (weight = Float32[0.0017983615 0.006062332 … 0.0053392933 0.0056276177; 0.0011292367 0.0041270256 … 0.003585879 0.0038155357; … ; -0.0008762945 -0.0031371699 … -0.0027350332 -0.0029033197; 0.0011154839 0.002197485 … 0.0021741025 0.0021157824], bias = Float32[0.006656272, 0.004425203, 0.0028994146, -0.0116051175, 0.0031301186, 0.0037318026, 0.0136483535, 0.013969757, -0.015173428, -0.005173992  …  -0.0018621369, -0.0015270555, -0.007873881, -0.0076395273, -0.0022123815, 0.0039605754, 0.0034407252, -0.0045406874, -0.003383829, 0.0029306945]), layer_2 = (layer_1 = (weight = Float32[0.04993449 0.03202845 … -0.059382 0.07701616], bias = Float32[0.08797912]), layer_2 = (weight = Float32[-0.094527975; -0.11476975; … ; -0.016841749; -0.0698748;;], bias = Float32[-0.21608135, -0.26255828, -0.23534852, -0.21524015, -0.055711076, -0.20314303, -0.1895644, 0.03666526, -0.03937737, -0.15905891]))), 0.8455785f0, NamedTuple(), Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}}(nothing, nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}((layer_1 = Dense(128 => 256, tanh), layer_2 = Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(256 => 1, tanh), layer_2 = Dense(1 => 10)), nothing)), nothing), (layer_1 = (weight = Float32[-0.22542597 0.22379348 … 0.1997513 -0.018708104; -0.023026714 0.15451026 … -0.065325744 0.18120264; … ; 0.038037397 -0.07125516 … -0.03306083 0.039138064; -0.18810266 -0.09693537 … -0.18102062 0.019230088], bias = Float32[0.030937059, -0.060276944, 0.084569596, 0.00040024254, -0.065509446, -0.08527214, -0.026523968, 0.06347208, 0.042247728, 0.027705256  …  -0.06052852, 0.03504307, -0.028244259, 0.06788022, 0.0027464977, -0.06942153, 0.0064240773, 0.0141069945, -0.029283267, 0.01174226]), layer_2 = (layer_1 = (weight = Float32[0.12008221 0.06026435 … -0.070576 0.1577647], bias = Float32[0.026844418]), layer_2 = (weight = Float32[0.5345728; -0.28288874; … ; -0.32983455; -0.45298168;;], bias = Float32[-0.59751064, -0.7033041, -0.8457602, -0.53789175, -0.31473723, 0.17461234, -0.82945836, 0.67841595, 0.35837248, -0.14941788]))), (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), Adam(0.0001, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.000926728 0.000860063 … 0.00110328 0.000908301; 0.000480834 0.000574605 … 0.000665883 0.000584197; … ; -0.000391039 -0.000438617 … -0.000520651 -0.000449867; 0.00106235 0.000365587 … 0.000813131 0.000495484], Float32[7.20343f-8 4.46976f-8 … 6.84867f-8 4.63952f-8; 1.79691f-8 2.02649f-8 … 2.45046f-8 1.96227f-8; … ; 1.21215f-8 1.17657f-8 … 1.50136f-8 1.15681f-8; 1.12738f-7 7.45199f-9 … 4.8495f-8 1.44173f-8], (0.729, 0.997003))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.00169459, 0.000977637, 0.00103866, -0.00234933, 0.000659175, 0.000868318, 0.00303222, 0.00271383, -0.00326585, -0.0014993  …  -0.000480712, -0.000501535, -0.00174489, -0.00160158, -0.000470662, 0.00127967, 0.000618911, -0.00103705, -0.000773079, 0.00146704], Float32[1.74884f-7, 5.48983f-8, 7.75433f-8, 3.08981f-7, 2.45763f-8, 4.41623f-8, 5.29156f-7, 4.09021f-7, 6.07287f-7, 1.45678f-7  …  1.4164f-8, 1.73391f-8, 1.7507f-7, 1.44894f-7, 1.25673f-8, 1.1198f-7, 2.11545f-8, 6.25338f-8, 3.4755f-8, 1.78565f-7], (0.729, 0.997003)))), layer_2 = (layer_1 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.00443555 0.00163654 … -0.0124978 0.0123434], Float32[2.53181f-6 1.32838f-6 … 8.83289f-6 8.58873f-6], (0.729, 0.997003))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.0191175], Float32[2.08743f-5], (0.729, 0.997003)))), layer_2 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[-0.0172084; -0.0213176; … ; -0.00376332; -0.0116419;;], Float32[1.63537f-5; 2.51152f-5; … ; 8.16783f-7; 7.55419f-6;;], (0.729, 0.997003))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[-0.0365001, -0.045083, -0.0507623, -0.0390298, -0.0242259, -0.0404982, -0.0358925, 0.0114351, -0.00803444, -0.0248332], Float32[7.40417f-5, 0.000112652, 0.000146818, 8.41229f-5, 4.60234f-5, 9.15105f-5, 7.13093f-5, 8.78741f-6, 3.62043f-6, 3.51285f-5], (0.729, 0.997003)))))), 2))

Defining Custom Layers

We can train our model using the above code, but let's go ahead and see how to use Reactant. Reactant is a julia frontend that generates MLIR and then compiles it using XLA (after running fancy optimizations). It is the current recommended way to train large models in Lux. For more details on using Reactant, see the manual.

julia
using Lux, Random, Optimisers, Reactant, Enzyme
using Printf # For pretty printing

dev = reactant_device()
(::ReactantDevice{Missing, Missing}) (generic function with 1 method)

We will define a custom MLP using the @compact macro. The macro takes in a list of parameters, layers and states, and a function defining the forward pass of the neural network.

julia
n_in = 1
n_out = 1
nlayers = 3

model = @compact(
    w1=Dense(n_in => 32),
    w2=[Dense(32 => 32) for i in 1:nlayers],
    w3=Dense(32 => n_out),
    act=relu
) do x
    embed = act(w1(x))
    for w in w2
        embed = act(w(embed))
    end
    out = w3(embed)
    @return out
end
@compact(
    w1 = Dense(1 => 32),                # 64 parameters
    w2 = NamedTuple(
        1 = Dense(32 => 32),            # 1_056 parameters
        2 = Dense(32 => 32),            # 1_056 parameters
        3 = Dense(32 => 32),            # 1_056 parameters
    ),
    w3 = Dense(32 => 1),                # 33 parameters
    act = relu,
) do x 
    embed = act(w1(x))
    for w = w2
        embed = act(w(embed))
    end
    out = w3(embed)
    return out
end       # Total: 3_265 parameters,
          #        plus 1 states.

We can initialize the model and train it with the same code as before!

julia
rng = Random.default_rng()
Random.seed!(rng, 0)

ps, st = Lux.setup(rng, model) |> dev

x = rand(rng, Float32, n_in, 32) |> dev

@jit model(x, ps, st)  # 1×32 Matrix and updated state as output.

x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :)
y_data = 2 .* x_data .- x_data .^ 3
x_data, y_data = dev(x_data), dev(y_data)

function train_model!(model, ps, st, x_data, y_data)
    train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0))

    for iter in 1:1000
        _, loss, _, train_state = Lux.Training.single_train_step!(
            AutoEnzyme(), MSELoss(),
            (x_data, y_data), train_state
        )
        if iter % 100 == 1 || iter == 1000
            @printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
        end
    end

    return model, ps, st
end

train_model!(model, ps, st, x_data, y_data)
E0000 00:00:1739762646.295757  202456 buffer_comparator.cc:156] Difference at 16: -nan, expected 29.4863
E0000 00:00:1739762646.296357  202456 buffer_comparator.cc:156] Difference at 17: -nan, expected 25.4275
E0000 00:00:1739762646.296365  202456 buffer_comparator.cc:156] Difference at 18: -nan, expected 29.498
E0000 00:00:1739762646.296372  202456 buffer_comparator.cc:156] Difference at 19: -nan, expected 24.9024
E0000 00:00:1739762646.296378  202456 buffer_comparator.cc:156] Difference at 20: -nan, expected 31.8883
E0000 00:00:1739762646.296385  202456 buffer_comparator.cc:156] Difference at 21: -nan, expected 30.5795
E0000 00:00:1739762646.296391  202456 buffer_comparator.cc:156] Difference at 22: -nan, expected 26.1755
E0000 00:00:1739762646.296397  202456 buffer_comparator.cc:156] Difference at 23: -nan, expected 30.0282
E0000 00:00:1739762646.296403  202456 buffer_comparator.cc:156] Difference at 24: -nan, expected 25.7237
E0000 00:00:1739762646.296409  202456 buffer_comparator.cc:156] Difference at 25: -nan, expected 25.7191
2025-02-17 03:24:06.296448: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1739762646.298876  202456 buffer_comparator.cc:156] Difference at 16: -nan, expected 29.4863
E0000 00:00:1739762646.298902  202456 buffer_comparator.cc:156] Difference at 17: -nan, expected 25.4275
E0000 00:00:1739762646.298909  202456 buffer_comparator.cc:156] Difference at 18: -nan, expected 29.498
E0000 00:00:1739762646.298915  202456 buffer_comparator.cc:156] Difference at 19: -nan, expected 24.9024
E0000 00:00:1739762646.298922  202456 buffer_comparator.cc:156] Difference at 20: -nan, expected 31.8883
E0000 00:00:1739762646.298928  202456 buffer_comparator.cc:156] Difference at 21: -nan, expected 30.5795
E0000 00:00:1739762646.298934  202456 buffer_comparator.cc:156] Difference at 22: -nan, expected 26.1755
E0000 00:00:1739762646.298940  202456 buffer_comparator.cc:156] Difference at 23: -nan, expected 30.0282
E0000 00:00:1739762646.298946  202456 buffer_comparator.cc:156] Difference at 24: -nan, expected 25.7237
E0000 00:00:1739762646.298953  202456 buffer_comparator.cc:156] Difference at 25: -nan, expected 25.7191
2025-02-17 03:24:06.298962: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1739762646.301322  202456 buffer_comparator.cc:156] Difference at 512: -nan, expected 13.9275
E0000 00:00:1739762646.301348  202456 buffer_comparator.cc:156] Difference at 513: -nan, expected 12.9447
E0000 00:00:1739762646.301355  202456 buffer_comparator.cc:156] Difference at 514: -nan, expected 13.899
E0000 00:00:1739762646.301361  202456 buffer_comparator.cc:156] Difference at 515: -nan, expected 14.1578
E0000 00:00:1739762646.301367  202456 buffer_comparator.cc:156] Difference at 516: -nan, expected 15.4892
E0000 00:00:1739762646.301373  202456 buffer_comparator.cc:156] Difference at 517: -nan, expected 16.545
E0000 00:00:1739762646.301380  202456 buffer_comparator.cc:156] Difference at 518: -nan, expected 17.8581
E0000 00:00:1739762646.301386  202456 buffer_comparator.cc:156] Difference at 519: -nan, expected 13.0536
E0000 00:00:1739762646.301392  202456 buffer_comparator.cc:156] Difference at 520: -nan, expected 16.1329
E0000 00:00:1739762646.301398  202456 buffer_comparator.cc:156] Difference at 521: -nan, expected 14.5245
2025-02-17 03:24:06.301408: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1739762646.303744  202456 buffer_comparator.cc:156] Difference at 528: -nan, expected 17.5032
E0000 00:00:1739762646.303757  202456 buffer_comparator.cc:156] Difference at 529: -nan, expected 15.1785
E0000 00:00:1739762646.303760  202456 buffer_comparator.cc:156] Difference at 530: -nan, expected 15.9473
E0000 00:00:1739762646.303763  202456 buffer_comparator.cc:156] Difference at 531: -nan, expected 14.437
E0000 00:00:1739762646.303768  202456 buffer_comparator.cc:156] Difference at 532: -nan, expected 17.9637
E0000 00:00:1739762646.303771  202456 buffer_comparator.cc:156] Difference at 533: -nan, expected 17.3157
E0000 00:00:1739762646.303773  202456 buffer_comparator.cc:156] Difference at 534: -nan, expected 15.7802
E0000 00:00:1739762646.303776  202456 buffer_comparator.cc:156] Difference at 535: -nan, expected 17.6887
E0000 00:00:1739762646.303779  202456 buffer_comparator.cc:156] Difference at 536: -nan, expected 15.1881
E0000 00:00:1739762646.303782  202456 buffer_comparator.cc:156] Difference at 537: -nan, expected 14.4224
2025-02-17 03:24:06.303786: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1739762646.320774  202456 buffer_comparator.cc:156] Difference at 16: 0, expected 18.4532
E0000 00:00:1739762646.320787  202456 buffer_comparator.cc:156] Difference at 17: 0, expected 16.1701
E0000 00:00:1739762646.320791  202456 buffer_comparator.cc:156] Difference at 18: 0, expected 18.5372
E0000 00:00:1739762646.320794  202456 buffer_comparator.cc:156] Difference at 19: 0, expected 17.7684
E0000 00:00:1739762646.320797  202456 buffer_comparator.cc:156] Difference at 20: 0, expected 17.8078
E0000 00:00:1739762646.320799  202456 buffer_comparator.cc:156] Difference at 21: 0, expected 17.412
E0000 00:00:1739762646.320802  202456 buffer_comparator.cc:156] Difference at 22: 0, expected 18.0425
E0000 00:00:1739762646.320805  202456 buffer_comparator.cc:156] Difference at 23: 0, expected 17.7822
E0000 00:00:1739762646.320808  202456 buffer_comparator.cc:156] Difference at 24: 0, expected 16.8692
E0000 00:00:1739762646.320811  202456 buffer_comparator.cc:156] Difference at 25: 0, expected 19.6248
2025-02-17 03:24:06.320816: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1739762646.322860  202456 buffer_comparator.cc:156] Difference at 16: 0, expected 18.4532
E0000 00:00:1739762646.322872  202456 buffer_comparator.cc:156] Difference at 17: 0, expected 16.1701
E0000 00:00:1739762646.322876  202456 buffer_comparator.cc:156] Difference at 18: 0, expected 18.5372
E0000 00:00:1739762646.322879  202456 buffer_comparator.cc:156] Difference at 19: 0, expected 17.7684
E0000 00:00:1739762646.322882  202456 buffer_comparator.cc:156] Difference at 20: 0, expected 17.8078
E0000 00:00:1739762646.322884  202456 buffer_comparator.cc:156] Difference at 21: 0, expected 17.412
E0000 00:00:1739762646.322887  202456 buffer_comparator.cc:156] Difference at 22: 0, expected 18.0425
E0000 00:00:1739762646.322890  202456 buffer_comparator.cc:156] Difference at 23: 0, expected 17.7822
E0000 00:00:1739762646.322893  202456 buffer_comparator.cc:156] Difference at 24: 0, expected 16.8692
E0000 00:00:1739762646.322896  202456 buffer_comparator.cc:156] Difference at 25: 0, expected 19.6248
2025-02-17 03:24:06.322901: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1739762646.324958  202456 buffer_comparator.cc:156] Difference at 656: 0, expected 15.8892
E0000 00:00:1739762646.324970  202456 buffer_comparator.cc:156] Difference at 657: 0, expected 15.1292
E0000 00:00:1739762646.324973  202456 buffer_comparator.cc:156] Difference at 658: 0, expected 14.0499
E0000 00:00:1739762646.324976  202456 buffer_comparator.cc:156] Difference at 659: 0, expected 13.8377
E0000 00:00:1739762646.324979  202456 buffer_comparator.cc:156] Difference at 660: 0, expected 13.7353
E0000 00:00:1739762646.324982  202456 buffer_comparator.cc:156] Difference at 661: 0, expected 15.7468
E0000 00:00:1739762646.324985  202456 buffer_comparator.cc:156] Difference at 662: 0, expected 14.9101
E0000 00:00:1739762646.324988  202456 buffer_comparator.cc:156] Difference at 663: 0, expected 14.8135
E0000 00:00:1739762646.324992  202456 buffer_comparator.cc:156] Difference at 664: 0, expected 13.6403
E0000 00:00:1739762646.324995  202456 buffer_comparator.cc:156] Difference at 665: 0, expected 15.8348
2025-02-17 03:24:06.325000: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1739762646.327041  202456 buffer_comparator.cc:156] Difference at 672: 0, expected 16.0696
E0000 00:00:1739762646.327053  202456 buffer_comparator.cc:156] Difference at 673: 0, expected 14.3019
E0000 00:00:1739762646.327056  202456 buffer_comparator.cc:156] Difference at 674: 0, expected 15.5573
E0000 00:00:1739762646.327059  202456 buffer_comparator.cc:156] Difference at 675: 0, expected 14.6242
E0000 00:00:1739762646.327062  202456 buffer_comparator.cc:156] Difference at 676: 0, expected 14.8486
E0000 00:00:1739762646.327065  202456 buffer_comparator.cc:156] Difference at 677: 0, expected 14.7699
E0000 00:00:1739762646.327067  202456 buffer_comparator.cc:156] Difference at 678: 0, expected 15.1617
E0000 00:00:1739762646.327070  202456 buffer_comparator.cc:156] Difference at 679: 0, expected 14.9394
E0000 00:00:1739762646.327073  202456 buffer_comparator.cc:156] Difference at 680: 0, expected 13.4678
E0000 00:00:1739762646.327076  202456 buffer_comparator.cc:156] Difference at 681: 0, expected 16.1851
2025-02-17 03:24:06.327080: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1739762646.329125  202456 buffer_comparator.cc:156] Difference at 688: 0, expected 15.1187
E0000 00:00:1739762646.329136  202456 buffer_comparator.cc:156] Difference at 689: 0, expected 14.6251
E0000 00:00:1739762646.329140  202456 buffer_comparator.cc:156] Difference at 690: 0, expected 14.2005
E0000 00:00:1739762646.329143  202456 buffer_comparator.cc:156] Difference at 691: 0, expected 15.1561
E0000 00:00:1739762646.329146  202456 buffer_comparator.cc:156] Difference at 692: 0, expected 15.4235
E0000 00:00:1739762646.329149  202456 buffer_comparator.cc:156] Difference at 693: 0, expected 14.1331
E0000 00:00:1739762646.329151  202456 buffer_comparator.cc:156] Difference at 694: 0, expected 14.4063
E0000 00:00:1739762646.329154  202456 buffer_comparator.cc:156] Difference at 695: 0, expected 14.0259
E0000 00:00:1739762646.329157  202456 buffer_comparator.cc:156] Difference at 696: 0, expected 15.0279
E0000 00:00:1739762646.329160  202456 buffer_comparator.cc:156] Difference at 729: 0, expected 14.5946
2025-02-17 03:24:06.329165: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1739762646.331353  202456 buffer_comparator.cc:156] Difference at 688: 0, expected 15.1187
E0000 00:00:1739762646.331365  202456 buffer_comparator.cc:156] Difference at 689: 0, expected 14.6251
E0000 00:00:1739762646.331368  202456 buffer_comparator.cc:156] Difference at 690: 0, expected 14.2005
E0000 00:00:1739762646.331371  202456 buffer_comparator.cc:156] Difference at 691: 0, expected 15.1561
E0000 00:00:1739762646.331374  202456 buffer_comparator.cc:156] Difference at 692: 0, expected 15.4235
E0000 00:00:1739762646.331377  202456 buffer_comparator.cc:156] Difference at 693: 0, expected 14.1331
E0000 00:00:1739762646.331380  202456 buffer_comparator.cc:156] Difference at 694: 0, expected 14.4063
E0000 00:00:1739762646.331383  202456 buffer_comparator.cc:156] Difference at 695: 0, expected 14.0259
E0000 00:00:1739762646.331386  202456 buffer_comparator.cc:156] Difference at 696: 0, expected 15.0279
E0000 00:00:1739762646.331389  202456 buffer_comparator.cc:156] Difference at 729: 0, expected 14.5946
2025-02-17 03:24:06.331393: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1739762646.333404  202456 buffer_comparator.cc:156] Difference at 688: 0, expected 15.1187
E0000 00:00:1739762646.333416  202456 buffer_comparator.cc:156] Difference at 689: 0, expected 14.6251
E0000 00:00:1739762646.333419  202456 buffer_comparator.cc:156] Difference at 690: 0, expected 14.2005
E0000 00:00:1739762646.333422  202456 buffer_comparator.cc:156] Difference at 691: 0, expected 15.1561
E0000 00:00:1739762646.333425  202456 buffer_comparator.cc:156] Difference at 692: 0, expected 15.4235
E0000 00:00:1739762646.333428  202456 buffer_comparator.cc:156] Difference at 693: 0, expected 14.1331
E0000 00:00:1739762646.333430  202456 buffer_comparator.cc:156] Difference at 694: 0, expected 14.4063
E0000 00:00:1739762646.333433  202456 buffer_comparator.cc:156] Difference at 695: 0, expected 14.0259
E0000 00:00:1739762646.333436  202456 buffer_comparator.cc:156] Difference at 696: 0, expected 15.0279
E0000 00:00:1739762646.333439  202456 buffer_comparator.cc:156] Difference at 729: 0, expected 14.5946
2025-02-17 03:24:06.333444: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iteration: 0001 	 Loss: 2.08086824
Iteration: 0101 	 Loss: 0.135109991
Iteration: 0201 	 Loss: 0.00448962208
Iteration: 0301 	 Loss: 0.00111342408
Iteration: 0401 	 Loss: 0.000457020855
Iteration: 0501 	 Loss: 0.000461334654
Iteration: 0601 	 Loss: 0.000210383674
Iteration: 0701 	 Loss: 0.000209954742
Iteration: 0801 	 Loss: 0.000158460374
Iteration: 0901 	 Loss: 9.63655111e-05
Iteration: 1000 	 Loss: 0.000503897725

Training with Optimization.jl

If you are coming from the SciML ecosystem and want to use Optimization.jl, please refer to the Optimization.jl Tutorial.

Additional Packages

LuxDL hosts various packages that provide additional functionality for Lux.jl. All packages mentioned in this documentation are available via the Julia General Registry.

You can install all those packages via import Pkg; Pkg.add(<package name>).

XLA (CPU/GPU/TPU) Support

Lux.jl supports XLA compilation for CPU, GPU, and TPU using Reactant.jl.

GPU Support

GPU Support for Lux.jl requires loading additional packages: