Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
Note: If you wish to use AutoZygote()
for automatic differentiation, add Zygote to your project dependencies and include using Zygote
.
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function create_dataset(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [
reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
d in data[1:(dataset_size ÷ 2)]
]
anticlockwise_spirals = [
reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
d in data[((dataset_size ÷ 2) + 1):end]
]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
return x_data, labels
end
function get_dataloaders(; dataset_size=1000, sequence_length=50)
x_data, labels = create_dataset(; dataset_size, sequence_length)
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
)
end
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a LSTM block and a classifier head.
We pass the field names lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
)
end
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
Defining Accuracy, Loss and Optimiser
Now let's define the binary cross-entropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = dev(get_dataloaders())
# Create the model
model = model_type(2, 8, 1)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
end
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf("Epoch [%3d]: Loss %4.5f\n", epoch, total_loss / total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf(
"Validation:\tLoss %4.5f\tAccuracy %4.5f\n",
total_loss / total_samples,
total_acc / total_samples
)
end
return cpu_device()((train_state.parameters, train_state.states))
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1759714606.403657 1334286 service.cc:158] XLA service 0x3c28e510 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1759714606.404142 1334286 service.cc:166] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1759714606.406242 1334286 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1759714606.406533 1334286 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1759714606.406811 1334286 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1759714606.444287 1334286 cuda_dnn.cc:463] Loaded cuDNN version 91200
Epoch [ 1]: Loss 0.56466
Validation: Loss 0.46252 Accuracy 1.00000
Epoch [ 2]: Loss 0.41123
Validation: Loss 0.36907 Accuracy 1.00000
Epoch [ 3]: Loss 0.33550
Validation: Loss 0.28746 Accuracy 1.00000
Epoch [ 4]: Loss 0.23725
Validation: Loss 0.17010 Accuracy 1.00000
Epoch [ 5]: Loss 0.14055
Validation: Loss 0.10802 Accuracy 1.00000
Epoch [ 6]: Loss 0.09527
Validation: Loss 0.07789 Accuracy 1.00000
Epoch [ 7]: Loss 0.06806
Validation: Loss 0.05487 Accuracy 1.00000
Epoch [ 8]: Loss 0.04801
Validation: Loss 0.03822 Accuracy 1.00000
Epoch [ 9]: Loss 0.03439
Validation: Loss 0.02883 Accuracy 1.00000
Epoch [ 10]: Loss 0.02655
Validation: Loss 0.02308 Accuracy 1.00000
Epoch [ 11]: Loss 0.02132
Validation: Loss 0.01895 Accuracy 1.00000
Epoch [ 12]: Loss 0.01792
Validation: Loss 0.01626 Accuracy 1.00000
Epoch [ 13]: Loss 0.01549
Validation: Loss 0.01425 Accuracy 1.00000
Epoch [ 14]: Loss 0.01370
Validation: Loss 0.01270 Accuracy 1.00000
Epoch [ 15]: Loss 0.01226
Validation: Loss 0.01149 Accuracy 1.00000
Epoch [ 16]: Loss 0.01112
Validation: Loss 0.01048 Accuracy 1.00000
Epoch [ 17]: Loss 0.01020
Validation: Loss 0.00964 Accuracy 1.00000
Epoch [ 18]: Loss 0.00939
Validation: Loss 0.00892 Accuracy 1.00000
Epoch [ 19]: Loss 0.00872
Validation: Loss 0.00830 Accuracy 1.00000
Epoch [ 20]: Loss 0.00812
Validation: Loss 0.00775 Accuracy 1.00000
Epoch [ 21]: Loss 0.00758
Validation: Loss 0.00726 Accuracy 1.00000
Epoch [ 22]: Loss 0.00712
Validation: Loss 0.00683 Accuracy 1.00000
Epoch [ 23]: Loss 0.00670
Validation: Loss 0.00643 Accuracy 1.00000
Epoch [ 24]: Loss 0.00632
Validation: Loss 0.00608 Accuracy 1.00000
Epoch [ 25]: Loss 0.00598
Validation: Loss 0.00575 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.60442
Validation: Loss 0.54005 Accuracy 0.51562
Epoch [ 2]: Loss 0.51477
Validation: Loss 0.46139 Accuracy 1.00000
Epoch [ 3]: Loss 0.43498
Validation: Loss 0.38172 Accuracy 1.00000
Epoch [ 4]: Loss 0.34525
Validation: Loss 0.27904 Accuracy 1.00000
Epoch [ 5]: Loss 0.23912
Validation: Loss 0.17893 Accuracy 1.00000
Epoch [ 6]: Loss 0.14934
Validation: Loss 0.11233 Accuracy 1.00000
Epoch [ 7]: Loss 0.09825
Validation: Loss 0.08111 Accuracy 1.00000
Epoch [ 8]: Loss 0.07412
Validation: Loss 0.06303 Accuracy 1.00000
Epoch [ 9]: Loss 0.05821
Validation: Loss 0.05079 Accuracy 1.00000
Epoch [ 10]: Loss 0.04781
Validation: Loss 0.04243 Accuracy 1.00000
Epoch [ 11]: Loss 0.04056
Validation: Loss 0.03627 Accuracy 1.00000
Epoch [ 12]: Loss 0.03465
Validation: Loss 0.03139 Accuracy 1.00000
Epoch [ 13]: Loss 0.03023
Validation: Loss 0.02747 Accuracy 1.00000
Epoch [ 14]: Loss 0.02644
Validation: Loss 0.02427 Accuracy 1.00000
Epoch [ 15]: Loss 0.02327
Validation: Loss 0.02157 Accuracy 1.00000
Epoch [ 16]: Loss 0.02095
Validation: Loss 0.01928 Accuracy 1.00000
Epoch [ 17]: Loss 0.01876
Validation: Loss 0.01735 Accuracy 1.00000
Epoch [ 18]: Loss 0.01683
Validation: Loss 0.01566 Accuracy 1.00000
Epoch [ 19]: Loss 0.01535
Validation: Loss 0.01420 Accuracy 1.00000
Epoch [ 20]: Loss 0.01382
Validation: Loss 0.01293 Accuracy 1.00000
Epoch [ 21]: Loss 0.01260
Validation: Loss 0.01183 Accuracy 1.00000
Epoch [ 22]: Loss 0.01165
Validation: Loss 0.01086 Accuracy 1.00000
Epoch [ 23]: Loss 0.01067
Validation: Loss 0.01001 Accuracy 1.00000
Epoch [ 24]: Loss 0.00986
Validation: Loss 0.00926 Accuracy 1.00000
Epoch [ 25]: Loss 0.00912
Validation: Loss 0.00859 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.