Skip to content

Getting Started

Installation

Install Julia v1.10 or above. Lux.jl is available through the Julia package manager. You can enter it by pressing ] in the REPL and then typing add Lux. Alternatively, you can also do

julia
import Pkg
Pkg.add("Lux")

Update to v1

If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.

Quickstart

Pre-Requisites

You need to install Optimisers and Zygote if not done already. Pkg.add(["Optimisers", "Zygote"])

julia
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support

We take randomness very seriously

julia
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
Random.TaskLocalRNG()

Build the model

julia
# Construct the layer
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))
Chain(
    layer_1 = Dense(128 => 256, tanh),  # 33_024 parameters
    layer_2 = Chain(
        layer_1 = Dense(256 => 1, tanh),  # 257 parameters
        layer_2 = Dense(1 => 10),       # 20 parameters
    ),
)         # Total: 33_301 parameters,
          #        plus 0 states.

Models don't hold parameters and states so initialize them. From there on, we can just use our standard AD and Optimisers API. However, here we will show how to use Lux's Training API that provides an uniform API over all supported AD systems.

julia
# Get the device determined by Lux
dev = gpu_device()

# Parameter and State Variables
ps, st = Lux.setup(rng, model) |> dev

# Dummy Input
x = rand(rng, Float32, 128, 2) |> dev

# Run the model
y, st = Lux.apply(model, x, ps, st)

# Gradients
## First construct a TrainState
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))

## We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux.Training.compute_gradients(
    AutoZygote(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state
)

## Optimization
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)

# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training.single_train_step!(
    AutoZygote(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state
)
((layer_1 = (weight = Float32[0.0017983615 0.006062332 … 0.0053392933 0.0056276177; 0.0011292367 0.0041270256 … 0.003585879 0.0038155357; … ; -0.0008762945 -0.0031371699 … -0.0027350332 -0.0029033197; 0.0011154839 0.002197485 … 0.0021741025 0.0021157824], bias = Float32[0.006656272, 0.004425203, 0.0028994146, -0.0116051175, 0.0031301186, 0.0037318026, 0.0136483535, 0.013969757, -0.015173428, -0.005173992  …  -0.0018621369, -0.0015270555, -0.007873881, -0.0076395273, -0.0022123815, 0.0039605754, 0.0034407252, -0.0045406874, -0.003383829, 0.0029306945]), layer_2 = (layer_1 = (weight = Float32[0.04993449 0.03202845 … -0.059382 0.07701616], bias = Float32[0.08797912]), layer_2 = (weight = Float32[-0.094527975; -0.11476975; … ; -0.016841749; -0.0698748;;], bias = Float32[-0.21608135, -0.26255828, -0.23534852, -0.21524015, -0.055711076, -0.20314303, -0.1895644, 0.03666526, -0.03937737, -0.15905891]))), 0.8455785f0, NamedTuple(), Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}}(nothing, nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}((layer_1 = Dense(128 => 256, tanh), layer_2 = Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(256 => 1, tanh), layer_2 = Dense(1 => 10)), nothing)), nothing), (layer_1 = (weight = Float32[-0.22542597 0.22379348 … 0.1997513 -0.018708104; -0.023026714 0.15451026 … -0.065325744 0.18120264; … ; 0.038037397 -0.07125516 … -0.03306083 0.039138064; -0.18810266 -0.09693537 … -0.18102062 0.019230088], bias = Float32[0.030937059, -0.060276944, 0.084569596, 0.00040024254, -0.065509446, -0.08527214, -0.026523968, 0.06347208, 0.042247728, 0.027705256  …  -0.06052852, 0.03504307, -0.028244259, 0.06788022, 0.0027464977, -0.06942153, 0.0064240773, 0.0141069945, -0.029283267, 0.01174226]), layer_2 = (layer_1 = (weight = Float32[0.12008221 0.06026435 … -0.070576 0.1577647], bias = Float32[0.026844418]), layer_2 = (weight = Float32[0.5345728; -0.28288874; … ; -0.32983455; -0.45298168;;], bias = Float32[-0.59751064, -0.7033041, -0.8457602, -0.53789175, -0.31473723, 0.17461234, -0.82945836, 0.67841595, 0.35837248, -0.14941788]))), (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), Adam(0.0001, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.000926728 0.000860063 … 0.00110328 0.000908301; 0.000480834 0.000574605 … 0.000665883 0.000584197; … ; -0.000391039 -0.000438617 … -0.000520651 -0.000449867; 0.00106235 0.000365587 … 0.000813131 0.000495484], Float32[7.20343f-8 4.46976f-8 … 6.84867f-8 4.63952f-8; 1.79691f-8 2.02649f-8 … 2.45046f-8 1.96227f-8; … ; 1.21215f-8 1.17657f-8 … 1.50136f-8 1.15681f-8; 1.12738f-7 7.45199f-9 … 4.8495f-8 1.44173f-8], (0.729, 0.997003))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.00169459, 0.000977637, 0.00103866, -0.00234933, 0.000659175, 0.000868318, 0.00303222, 0.00271383, -0.00326585, -0.0014993  …  -0.000480712, -0.000501535, -0.00174489, -0.00160158, -0.000470662, 0.00127967, 0.000618911, -0.00103705, -0.000773079, 0.00146704], Float32[1.74884f-7, 5.48983f-8, 7.75433f-8, 3.08981f-7, 2.45763f-8, 4.41623f-8, 5.29156f-7, 4.09021f-7, 6.07287f-7, 1.45678f-7  …  1.4164f-8, 1.73391f-8, 1.7507f-7, 1.44894f-7, 1.25673f-8, 1.1198f-7, 2.11545f-8, 6.25338f-8, 3.4755f-8, 1.78565f-7], (0.729, 0.997003)))), layer_2 = (layer_1 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.00443555 0.00163654 … -0.0124978 0.0123434], Float32[2.53181f-6 1.32838f-6 … 8.83289f-6 8.58873f-6], (0.729, 0.997003))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.0191175], Float32[2.08743f-5], (0.729, 0.997003)))), layer_2 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[-0.0172084; -0.0213176; … ; -0.00376332; -0.0116419;;], Float32[1.63537f-5; 2.51152f-5; … ; 8.16783f-7; 7.55419f-6;;], (0.729, 0.997003))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[-0.0365001, -0.045083, -0.0507623, -0.0390298, -0.0242259, -0.0404982, -0.0358925, 0.0114351, -0.00803444, -0.0248332], Float32[7.40417f-5, 0.000112652, 0.000146818, 8.41229f-5, 4.60234f-5, 9.15105f-5, 7.13093f-5, 8.78741f-6, 3.62043f-6, 3.51285f-5], (0.729, 0.997003)))))), 2))

