Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling Lux...
5899.6 ms ✓ LuxLib
9313.1 ms ✓ Lux
2 dependencies successfully precompiled in 16 seconds. 107 already precompiled.
Precompiling LuxMLUtilsExt...
2229.9 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 167 already precompiled.
Precompiling LuxLibEnzymeExt...
1362.5 ms ✓ LuxLib → LuxLibEnzymeExt
1 dependency successfully precompiled in 2 seconds. 130 already precompiled.
Precompiling LuxEnzymeExt...
6890.5 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 146 already precompiled.
Precompiling LuxLibReactantExt...
13508.2 ms ✓ Reactant → ReactantNNlibExt
13758.2 ms ✓ LuxLib → LuxLibReactantExt
2 dependencies successfully precompiled in 14 seconds. 142 already precompiled.
Precompiling LuxReactantExt...
8542.8 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 9 seconds. 161 already precompiled.
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false)
)
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##230".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() |> dev
# Create the model
model = model_type(2, 8, 1)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss/total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss/total_samples) (total_acc/total_samples)
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-01-24 06:01:33.744382: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 15805882848018654454
Epoch [ 1]: Loss 0.61726
Validation: Loss 0.54094 Accuracy 1.00000
Epoch [ 2]: Loss 0.50823
Validation: Loss 0.46061 Accuracy 1.00000
Epoch [ 3]: Loss 0.43373
Validation: Loss 0.39411 Accuracy 1.00000
Epoch [ 4]: Loss 0.37233
Validation: Loss 0.33552 Accuracy 1.00000
Epoch [ 5]: Loss 0.31353
Validation: Loss 0.27997 Accuracy 1.00000
Epoch [ 6]: Loss 0.25701
Validation: Loss 0.22527 Accuracy 1.00000
Epoch [ 7]: Loss 0.20482
Validation: Loss 0.17676 Accuracy 1.00000
Epoch [ 8]: Loss 0.16106
Validation: Loss 0.13826 Accuracy 1.00000
Epoch [ 9]: Loss 0.12582
Validation: Loss 0.10878 Accuracy 1.00000
Epoch [ 10]: Loss 0.09994
Validation: Loss 0.08616 Accuracy 1.00000
Epoch [ 11]: Loss 0.07903
Validation: Loss 0.06723 Accuracy 1.00000
Epoch [ 12]: Loss 0.06045
Validation: Loss 0.05047 Accuracy 1.00000
Epoch [ 13]: Loss 0.04459
Validation: Loss 0.03730 Accuracy 1.00000
Epoch [ 14]: Loss 0.03376
Validation: Loss 0.02840 Accuracy 1.00000
Epoch [ 15]: Loss 0.02581
Validation: Loss 0.02201 Accuracy 1.00000
Epoch [ 16]: Loss 0.02013
Validation: Loss 0.01738 Accuracy 1.00000
Epoch [ 17]: Loss 0.01603
Validation: Loss 0.01405 Accuracy 1.00000
Epoch [ 18]: Loss 0.01306
Validation: Loss 0.01164 Accuracy 1.00000
Epoch [ 19]: Loss 0.01091
Validation: Loss 0.00988 Accuracy 1.00000
Epoch [ 20]: Loss 0.00935
Validation: Loss 0.00856 Accuracy 1.00000
Epoch [ 21]: Loss 0.00814
Validation: Loss 0.00756 Accuracy 1.00000
Epoch [ 22]: Loss 0.00724
Validation: Loss 0.00681 Accuracy 1.00000
Epoch [ 23]: Loss 0.00657
Validation: Loss 0.00623 Accuracy 1.00000
Epoch [ 24]: Loss 0.00601
Validation: Loss 0.00573 Accuracy 1.00000
Epoch [ 25]: Loss 0.00553
Validation: Loss 0.00529 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.66203
Validation: Loss 0.57153 Accuracy 0.57031
Epoch [ 2]: Loss 0.54116
Validation: Loss 0.46976 Accuracy 1.00000
Epoch [ 3]: Loss 0.43337
Validation: Loss 0.37439 Accuracy 1.00000
Epoch [ 4]: Loss 0.33407
Validation: Loss 0.28618 Accuracy 1.00000
Epoch [ 5]: Loss 0.24639
Validation: Loss 0.20935 Accuracy 1.00000
Epoch [ 6]: Loss 0.17783
Validation: Loss 0.15226 Accuracy 1.00000
Epoch [ 7]: Loss 0.12870
Validation: Loss 0.11263 Accuracy 1.00000
Epoch [ 8]: Loss 0.09590
Validation: Loss 0.08546 Accuracy 1.00000
Epoch [ 9]: Loss 0.07340
Validation: Loss 0.06627 Accuracy 1.00000
Epoch [ 10]: Loss 0.05773
Validation: Loss 0.05271 Accuracy 1.00000
Epoch [ 11]: Loss 0.04620
Validation: Loss 0.04315 Accuracy 1.00000
Epoch [ 12]: Loss 0.03817
Validation: Loss 0.03601 Accuracy 1.00000
Epoch [ 13]: Loss 0.03216
Validation: Loss 0.03047 Accuracy 1.00000
Epoch [ 14]: Loss 0.02748
Validation: Loss 0.02618 Accuracy 1.00000
Epoch [ 15]: Loss 0.02378
Validation: Loss 0.02296 Accuracy 1.00000
Epoch [ 16]: Loss 0.02090
Validation: Loss 0.02045 Accuracy 1.00000
Epoch [ 17]: Loss 0.01889
Validation: Loss 0.01840 Accuracy 1.00000
Epoch [ 18]: Loss 0.01699
Validation: Loss 0.01672 Accuracy 1.00000
Epoch [ 19]: Loss 0.01554
Validation: Loss 0.01531 Accuracy 1.00000
Epoch [ 20]: Loss 0.01431
Validation: Loss 0.01410 Accuracy 1.00000
Epoch [ 21]: Loss 0.01323
Validation: Loss 0.01306 Accuracy 1.00000
Epoch [ 22]: Loss 0.01226
Validation: Loss 0.01215 Accuracy 1.00000
Epoch [ 23]: Loss 0.01142
Validation: Loss 0.01134 Accuracy 1.00000
Epoch [ 24]: Loss 0.01065
Validation: Loss 0.01063 Accuracy 1.00000
Epoch [ 25]: Loss 0.00998
Validation: Loss 0.00999 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.