Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
Note: If you wish to use AutoZygote()
for automatic differentiation, add Zygote to your project dependencies and include using Zygote
.
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function create_dataset(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [
reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
d in data[1:(dataset_size ÷ 2)]
]
anticlockwise_spirals = [
reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
d in data[((dataset_size ÷ 2) + 1):end]
]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
return x_data, labels
end
function get_dataloaders(; dataset_size=1000, sequence_length=50)
x_data, labels = create_dataset(; dataset_size, sequence_length)
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
)
end
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a LSTM block and a classifier head.
We pass the field names lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
)
end
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
Defining Accuracy, Loss and Optimiser
Now let's define the binary cross-entropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = dev(get_dataloaders())
# Create the model
model = model_type(2, 8, 1)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
end
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf("Epoch [%3d]: Loss %4.5f\n", epoch, total_loss / total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf(
"Validation:\tLoss %4.5f\tAccuracy %4.5f\n",
total_loss / total_samples,
total_acc / total_samples
)
end
return cpu_device()((train_state.parameters, train_state.states))
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-08-02 15:23:34.133985: I external/xla/xla/service/service.cc:163] XLA service 0x11a3e710 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-08-02 15:23:34.134024: I external/xla/xla/service/service.cc:171] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1754148214.134984 4072956 se_gpu_pjrt_client.cc:1380] Using BFC allocator.
I0000 00:00:1754148214.135080 4072956 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1754148214.135140 4072956 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
2025-08-02 15:23:34.148950: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800
Epoch [ 1]: Loss 0.58661
Validation: Loss 0.51742 Accuracy 1.00000
Epoch [ 2]: Loss 0.48195
Validation: Loss 0.42456 Accuracy 1.00000
Epoch [ 3]: Loss 0.39707
Validation: Loss 0.34845 Accuracy 1.00000
Epoch [ 4]: Loss 0.32691
Validation: Loss 0.28922 Accuracy 1.00000
Epoch [ 5]: Loss 0.27600
Validation: Loss 0.24548 Accuracy 1.00000
Epoch [ 6]: Loss 0.23739
Validation: Loss 0.21099 Accuracy 1.00000
Epoch [ 7]: Loss 0.20368
Validation: Loss 0.18194 Accuracy 1.00000
Epoch [ 8]: Loss 0.17671
Validation: Loss 0.15745 Accuracy 1.00000
Epoch [ 9]: Loss 0.15375
Validation: Loss 0.13762 Accuracy 1.00000
Epoch [ 10]: Loss 0.13480
Validation: Loss 0.12163 Accuracy 1.00000
Epoch [ 11]: Loss 0.11921
Validation: Loss 0.10834 Accuracy 1.00000
Epoch [ 12]: Loss 0.10704
Validation: Loss 0.09710 Accuracy 1.00000
Epoch [ 13]: Loss 0.09640
Validation: Loss 0.08742 Accuracy 1.00000
Epoch [ 14]: Loss 0.08546
Validation: Loss 0.07893 Accuracy 1.00000
Epoch [ 15]: Loss 0.07784
Validation: Loss 0.07129 Accuracy 1.00000
Epoch [ 16]: Loss 0.07010
Validation: Loss 0.06411 Accuracy 1.00000
Epoch [ 17]: Loss 0.06301
Validation: Loss 0.05720 Accuracy 1.00000
Epoch [ 18]: Loss 0.05583
Validation: Loss 0.04998 Accuracy 1.00000
Epoch [ 19]: Loss 0.04787
Validation: Loss 0.04198 Accuracy 1.00000
Epoch [ 20]: Loss 0.03942
Validation: Loss 0.03286 Accuracy 1.00000
Epoch [ 21]: Loss 0.02954
Validation: Loss 0.02403 Accuracy 1.00000
Epoch [ 22]: Loss 0.02133
Validation: Loss 0.01748 Accuracy 1.00000
Epoch [ 23]: Loss 0.01583
Validation: Loss 0.01382 Accuracy 1.00000
Epoch [ 24]: Loss 0.01301
Validation: Loss 0.01208 Accuracy 1.00000
Epoch [ 25]: Loss 0.01156
Validation: Loss 0.01090 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.49864
Validation: Loss 0.41509 Accuracy 1.00000
Epoch [ 2]: Loss 0.36967
Validation: Loss 0.31113 Accuracy 1.00000
Epoch [ 3]: Loss 0.27462
Validation: Loss 0.22924 Accuracy 1.00000
Epoch [ 4]: Loss 0.20196
Validation: Loss 0.16798 Accuracy 1.00000
Epoch [ 5]: Loss 0.14889
Validation: Loss 0.12616 Accuracy 1.00000
Epoch [ 6]: Loss 0.11370
Validation: Loss 0.09916 Accuracy 1.00000
Epoch [ 7]: Loss 0.09043
Validation: Loss 0.07981 Accuracy 1.00000
Epoch [ 8]: Loss 0.07377
Validation: Loss 0.06643 Accuracy 1.00000
Epoch [ 9]: Loss 0.06180
Validation: Loss 0.05630 Accuracy 1.00000
Epoch [ 10]: Loss 0.05304
Validation: Loss 0.04873 Accuracy 1.00000
Epoch [ 11]: Loss 0.04609
Validation: Loss 0.04272 Accuracy 1.00000
Epoch [ 12]: Loss 0.04048
Validation: Loss 0.03755 Accuracy 1.00000
Epoch [ 13]: Loss 0.03565
Validation: Loss 0.03308 Accuracy 1.00000
Epoch [ 14]: Loss 0.03125
Validation: Loss 0.02880 Accuracy 1.00000
Epoch [ 15]: Loss 0.02717
Validation: Loss 0.02505 Accuracy 1.00000
Epoch [ 16]: Loss 0.02382
Validation: Loss 0.02231 Accuracy 1.00000
Epoch [ 17]: Loss 0.02132
Validation: Loss 0.02013 Accuracy 1.00000
Epoch [ 18]: Loss 0.01936
Validation: Loss 0.01840 Accuracy 1.00000
Epoch [ 19]: Loss 0.01775
Validation: Loss 0.01694 Accuracy 1.00000
Epoch [ 20]: Loss 0.01636
Validation: Loss 0.01565 Accuracy 1.00000
Epoch [ 21]: Loss 0.01517
Validation: Loss 0.01456 Accuracy 1.00000
Epoch [ 22]: Loss 0.01414
Validation: Loss 0.01358 Accuracy 1.00000
Epoch [ 23]: Loss 0.01321
Validation: Loss 0.01271 Accuracy 1.00000
Epoch [ 24]: Loss 0.01237
Validation: Loss 0.01193 Accuracy 1.00000
Epoch [ 25]: Loss 0.01163
Validation: Loss 0.01122 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.