Defining Custom Layers

We can train our model using the above code, but let's go ahead and see how to use Reactant. Reactant is a julia frontend that generates MLIR and then compiles it using XLA (after running fancy optimizations). It is the current recommended way to train large models in Lux. For more details on using Reactant, see the manual.

julia
using Lux, Random, Optimisers, Reactant, Enzyme
using Printf # For pretty printing

dev = reactant_device()
(::ReactantDevice{Missing, Missing}) (generic function with 1 method)

We will define a custom MLP using the @compact macro. The macro takes in a list of parameters, layers and states, and a function defining the forward pass of the neural network.

julia
n_in = 1
n_out = 1
nlayers = 3

model = @compact(
    w1=Dense(n_in => 32),
    w2=[Dense(32 => 32) for i in 1:nlayers],
    w3=Dense(32 => n_out),
    act=relu
) do x
    embed = act(w1(x))
    for w in w2
        embed = act(w(embed))
    end
    out = w3(embed)
    @return out
end
@compact(
    w1 = Dense(1 => 32),                # 64 parameters
    w2 = NamedTuple(
        1 = Dense(32 => 32),            # 1_056 parameters
        2 = Dense(32 => 32),            # 1_056 parameters
        3 = Dense(32 => 32),            # 1_056 parameters
    ),
    w3 = Dense(32 => 1),                # 33 parameters
    act = relu,
) do x 
    embed = act(w1(x))
    for w = w2
        embed = act(w(embed))
    end
    out = w3(embed)
    return out
end       # Total: 3_265 parameters,
          #        plus 1 states.

We can initialize the model and train it with the same code as before!

julia
rng = Random.default_rng()
Random.seed!(rng, 0)

ps, st = Lux.setup(rng, model) |> dev

x = rand(rng, Float32, n_in, 32) |> dev

@jit model(x, ps, st)  # 1×32 Matrix and updated state as output.

x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :)
y_data = 2 .* x_data .- x_data .^ 3
x_data, y_data = dev(x_data), dev(y_data)

function train_model!(model, ps, st, x_data, y_data)
    train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0))

    for iter in 1:1000
        _, loss, _, train_state = Lux.Training.single_train_step!(
            AutoEnzyme(), MSELoss(),
            (x_data, y_data), train_state
        )
        if iter % 100 == 1 || iter == 1000
            @printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
        end
    end

    return model, ps, st
end

train_model!(model, ps, st, x_data, y_data)
2025-01-08 22:23:45.480635: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 17191936659798093842
Iteration: 0001 	 Loss: 2.08073235
Iteration: 0101 	 Loss: 0.142574623
Iteration: 0201 	 Loss: 0.0051055951
Iteration: 0301 	 Loss: 0.00118357129
Iteration: 0401 	 Loss: 0.000504208321
Iteration: 0501 	 Loss: 0.000281832268
Iteration: 0601 	 Loss: 0.000203011135
Iteration: 0701 	 Loss: 0.000126347542
Iteration: 0801 	 Loss: 0.00201115524
Iteration: 0901 	 Loss: 9.70276451e-05
Iteration: 1000 	 Loss: 7.81012277e-05

Training with Optimization.jl

If you are coming from the SciML ecosystem and want to use Optimization.jl, please refer to the Optimization.jl Tutorial.

Additional Packages

LuxDL hosts various packages that provide additional functionality for Lux.jl. All packages mentioned in this documentation are available via the Julia General Registry.

You can install all those packages via import Pkg; Pkg.add(<package name>).

XLA (CPU/GPU/TPU) Support

Lux.jl supports XLA compilation for CPU, GPU, and TPU using Reactant.jl.

GPU Support

GPU Support for Lux.jl requires loading additional packages: