Getting Started
Installation
Install Julia v1.10 or above. Lux.jl is available through the Julia package manager. You can enter it by pressing ] in the REPL and then typing add Lux. Alternatively, you can also do
import Pkg
Pkg.add("Lux")Update to v1
If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.
Quickstart
Pre-Requisites
You need to install Optimisers and Zygote if not done already. Pkg.add(["Optimisers", "Zygote"])
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU supportWe take randomness very seriously
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)Random.TaskLocalRNG()Build the model
# Construct the layer
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))Chain(
layer_1 = Dense(128 => 256, tanh), # 33_024 parameters
layer_2 = Chain(
layer_1 = Dense(256 => 1, tanh), # 257 parameters
layer_2 = Dense(1 => 10), # 20 parameters
),
) # Total: 33_301 parameters,
# plus 0 states.Models don't hold parameters and states so initialize them. From there on, we just use our standard AD and Optimisers API.
# Get the device determined by Lux
dev = gpu_device()
# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> dev
# Dummy Input
x = rand(rng, Float32, 128, 2) |> dev
# Run the model
y, st = Lux.apply(model, x, ps, st)
# Gradients
## Pullback API to capture change in state
(l, st_), pb = pullback(Lux.apply, model, x, ps, st)
gs = pb((one.(l), nothing))[3]
# Optimization
st_opt = Optimisers.setup(Adam(0.0001f0), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs) # or Optimisers.update!(st_opt, ps, gs)((layer_1 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.00244247 0.00182434 … 0.0025857 0.00200468; 0.00124785 0.00121261 … 0.00152527 0.001271; … ; -0.00101863 -0.000926631 … -0.00119906 -0.000981905; 0.00289425 0.000812688 … 0.00208184 0.00118862], Float32[5.96558f-7 3.32818f-7 … 6.68577f-7 4.01871f-7; 1.55711f-7 1.47039f-7 … 2.32642f-7 1.61543f-7; … ; 1.03759f-7 8.58633f-8 … 1.43772f-7 9.64124f-8; 8.37656f-7 6.60454f-8 … 4.33399f-7 1.41281f-7], (0.81, 0.998001))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.00412708, 0.00232358, 0.00265797, -0.00548918, 0.00155959, 0.00208427, 0.00722279, 0.00629379, -0.00774017, -0.00372379 … -0.00118024, -0.0012661, -0.00415068, -0.00376983, -0.00111527, 0.00323512, 0.00140803, -0.00248214, -0.00184828, 0.00385626], Float32[1.70326f-6, 5.39895f-7, 7.06473f-7, 3.01307f-6, 2.43228f-7, 4.34414f-7, 5.2168f-6, 3.96113f-6, 5.99094f-6, 1.38664f-6 … 1.39294f-7, 1.60299f-7, 1.7228f-6, 1.42114f-6, 1.24381f-7, 1.04659f-6, 1.98252f-7, 6.16092f-7, 3.41611f-7, 1.48706f-6], (0.81, 0.998001)))), layer_2 = (layer_1 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.00768371 0.00155563 … -0.0295524 0.0273372], Float32[5.90385f-6 2.41994f-7 … 8.73331f-5 7.47312f-5], (0.81, 0.998001))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.0453981], Float32[0.000206096], (0.81, 0.998001)))), layer_2 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.104277; 0.104277; … ; 0.104277; 0.104277;;], Float32[0.00108735; 0.00108735; … ; 0.00108735; 0.00108735;;], (0.81, 0.998001))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], Float32[0.00399995, 0.00399995, 0.00399995, 0.00399995, 0.00399995, 0.00399995, 0.00399995, 0.00399995, 0.00399995, 0.00399995], (0.81, 0.998001)))))), (layer_1 = (weight = Float32[-0.22534472 0.2238892 … 0.1998505 -0.018608876; -0.022942306 0.15460524 … -0.065225646 0.18130077; … ; 0.037953824 -0.07135031 … -0.033160813 0.039039645; -0.18802822 -0.09683571 … -0.18093373 0.019327192], bias = Float32[0.031032413, -0.060178764, 0.08465736, 0.00030078756, -0.0654105, -0.08517491, -0.02642588, 0.06357193, 0.04214911, 0.02761282 … -0.060623564, 0.034953445, -0.02834239, 0.06778121, 0.002647703, -0.06933154, 0.00652421, 0.014009408, -0.029380847, 0.011823955]), layer_2 = (layer_1 = (weight = Float32[0.1199478 0.060097758 … -0.07067495 0.15786381], bias = Float32[0.02694288]), layer_2 = (weight = Float32[0.5342726; -0.28318882; … ; -0.3301325; -0.4532813;;], bias = Float32[-0.5978105, -0.7036041, -0.84605885, -0.5381919, -0.31502125, 0.17431273, -0.8297584, 0.67850673, 0.35807315, -0.14971648]))))Defining Custom Layers
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
using Printf # For pretty printingWe will define a custom MLP using the @compact macro. The macro takes in a list of parameters, layers and states, and a function defining the forward pass of the neural network.
n_in = 1
n_out = 1
nlayers = 3
model = @compact(w1=Dense(n_in, 128),
w2=[Dense(128, 128) for i in 1:nlayers],
w3=Dense(128, n_out),
act=relu) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
end
out = w3(embed)
@return out
end@compact(
w1 = Dense(1 => 128), # 256 parameters
w2 = NamedTuple(
1 = Dense(128 => 128), # 16_512 parameters
2 = Dense(128 => 128), # 16_512 parameters
3 = Dense(128 => 128), # 16_512 parameters
),
w3 = Dense(128 => 1), # 129 parameters
act = relu,
) do x
embed = act(w1(x))
for w = w2
embed = act(w(embed))
end
out = w3(embed)
return out
end # Total: 49_921 parameters,
# plus 1 states.We can initialize the model and train it with the same code as before!
ps, st = Lux.setup(Xoshiro(0), model)
model(randn(n_in, 32), ps, st) # 1×32 Matrix as output.
x_data = collect(-2.0f0:0.1f0:2.0f0)'
y_data = 2 .* x_data .- x_data .^ 3
st_opt = Optimisers.setup(Adam(), ps)
for epoch in 1:1000
global st # Put this in a function in real use-cases
(loss, st), pb = Zygote.pullback(ps) do p
y, st_ = model(x_data, p, st)
return sum(abs2, y .- y_data), st_
end
gs = only(pb((one(loss), nothing)))
epoch % 100 == 1 && @printf "Epoch: %04d \t Loss: %10.9g\n" epoch loss
Optimisers.update!(st_opt, ps, gs)
endEpoch: 0001 Loss: 82.5918655
Epoch: 0101 Loss: 0.0203782618
Epoch: 0201 Loss: 0.141528264
Epoch: 0301 Loss: 0.0223009642
Epoch: 0401 Loss: 0.00288237352
Epoch: 0501 Loss: 0.00413147826
Epoch: 0601 Loss: 0.0117701376
Epoch: 0701 Loss: 0.00478190044
Epoch: 0801 Loss: 0.00680404948
Epoch: 0901 Loss: 0.00144632021Additional Packages
LuxDL hosts various packages that provide additional functionality for Lux.jl. All packages mentioned in this documentation are available via the Julia General Registry.
You can install all those packages via import Pkg; Pkg.add(<package name>).
GPU Support
GPU Support for Lux.jl requires loading additional packages:
LuxCUDA.jlfor CUDA support.AMDGPU.jlfor AMDGPU support.Metal.jlfor Apple Metal support.oneAPI.jlfor oneAPI support.