Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote
.
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [
reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
d in data[1:(dataset_size ÷ 2)]
]
anticlockwise_spirals = [
reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
d in data[((dataset_size ÷ 2) + 1):end]
]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
)
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
)
end
Main.var"##230".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = dev(get_dataloaders())
# Create the model
model = model_type(2, 8, 1)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (
total_acc / total_samples
)
end
return cpu_device()((train_state.parameters, train_state.states))
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-05-23 22:57:28.654724: I external/xla/xla/service/service.cc:152] XLA service 0xc537210 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-23 22:57:28.654821: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748041048.655508 800117 se_gpu_pjrt_client.cc:1026] Using BFC allocator.
I0000 00:00:1748041048.655588 800117 gpu_helpers.cc:136] XLA backend allocating 12527321088 bytes on device 0 for BFCAllocator.
I0000 00:00:1748041048.655619 800117 gpu_helpers.cc:177] XLA backend will use up to 4175773696 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1748041048.666862 800117 cuda_dnn.cc:529] Loaded cuDNN version 90400
Epoch [ 1]: Loss 0.72328
Validation: Loss 0.66961 Accuracy 0.47656
Epoch [ 2]: Loss 0.62975
Validation: Loss 0.60080 Accuracy 0.47656
Epoch [ 3]: Loss 0.57419
Validation: Loss 0.55200 Accuracy 1.00000
Epoch [ 4]: Loss 0.52276
Validation: Loss 0.49260 Accuracy 1.00000
Epoch [ 5]: Loss 0.45420
Validation: Loss 0.40837 Accuracy 1.00000
Epoch [ 6]: Loss 0.35786
Validation: Loss 0.29874 Accuracy 1.00000
Epoch [ 7]: Loss 0.25151
Validation: Loss 0.20315 Accuracy 1.00000
Epoch [ 8]: Loss 0.17247
Validation: Loss 0.14086 Accuracy 1.00000
Epoch [ 9]: Loss 0.12162
Validation: Loss 0.10067 Accuracy 1.00000
Epoch [ 10]: Loss 0.08789
Validation: Loss 0.07288 Accuracy 1.00000
Epoch [ 11]: Loss 0.06482
Validation: Loss 0.05478 Accuracy 1.00000
Epoch [ 12]: Loss 0.04981
Validation: Loss 0.04289 Accuracy 1.00000
Epoch [ 13]: Loss 0.03951
Validation: Loss 0.03464 Accuracy 1.00000
Epoch [ 14]: Loss 0.03226
Validation: Loss 0.02865 Accuracy 1.00000
Epoch [ 15]: Loss 0.02687
Validation: Loss 0.02423 Accuracy 1.00000
Epoch [ 16]: Loss 0.02296
Validation: Loss 0.02086 Accuracy 1.00000
Epoch [ 17]: Loss 0.01982
Validation: Loss 0.01826 Accuracy 1.00000
Epoch [ 18]: Loss 0.01743
Validation: Loss 0.01622 Accuracy 1.00000
Epoch [ 19]: Loss 0.01555
Validation: Loss 0.01460 Accuracy 1.00000
Epoch [ 20]: Loss 0.01403
Validation: Loss 0.01328 Accuracy 1.00000
Epoch [ 21]: Loss 0.01278
Validation: Loss 0.01217 Accuracy 1.00000
Epoch [ 22]: Loss 0.01174
Validation: Loss 0.01120 Accuracy 1.00000
Epoch [ 23]: Loss 0.01081
Validation: Loss 0.01035 Accuracy 1.00000
Epoch [ 24]: Loss 0.01000
Validation: Loss 0.00958 Accuracy 1.00000
Epoch [ 25]: Loss 0.00925
Validation: Loss 0.00887 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.62104
Validation: Loss 0.54463 Accuracy 1.00000
Epoch [ 2]: Loss 0.48101
Validation: Loss 0.39142 Accuracy 1.00000
Epoch [ 3]: Loss 0.35813
Validation: Loss 0.30133 Accuracy 1.00000
Epoch [ 4]: Loss 0.27453
Validation: Loss 0.23122 Accuracy 1.00000
Epoch [ 5]: Loss 0.21187
Validation: Loss 0.18427 Accuracy 1.00000
Epoch [ 6]: Loss 0.16971
Validation: Loss 0.14992 Accuracy 1.00000
Epoch [ 7]: Loss 0.13876
Validation: Loss 0.12375 Accuracy 1.00000
Epoch [ 8]: Loss 0.11506
Validation: Loss 0.10300 Accuracy 1.00000
Epoch [ 9]: Loss 0.09583
Validation: Loss 0.08521 Accuracy 1.00000
Epoch [ 10]: Loss 0.07883
Validation: Loss 0.06890 Accuracy 1.00000
Epoch [ 11]: Loss 0.06324
Validation: Loss 0.05414 Accuracy 1.00000
Epoch [ 12]: Loss 0.04985
Validation: Loss 0.04285 Accuracy 1.00000
Epoch [ 13]: Loss 0.04028
Validation: Loss 0.03510 Accuracy 1.00000
Epoch [ 14]: Loss 0.03357
Validation: Loss 0.02976 Accuracy 1.00000
Epoch [ 15]: Loss 0.02895
Validation: Loss 0.02593 Accuracy 1.00000
Epoch [ 16]: Loss 0.02552
Validation: Loss 0.02306 Accuracy 1.00000
Epoch [ 17]: Loss 0.02291
Validation: Loss 0.02081 Accuracy 1.00000
Epoch [ 18]: Loss 0.02083
Validation: Loss 0.01900 Accuracy 1.00000
Epoch [ 19]: Loss 0.01901
Validation: Loss 0.01748 Accuracy 1.00000
Epoch [ 20]: Loss 0.01762
Validation: Loss 0.01618 Accuracy 1.00000
Epoch [ 21]: Loss 0.01635
Validation: Loss 0.01505 Accuracy 1.00000
Epoch [ 22]: Loss 0.01517
Validation: Loss 0.01406 Accuracy 1.00000
Epoch [ 23]: Loss 0.01426
Validation: Loss 0.01318 Accuracy 1.00000
Epoch [ 24]: Loss 0.01334
Validation: Loss 0.01239 Accuracy 1.00000
Epoch [ 25]: Loss 0.01257
Validation: Loss 0.01168 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.