Introduction
Installation
Install Julia v1.10 or above. Lux.jl is available through the Julia package manager. You can enter it by pressing ]
in the REPL and then typing add Lux
. Alternatively, you can also do
import Pkg
Pkg.add("Lux")
Update to v1
If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.
Quickstart
Pre-Requisites
You need to install Optimisers
and Zygote
if not done already. Pkg.add(["Optimisers", "Zygote"])
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
We take randomness very seriously
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
Random.TaskLocalRNG()
Build the model
# Construct the layer
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))
Chain(
layer_1 = Dense(128 => 256, tanh), # 33_024 parameters
layer_2 = Chain(
layer_1 = Dense(256 => 1, tanh), # 257 parameters
layer_2 = Dense(1 => 10), # 20 parameters
),
) # Total: 33_301 parameters,
# plus 0 states.
Models don't hold parameters and states so initialize them. From there on, we can just use our standard AD and Optimisers API. However, here we will show how to use Lux's Training API that provides an uniform API over all supported AD systems.
# Get the device determined by Lux
dev = gpu_device()
# Parameter and State Variables
ps, st = Lux.setup(rng, model) |> dev
# Dummy Input
x = rand(rng, Float32, 128, 2) |> dev
# Run the model
y, st = Lux.apply(model, x, ps, st)
# Gradients
## First construct a TrainState
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
## We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux.Training.compute_gradients(
AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state
)
## Optimization
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)
# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training.single_train_step!(
AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state
)
((layer_1 = (weight = Float32[0.0017983615 0.006062332 … 0.0053392933 0.0056276177; 0.0011292367 0.0041270256 … 0.003585879 0.0038155357; … ; -0.0008762945 -0.0031371699 … -0.0027350332 -0.0029033197; 0.0011154839 0.002197485 … 0.0021741025 0.0021157824], bias = Float32[0.006656272, 0.004425203, 0.0028994146, -0.0116051175, 0.0031301186, 0.0037318026, 0.0136483535, 0.013969757, -0.015173428, -0.005173992 … -0.0018621369, -0.0015270555, -0.007873881, -0.0076395273, -0.0022123815, 0.0039605754, 0.0034407252, -0.0045406874, -0.003383829, 0.0029306945]), layer_2 = (layer_1 = (weight = Float32[0.04993449 0.03202845 … -0.059382 0.07701616], bias = Float32[0.08797912]), layer_2 = (weight = Float32[-0.094527975; -0.11476975; … ; -0.016841749; -0.0698748;;], bias = Float32[-0.21608135, -0.26255828, -0.23534852, -0.21524015, -0.055711076, -0.20314303, -0.1895644, 0.03666526, -0.03937737, -0.15905891]))), 0.8455785f0, NamedTuple(), Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, Adam{Float32, Tuple{Float64, Float64}, Float64}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}}(nothing, nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}((layer_1 = Dense(128 => 256, tanh), layer_2 = Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(256 => 1, tanh), layer_2 = Dense(1 => 10)), nothing)), nothing), (layer_1 = (weight = Float32[-0.22542597 0.22379348 … 0.1997513 -0.018708104; -0.023026714 0.15451026 … -0.065325744 0.18120264; … ; 0.038037397 -0.07125516 … -0.03306083 0.039138064; -0.18810266 -0.09693537 … -0.18102062 0.019230088], bias = Float32[0.030937059, -0.060276944, 0.084569596, 0.00040024254, -0.065509446, -0.08527214, -0.026523968, 0.06347208, 0.042247728, 0.027705256 … -0.06052852, 0.03504307, -0.028244259, 0.06788022, 0.0027464977, -0.06942153, 0.0064240773, 0.0141069945, -0.029283267, 0.01174226]), layer_2 = (layer_1 = (weight = Float32[0.12008221 0.06026435 … -0.070576 0.1577647], bias = Float32[0.026844418]), layer_2 = (weight = Float32[0.5345728; -0.28288874; … ; -0.32983455; -0.45298168;;], bias = Float32[-0.59751064, -0.7033041, -0.8457602, -0.53789175, -0.31473723, 0.17461234, -0.82945836, 0.67841595, 0.35837248, -0.14941788]))), (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (layer_1 = (weight = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.000926728 0.000860063 … 0.00110328 0.000908301; 0.000480834 0.000574605 … 0.000665883 0.000584197; … ; -0.000391039 -0.000438617 … -0.000520651 -0.000449867; 0.00106235 0.000365587 … 0.000813131 0.000495484], Float32[7.20343f-8 4.46976f-8 … 6.84867f-8 4.63952f-8; 1.79691f-8 2.02649f-8 … 2.45046f-8 1.96227f-8; … ; 1.21215f-8 1.17657f-8 … 1.50136f-8 1.15681f-8; 1.12738f-7 7.45199f-9 … 4.8495f-8 1.44173f-8], (0.729, 0.997003))), bias = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.00169459, 0.000977637, 0.00103866, -0.00234933, 0.000659175, 0.000868318, 0.00303222, 0.00271383, -0.00326585, -0.0014993 … -0.000480712, -0.000501535, -0.00174489, -0.00160158, -0.000470662, 0.00127967, 0.000618911, -0.00103705, -0.000773079, 0.00146704], Float32[1.74884f-7, 5.48983f-8, 7.75433f-8, 3.08981f-7, 2.45763f-8, 4.41623f-8, 5.29156f-7, 4.09021f-7, 6.07287f-7, 1.45678f-7 … 1.4164f-8, 1.73391f-8, 1.7507f-7, 1.44894f-7, 1.25673f-8, 1.1198f-7, 2.11545f-8, 6.25338f-8, 3.4755f-8, 1.78565f-7], (0.729, 0.997003)))), layer_2 = (layer_1 = (weight = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.00443555 0.00163654 … -0.0124978 0.0123434], Float32[2.53181f-6 1.32838f-6 … 8.83289f-6 8.58873f-6], (0.729, 0.997003))), bias = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0191175], Float32[2.08743f-5], (0.729, 0.997003)))), layer_2 = (weight = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[-0.0172084; -0.0213176; … ; -0.00376332; -0.0116419;;], Float32[1.63537f-5; 2.51152f-5; … ; 8.16783f-7; 7.55419f-6;;], (0.729, 0.997003))), bias = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[-0.0365001, -0.045083, -0.0507623, -0.0390298, -0.0242259, -0.0404982, -0.0358925, 0.0114351, -0.00803444, -0.0248332], Float32[7.40417f-5, 0.000112652, 0.000146818, 8.41229f-5, 4.60234f-5, 9.15105f-5, 7.13093f-5, 8.78741f-6, 3.62043f-6, 3.51285f-5], (0.729, 0.997003)))))), 2))
Defining Custom Layers
We can train our model using the above code, but let's go ahead and see how to use Reactant. Reactant is a julia frontend that generates MLIR and then compiles it using XLA (after running fancy optimizations). It is the current recommended way to train large models in Lux. For more details on using Reactant, see the manual.
using Lux, Random, Optimisers, Reactant, Enzyme
using Printf # For pretty printing
dev = reactant_device()
(::ReactantDevice{Missing, Missing, Missing}) (generic function with 1 method)
We will define a custom MLP using the @compact
macro. The macro takes in a list of parameters, layers and states, and a function defining the forward pass of the neural network.
n_in = 1
n_out = 1
nlayers = 3
model = @compact(
w1=Dense(n_in => 32),
w2=[Dense(32 => 32) for i in 1:nlayers],
w3=Dense(32 => n_out),
act=relu
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
end
out = w3(embed)
@return out
end
@compact(
w1 = Dense(1 => 32), # 64 parameters
w2 = NamedTuple(
1 = Dense(32 => 32), # 1_056 parameters
2 = Dense(32 => 32), # 1_056 parameters
3 = Dense(32 => 32), # 1_056 parameters
),
w3 = Dense(32 => 1), # 33 parameters
act = relu,
) do x
embed = act(w1(x))
for w = w2
embed = act(w(embed))
end
out = w3(embed)
return out
end # Total: 3_265 parameters,
# plus 1 states.
We can initialize the model and train it with the same code as before!
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model) |> dev
x = rand(rng, Float32, n_in, 32) |> dev
@jit model(x, ps, st) # 1×32 Matrix and updated state as output.
x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :)
y_data = 2 .* x_data .- x_data .^ 3
x_data, y_data = dev(x_data), dev(y_data)
function train_model!(model, ps, st, x_data, y_data)
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0))
for iter in 1:1000
_, loss, _, train_state = Lux.Training.single_train_step!(
AutoEnzyme(), MSELoss(),
(x_data, y_data), train_state
)
if iter % 100 == 1 || iter == 1000
@printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
end
end
return model, ps, st
end
train_model!(model, ps, st, x_data, y_data)
2025-08-02 16:08:42.854298: I external/xla/xla/service/service.cc:163] XLA service 0x3f1c8150 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-08-02 16:08:42.854341: I external/xla/xla/service/service.cc:171] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1754150922.855357 75426 se_gpu_pjrt_client.cc:1380] Using BFC allocator.
I0000 00:00:1754150922.855453 75426 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1754150922.855503 75426 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
2025-08-02 16:08:42.869341: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800
E0000 00:00:1754150960.109113 75426 buffer_comparator.cc:147] Difference at 512: 0, expected 15.077
E0000 00:00:1754150960.109179 75426 buffer_comparator.cc:147] Difference at 513: 0, expected 14.8311
E0000 00:00:1754150960.109183 75426 buffer_comparator.cc:147] Difference at 514: 0, expected 13.6951
E0000 00:00:1754150960.109187 75426 buffer_comparator.cc:147] Difference at 515: 0, expected 16.2599
E0000 00:00:1754150960.109190 75426 buffer_comparator.cc:147] Difference at 516: 0, expected 14.5245
E0000 00:00:1754150960.109193 75426 buffer_comparator.cc:147] Difference at 517: 0, expected 14.1703
E0000 00:00:1754150960.109196 75426 buffer_comparator.cc:147] Difference at 518: 0, expected 14.8673
E0000 00:00:1754150960.109199 75426 buffer_comparator.cc:147] Difference at 519: 0, expected 16.2019
E0000 00:00:1754150960.109202 75426 buffer_comparator.cc:147] Difference at 520: 0, expected 16.0798
E0000 00:00:1754150960.109205 75426 buffer_comparator.cc:147] Difference at 521: 0, expected 14.494
2025-08-02 16:09:20.109215: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1075] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1754150960.112125 75426 buffer_comparator.cc:147] Difference at 528: 0, expected 15.2078
E0000 00:00:1754150960.112138 75426 buffer_comparator.cc:147] Difference at 529: 0, expected 17.9069
E0000 00:00:1754150960.112142 75426 buffer_comparator.cc:147] Difference at 530: 0, expected 14.8306
E0000 00:00:1754150960.112145 75426 buffer_comparator.cc:147] Difference at 531: 0, expected 16.1803
E0000 00:00:1754150960.112148 75426 buffer_comparator.cc:147] Difference at 532: 0, expected 15.7698
E0000 00:00:1754150960.112151 75426 buffer_comparator.cc:147] Difference at 533: 0, expected 14.5381
E0000 00:00:1754150960.112154 75426 buffer_comparator.cc:147] Difference at 534: 0, expected 16.2759
E0000 00:00:1754150960.112157 75426 buffer_comparator.cc:147] Difference at 535: 0, expected 15.4057
E0000 00:00:1754150960.112160 75426 buffer_comparator.cc:147] Difference at 536: 0, expected 15.5044
E0000 00:00:1754150960.112163 75426 buffer_comparator.cc:147] Difference at 537: 0, expected 15.8215
2025-08-02 16:09:20.112168: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1075] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iteration: 0001 Loss: 2.08073235
Iteration: 0101 Loss: 0.142574623
Iteration: 0201 Loss: 0.0051055951
Iteration: 0301 Loss: 0.00118357129
Iteration: 0401 Loss: 0.000504208321
Iteration: 0501 Loss: 0.000281832268
Iteration: 0601 Loss: 0.000203011135
Iteration: 0701 Loss: 0.000126347542
Iteration: 0801 Loss: 0.00201115524
Iteration: 0901 Loss: 9.70276451e-05
Iteration: 1000 Loss: 7.81012277e-05
Training with Optimization.jl
If you are coming from the SciML ecosystem and want to use Optimization.jl, please refer to the Optimization.jl Tutorial.
Additional Packages
LuxDL
hosts various packages that provide additional functionality for Lux.jl. All packages mentioned in this documentation are available via the Julia General Registry.
You can install all those packages via import Pkg; Pkg.add(<package name>)
.
XLA (CPU/GPU/TPU) Support
Lux.jl supports XLA compilation for CPU, GPU, and TPU using Reactant.jl.
GPU Support
GPU Support for Lux.jl requires loading additional packages:
LuxCUDA.jl
for CUDA support.AMDGPU.jl
for AMDGPU support.Metal.jl
for Apple Metal support.oneAPI.jl
for oneAPI support.