Getting Started
Installation
Install Julia v1.10 or above. Lux.jl is available through the Julia package manager. You can enter it by pressing ]
in the REPL and then typing add Lux
. Alternatively, you can also do
import Pkg
Pkg.add("Lux")
Update to v1
If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.
Quickstart
Pre-Requisites
You need to install Optimisers
and Zygote
if not done already. Pkg.add(["Optimisers", "Zygote"])
using Lux, Random, Optimisers, Zygote
using LuxCUDA # For CUDA support
# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support
We take randomness very seriously
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
Random.TaskLocalRNG()
Build the model
# Construct the layer
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))
Chain(
layer_1 = Dense(128 => 256, tanh), # 33_024 parameters
layer_2 = Chain(
layer_1 = Dense(256 => 1, tanh), # 257 parameters
layer_2 = Dense(1 => 10), # 20 parameters
),
) # Total: 33_301 parameters,
# plus 0 states.
Models don't hold parameters and states so initialize them. From there on, we can just use our standard AD and Optimisers API. However, here we will show how to use Lux's Training API that provides an uniform API over all supported AD systems.
# Get the device determined by Lux
dev = gpu_device()
# Parameter and State Variables
ps, st = Lux.setup(rng, model) |> dev
# Dummy Input
x = rand(rng, Float32, 128, 2) |> dev
# Run the model
y, st = Lux.apply(model, x, ps, st)
# Gradients
## First construct a TrainState
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
## We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux.Training.compute_gradients(AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state)
## Optimization
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)
# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state)
((layer_1 = (weight = Float32[0.00414443 0.002795561 … 0.004169056 0.0031376465; 0.0021067895 0.0018510974 … 0.0024323424 0.0019732846; … ; -0.0017369931 -0.0014175161 … -0.001926295 -0.001531324; 0.005014103 0.0012843239 … 0.0035166647 0.0019504727], bias = Float32[0.006772918, 0.0037725866, 0.004418927, -0.008870332, 0.002497137, 0.003410134, 0.011739187, 0.010097131, -0.01252219, -0.006184276 … -0.0019153776, -0.0021203351, -0.006736364, -0.00610999, -0.0017900262, 0.0053856988, 0.0022308973, -0.004047676, -0.0030270987, 0.0065859808]), layer_2 = (layer_1 = (weight = Float32[0.00957886 3.79486f-5 … -0.04735685 0.04226011], bias = Float32[0.07344248]), layer_2 = (weight = Float32[-0.098861896; -0.097854555; … ; -0.022434529; -0.07412191;;], bias = Float32[-0.22638096, -0.22459084, -0.2545935, -0.24099024, -0.07057327, -0.17239988, -0.17449908, 0.077043094, -0.05008901, -0.16909766]))), 0.87143785f0, NamedTuple(), Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, layer_2::@NamedTuple{weight::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}}}}}(nothing, nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}((layer_1 = Dense(128 => 256, tanh), layer_2 = Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(256 => 1, tanh), layer_2 = Dense(1 => 10)), nothing)), nothing), (layer_1 = (weight = Float32[-0.22543387 0.2237958 … 0.19975941 -0.018701272; -0.023031702 0.15451166 … -0.065317236 0.181208; … ; 0.03804328 -0.07125676 … -0.03306928 0.039132368; -0.18811704 -0.09692798 … -0.18102339 0.019236464], bias = Float32[0.03094203, -0.06026962, 0.08456781, 0.00039204443, -0.06550143, -0.08526564, -0.026516732, 0.063480526, 0.04224006, 0.027702922 … -0.06053336, 0.0350433, -0.028251555, 0.067872316, 0.0027386106, -0.06942138, 0.006432486, 0.014100174, -0.029290024, 0.011734666]), layer_2 = (layer_1 = (weight = Float32[0.11984776 0.060364496 … -0.07058401 0.1577715], bias = Float32[0.026851978]), layer_2 = (weight = Float32[0.5345716; -0.28289196; … ; -0.32984197; -0.45298263;;], bias = Float32[-0.5975106, -0.70330536, -0.84576, -0.5378918, -0.3147238, 0.1746128, -0.82945824, 0.6784163, 0.35836592, -0.14941669]))), (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), Adam(0.0001, (0.9, 0.999), 1.0e-8), (layer_1 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.00139126 0.00077649 … 0.00128162 0.000910803; 0.000699265 0.000510198 … 0.000731462 0.000563284; … ; -0.000574894 -0.000391076 … -0.000580705 -0.000438133; 0.00170922 0.000374875 … 0.00115296 0.000609521], Float32[1.34855f-7 3.8271f-8 … 1.09599f-7 5.38068f-8; 3.38798f-8 1.64605f-8 … 3.53144f-8 2.04108f-8; … ; 2.28682f-8 9.67595f-9 … 2.22846f-8 1.23626f-8; 2.05058f-7 9.13988f-9 … 9.15542f-8 2.49913f-8], (0.729, 0.997003))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.00214947, 0.00117228, 0.00145586, -0.00270753, 0.00077338, 0.00106574, 0.00364828, 0.00306169, -0.0038745, -0.00198768 … -0.000612726, -0.000689046, -0.00209504, -0.00187748, -0.000554918, 0.00175152, 0.000667028, -0.00126273, -0.000942026, 0.00219681], Float32[3.13168f-7, 9.21857f-8, 1.46327f-7, 4.87431f-7, 4.00564f-8, 7.64058f-8, 8.929f-7, 6.21252f-7, 1.00488f-6, 2.69472f-7 … 2.55476f-8, 3.25587f-8, 2.94557f-7, 2.35152f-7, 2.06325f-8, 2.10458f-7, 2.92832f-8, 1.07167f-7, 5.9572f-8, 3.35188f-7], (0.729, 0.997003)))), layer_2 = (layer_1 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.0018058 -0.000936726 … -0.0146587 0.0123139], Float32[1.80424f-7 1.09098f-7 … 1.43866f-5 9.85344f-6], (0.729, 0.997003))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.0227737], Float32[3.47551f-5], (0.729, 0.997003)))), layer_2 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[-0.0207926; -0.0231775; … ; -0.00706411; -0.0158726;;], Float32[2.44439f-5; 3.16946f-5; … ; 3.36937f-6; 1.4322f-5;;], (0.729, 0.997003))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[-0.0435721, -0.0480501, -0.0540537, -0.0420834, -0.0105465, -0.0302464, -0.0315817, 0.0244214, -0.0142814, -0.0330944], Float32[0.000105296, 0.00013121, 0.000165657, 9.79656f-5, 6.48203f-6, 5.05851f-5, 5.50798f-5, 4.04019f-5, 1.31128f-5, 6.08993f-5], (0.729, 0.997003)))))), 2))
Defining Custom Layers
using Lux, Random, Optimisers, Zygote
using LuxCUDA # For CUDA support
# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support
using Printf # For pretty printing
dev = gpu_device()
(::CUDADevice{Nothing}) (generic function with 4 methods)
We will define a custom MLP using the @compact
macro. The macro takes in a list of parameters, layers and states, and a function defining the forward pass of the neural network.
n_in = 1
n_out = 1
nlayers = 3
model = @compact(w1=Dense(n_in => 32),
w2=[Dense(32 => 32) for i in 1:nlayers],
w3=Dense(32 => n_out),
act=relu) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
end
out = w3(embed)
@return out
end
@compact(
w1 = Dense(1 => 32), # 64 parameters
w2 = NamedTuple(
1 = Dense(32 => 32), # 1_056 parameters
2 = Dense(32 => 32), # 1_056 parameters
3 = Dense(32 => 32), # 1_056 parameters
),
w3 = Dense(32 => 1), # 33 parameters
act = relu,
) do x
embed = act(w1(x))
for w = w2
embed = act(w(embed))
end
out = w3(embed)
return out
end # Total: 3_265 parameters,
# plus 1 states.
We can initialize the model and train it with the same code as before!
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model) |> dev
x = rand(rng, Float32, n_in, 32) |> dev
model(x, ps, st) # 1×32 Matrix and updated state as output.
x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :) |> dev
y_data = 2 .* x_data .- x_data .^ 3
function train_model!(model, ps, st, x_data, y_data)
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0))
for iter in 1:1000
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), MSELoss(),
(x_data, y_data), train_state)
if iter % 100 == 1 || iter == 1000
@printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
end
end
return model, ps, st
end
train_model!(model, ps, st, x_data, y_data)
Iteration: 0001 Loss: 2.08085155
Iteration: 0101 Loss: 0.131583616
Iteration: 0201 Loss: 0.00390525744
Iteration: 0301 Loss: 0.000871082244
Iteration: 0401 Loss: 0.000430405606
Iteration: 0501 Loss: 0.00144559005
Iteration: 0601 Loss: 0.00188280677
Iteration: 0701 Loss: 0.000128166575
Iteration: 0801 Loss: 0.000106401763
Iteration: 0901 Loss: 9.37113291e-05
Iteration: 1000 Loss: 7.96537279e-05
Training with Optimization.jl
If you are coming from the SciML ecosystem and want to use Optimization.jl, please refer to the Optimization.jl Tutorial.
Additional Packages
LuxDL
hosts various packages that provide additional functionality for Lux.jl. All packages mentioned in this documentation are available via the Julia General Registry.
You can install all those packages via import Pkg; Pkg.add(<package name>)
.
GPU Support
GPU Support for Lux.jl requires loading additional packages:
LuxCUDA.jl
for CUDA support.AMDGPU.jl
for AMDGPU support.Metal.jl
for Apple Metal support.oneAPI.jl
for oneAPI support.