Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
Note: If you wish to use AutoZygote()
for automatic differentiation, add Zygote to your project dependencies and include using Zygote
.
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function create_dataset(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [
reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
d in data[1:(dataset_size ÷ 2)]
]
anticlockwise_spirals = [
reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
d in data[((dataset_size ÷ 2) + 1):end]
]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
return x_data, labels
end
function get_dataloaders(; dataset_size=1000, sequence_length=50)
x_data, labels = create_dataset(; dataset_size, sequence_length)
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
)
end
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a LSTM block and a classifier head.
We pass the field names lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
)
end
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
Defining Accuracy, Loss and Optimiser
Now let's define the binary cross-entropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = dev(get_dataloaders())
# Create the model
model = model_type(2, 8, 1)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
end
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf("Epoch [%3d]: Loss %4.5f\n", epoch, total_loss / total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf(
"Validation:\tLoss %4.5f\tAccuracy %4.5f\n",
total_loss / total_samples,
total_acc / total_samples
)
end
return cpu_device()((train_state.parameters, train_state.states))
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-08-05 23:25:19.962283: I external/xla/xla/service/service.cc:163] XLA service 0x46473b90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-08-05 23:25:19.962369: I external/xla/xla/service/service.cc:171] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1754436319.963208 319589 se_gpu_pjrt_client.cc:1373] Using BFC allocator.
I0000 00:00:1754436319.963396 319589 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1754436319.963549 319589 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
2025-08-05 23:25:19.979836: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800
Epoch [ 1]: Loss 0.49418
Validation: Loss 0.42908 Accuracy 0.50781
Epoch [ 2]: Loss 0.41481
Validation: Loss 0.38093 Accuracy 1.00000
Epoch [ 3]: Loss 0.36448
Validation: Loss 0.33782 Accuracy 1.00000
Epoch [ 4]: Loss 0.31993
Validation: Loss 0.28648 Accuracy 1.00000
Epoch [ 5]: Loss 0.26585
Validation: Loss 0.22743 Accuracy 1.00000
Epoch [ 6]: Loss 0.20549
Validation: Loss 0.16805 Accuracy 1.00000
Epoch [ 7]: Loss 0.14619
Validation: Loss 0.11574 Accuracy 1.00000
Epoch [ 8]: Loss 0.09926
Validation: Loss 0.07505 Accuracy 1.00000
Epoch [ 9]: Loss 0.06308
Validation: Loss 0.04703 Accuracy 1.00000
Epoch [ 10]: Loss 0.03952
Validation: Loss 0.03005 Accuracy 1.00000
Epoch [ 11]: Loss 0.02618
Validation: Loss 0.02109 Accuracy 1.00000
Epoch [ 12]: Loss 0.01878
Validation: Loss 0.01577 Accuracy 1.00000
Epoch [ 13]: Loss 0.01435
Validation: Loss 0.01234 Accuracy 1.00000
Epoch [ 14]: Loss 0.01140
Validation: Loss 0.01007 Accuracy 1.00000
Epoch [ 15]: Loss 0.00946
Validation: Loss 0.00854 Accuracy 1.00000
Epoch [ 16]: Loss 0.00817
Validation: Loss 0.00744 Accuracy 1.00000
Epoch [ 17]: Loss 0.00711
Validation: Loss 0.00662 Accuracy 1.00000
Epoch [ 18]: Loss 0.00636
Validation: Loss 0.00594 Accuracy 1.00000
Epoch [ 19]: Loss 0.00574
Validation: Loss 0.00538 Accuracy 1.00000
Epoch [ 20]: Loss 0.00523
Validation: Loss 0.00490 Accuracy 1.00000
Epoch [ 21]: Loss 0.00476
Validation: Loss 0.00449 Accuracy 1.00000
Epoch [ 22]: Loss 0.00438
Validation: Loss 0.00414 Accuracy 1.00000
Epoch [ 23]: Loss 0.00403
Validation: Loss 0.00382 Accuracy 1.00000
Epoch [ 24]: Loss 0.00374
Validation: Loss 0.00353 Accuracy 1.00000
Epoch [ 25]: Loss 0.00345
Validation: Loss 0.00326 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.75722
Validation: Loss 0.65034 Accuracy 0.50000
Epoch [ 2]: Loss 0.60343
Validation: Loss 0.54516 Accuracy 0.50000
Epoch [ 3]: Loss 0.50653
Validation: Loss 0.45937 Accuracy 1.00000
Epoch [ 4]: Loss 0.42980
Validation: Loss 0.39885 Accuracy 1.00000
Epoch [ 5]: Loss 0.37428
Validation: Loss 0.35542 Accuracy 1.00000
Epoch [ 6]: Loss 0.33829
Validation: Loss 0.32162 Accuracy 1.00000
Epoch [ 7]: Loss 0.30593
Validation: Loss 0.29370 Accuracy 1.00000
Epoch [ 8]: Loss 0.28075
Validation: Loss 0.26765 Accuracy 1.00000
Epoch [ 9]: Loss 0.25346
Validation: Loss 0.24088 Accuracy 1.00000
Epoch [ 10]: Loss 0.22605
Validation: Loss 0.21062 Accuracy 1.00000
Epoch [ 11]: Loss 0.19468
Validation: Loss 0.17477 Accuracy 1.00000
Epoch [ 12]: Loss 0.15650
Validation: Loss 0.13699 Accuracy 1.00000
Epoch [ 13]: Loss 0.12160
Validation: Loss 0.10415 Accuracy 1.00000
Epoch [ 14]: Loss 0.09041
Validation: Loss 0.07536 Accuracy 1.00000
Epoch [ 15]: Loss 0.06361
Validation: Loss 0.04965 Accuracy 1.00000
Epoch [ 16]: Loss 0.04194
Validation: Loss 0.03421 Accuracy 1.00000
Epoch [ 17]: Loss 0.03058
Validation: Loss 0.02691 Accuracy 1.00000
Epoch [ 18]: Loss 0.02480
Validation: Loss 0.02259 Accuracy 1.00000
Epoch [ 19]: Loss 0.02120
Validation: Loss 0.01950 Accuracy 1.00000
Epoch [ 20]: Loss 0.01838
Validation: Loss 0.01715 Accuracy 1.00000
Epoch [ 21]: Loss 0.01631
Validation: Loss 0.01526 Accuracy 1.00000
Epoch [ 22]: Loss 0.01452
Validation: Loss 0.01364 Accuracy 1.00000
Epoch [ 23]: Loss 0.01302
Validation: Loss 0.01224 Accuracy 1.00000
Epoch [ 24]: Loss 0.01166
Validation: Loss 0.01091 Accuracy 1.00000
Epoch [ 25]: Loss 0.01034
Validation: Loss 0.00953 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.