Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote
.
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling JLD2...
4297.8 ms ✓ FileIO
32260.2 ms ✓ JLD2
2 dependencies successfully precompiled in 37 seconds. 30 already precompiled.
Precompiling MLUtils...
445.6 ms ✓ DelimitedFiles
841.4 ms ✓ BangBang
589.9 ms ✓ ContextVariablesX
1193.0 ms ✓ SimpleTraits
787.6 ms ✓ InverseFunctions → InverseFunctionsTestExt
663.9 ms ✓ Accessors → TestExt
682.7 ms ✓ BangBang → BangBangStaticArraysExt
506.9 ms ✓ BangBang → BangBangChainRulesCoreExt
1398.3 ms ✓ SplittablesBase
509.5 ms ✓ BangBang → BangBangTablesExt
599.8 ms ✓ FLoopsBase
1052.9 ms ✓ MicroCollections
1071.6 ms ✓ MLCore
2708.3 ms ✓ Transducers
674.2 ms ✓ Transducers → TransducersAdaptExt
5380.0 ms ✓ FLoops
5958.9 ms ✓ MLUtils
17 dependencies successfully precompiled in 18 seconds. 83 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
718.1 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1720.8 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 104 already precompiled.
Precompiling LuxMLUtilsExt...
2349.8 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 3 seconds. 170 already precompiled.
Precompiling Reactant...
599.6 ms ✓ ReactantCore
639.7 ms ✓ LLVMOpenMP_jll
937.7 ms ✓ CUDA_Driver_jll
2249.3 ms ✓ Reactant_jll
219755.0 ms ✓ Enzyme
5864.0 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
77146.1 ms ✓ Reactant
7 dependencies successfully precompiled in 303 seconds. 57 already precompiled.
Precompiling LuxLibEnzymeExt...
6077.1 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
8324.4 ms ✓ Enzyme → EnzymeStaticArraysExt
1315.3 ms ✓ LuxLib → LuxLibEnzymeExt
11395.0 ms ✓ Enzyme → EnzymeChainRulesCoreExt
6089.9 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
5 dependencies successfully precompiled in 13 seconds. 128 already precompiled.
Precompiling LuxEnzymeExt...
6904.0 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling LuxCoreReactantExt...
12365.0 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 13 seconds. 69 already precompiled.
Precompiling MLDataDevicesReactantExt...
12808.9 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 13 seconds. 66 already precompiled.
Precompiling LuxLibReactantExt...
12681.4 ms ✓ Reactant → ReactantStatisticsExt
13285.7 ms ✓ Reactant → ReactantNNlibExt
13317.6 ms ✓ LuxLib → LuxLibReactantExt
12721.6 ms ✓ Reactant → ReactantKernelAbstractionsExt
12625.0 ms ✓ Reactant → ReactantArrayInterfaceExt
12737.0 ms ✓ Reactant → ReactantSpecialFunctionsExt
6 dependencies successfully precompiled in 26 seconds. 143 already precompiled.
Precompiling WeightInitializersReactantExt...
12780.0 ms ✓ WeightInitializers → WeightInitializersReactantExt
1 dependency successfully precompiled in 13 seconds. 80 already precompiled.
Precompiling LuxReactantExt...
10361.9 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 11 seconds. 166 already precompiled.
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size = 1000, sequence_length = 50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [
reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]
]
anticlockwise_spirals = [
reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1
)
for d in data[((dataset_size ÷ 2) + 1):end]
]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims = 3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at = 0.8, shuffle = true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize = 128, shuffle = true, partial = false
),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize = 128, shuffle = false, partial = false),
)
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
)
end
Main.var"##230".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple
) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier = st_classifier, lstm_cell = st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred = ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() |> dev
# Create the model
model = model_type(2, 8, 1)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (total_acc / total_samples)
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-03-07 23:46:55.628278: I external/xla/xla/service/service.cc:152] XLA service 0x9ae9d50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-07 23:46:55.628438: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741391215.629255 3638329 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741391215.629320 3638329 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741391215.629364 3638329 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741391215.644784 3638329 cuda_dnn.cc:529] Loaded cuDNN version 90400
Epoch [ 1]: Loss 0.65806
Validation: Loss 0.58377 Accuracy 1.00000
Epoch [ 2]: Loss 0.53889
Validation: Loss 0.48049 Accuracy 1.00000
Epoch [ 3]: Loss 0.44765
Validation: Loss 0.39908 Accuracy 1.00000
Epoch [ 4]: Loss 0.37392
Validation: Loss 0.33268 Accuracy 1.00000
Epoch [ 5]: Loss 0.31215
Validation: Loss 0.27409 Accuracy 1.00000
Epoch [ 6]: Loss 0.25143
Validation: Loss 0.20997 Accuracy 1.00000
Epoch [ 7]: Loss 0.18228
Validation: Loss 0.15024 Accuracy 1.00000
Epoch [ 8]: Loss 0.13617
Validation: Loss 0.11492 Accuracy 1.00000
Epoch [ 9]: Loss 0.10375
Validation: Loss 0.08718 Accuracy 1.00000
Epoch [ 10]: Loss 0.07819
Validation: Loss 0.06623 Accuracy 1.00000
Epoch [ 11]: Loss 0.05933
Validation: Loss 0.05051 Accuracy 1.00000
Epoch [ 12]: Loss 0.04508
Validation: Loss 0.03884 Accuracy 1.00000
Epoch [ 13]: Loss 0.03473
Validation: Loss 0.03074 Accuracy 1.00000
Epoch [ 14]: Loss 0.02794
Validation: Loss 0.02558 Accuracy 1.00000
Epoch [ 15]: Loss 0.02355
Validation: Loss 0.02204 Accuracy 1.00000
Epoch [ 16]: Loss 0.02055
Validation: Loss 0.01939 Accuracy 1.00000
Epoch [ 17]: Loss 0.01814
Validation: Loss 0.01732 Accuracy 1.00000
Epoch [ 18]: Loss 0.01625
Validation: Loss 0.01565 Accuracy 1.00000
Epoch [ 19]: Loss 0.01473
Validation: Loss 0.01429 Accuracy 1.00000
Epoch [ 20]: Loss 0.01348
Validation: Loss 0.01314 Accuracy 1.00000
Epoch [ 21]: Loss 0.01241
Validation: Loss 0.01217 Accuracy 1.00000
Epoch [ 22]: Loss 0.01141
Validation: Loss 0.01133 Accuracy 1.00000
Epoch [ 23]: Loss 0.01076
Validation: Loss 0.01060 Accuracy 1.00000
Epoch [ 24]: Loss 0.01004
Validation: Loss 0.00996 Accuracy 1.00000
Epoch [ 25]: Loss 0.00941
Validation: Loss 0.00939 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.56069
Validation: Loss 0.50002 Accuracy 0.50781
Epoch [ 2]: Loss 0.45935
Validation: Loss 0.40520 Accuracy 0.50781
Epoch [ 3]: Loss 0.39022
Validation: Loss 0.34996 Accuracy 1.00000
Epoch [ 4]: Loss 0.32765
Validation: Loss 0.28566 Accuracy 1.00000
Epoch [ 5]: Loss 0.26878
Validation: Loss 0.22812 Accuracy 1.00000
Epoch [ 6]: Loss 0.20919
Validation: Loss 0.17706 Accuracy 1.00000
Epoch [ 7]: Loss 0.15947
Validation: Loss 0.13204 Accuracy 1.00000
Epoch [ 8]: Loss 0.11734
Validation: Loss 0.09225 Accuracy 1.00000
Epoch [ 9]: Loss 0.08175
Validation: Loss 0.06725 Accuracy 1.00000
Epoch [ 10]: Loss 0.06103
Validation: Loss 0.05031 Accuracy 1.00000
Epoch [ 11]: Loss 0.04458
Validation: Loss 0.03722 Accuracy 1.00000
Epoch [ 12]: Loss 0.03371
Validation: Loss 0.02848 Accuracy 1.00000
Epoch [ 13]: Loss 0.02671
Validation: Loss 0.02289 Accuracy 1.00000
Epoch [ 14]: Loss 0.02178
Validation: Loss 0.01941 Accuracy 1.00000
Epoch [ 15]: Loss 0.01880
Validation: Loss 0.01707 Accuracy 1.00000
Epoch [ 16]: Loss 0.01657
Validation: Loss 0.01535 Accuracy 1.00000
Epoch [ 17]: Loss 0.01505
Validation: Loss 0.01396 Accuracy 1.00000
Epoch [ 18]: Loss 0.01353
Validation: Loss 0.01277 Accuracy 1.00000
Epoch [ 19]: Loss 0.01258
Validation: Loss 0.01171 Accuracy 1.00000
Epoch [ 20]: Loss 0.01151
Validation: Loss 0.01066 Accuracy 1.00000
Epoch [ 21]: Loss 0.01035
Validation: Loss 0.00954 Accuracy 1.00000
Epoch [ 22]: Loss 0.00910
Validation: Loss 0.00826 Accuracy 1.00000
Epoch [ 23]: Loss 0.00791
Validation: Loss 0.00705 Accuracy 1.00000
Epoch [ 24]: Loss 0.00676
Validation: Loss 0.00615 Accuracy 1.00000
Epoch [ 25]: Loss 0.00591
Validation: Loss 0.00544 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.