Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling Lux...
2693.8 ms ✓ WeightInitializers
944.5 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
9292.7 ms ✓ Lux
3 dependencies successfully precompiled in 14 seconds. 106 already precompiled.
Precompiling JLD2...
4138.7 ms ✓ FileIO
32118.2 ms ✓ JLD2
2 dependencies successfully precompiled in 36 seconds. 30 already precompiled.
Precompiling MLUtils...
421.6 ms ✓ InverseFunctions → InverseFunctionsDatesExt
436.2 ms ✓ ContextVariablesX
615.6 ms ✓ InverseFunctions → InverseFunctionsTestExt
1146.0 ms ✓ SimpleTraits
459.6 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
1220.0 ms ✓ SplittablesBase
602.8 ms ✓ FLoopsBase
2275.4 ms ✓ StatsBase
2330.2 ms ✓ Accessors
684.9 ms ✓ Accessors → StaticArraysExt
687.9 ms ✓ Accessors → TestExt
853.8 ms ✓ Accessors → LinearAlgebraExt
756.4 ms ✓ BangBang
525.9 ms ✓ BangBang → BangBangChainRulesCoreExt
526.9 ms ✓ BangBang → BangBangTablesExt
696.1 ms ✓ BangBang → BangBangStaticArraysExt
1041.1 ms ✓ MicroCollections
2726.1 ms ✓ Transducers
672.3 ms ✓ Transducers → TransducersAdaptExt
5306.9 ms ✓ FLoops
6256.3 ms ✓ MLUtils
21 dependencies successfully precompiled in 22 seconds. 77 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1765.5 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 102 already precompiled.
Precompiling LuxMLUtilsExt...
2073.3 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 167 already precompiled.
Precompiling Reactant...
624.9 ms ✓ ReactantCore
947.1 ms ✓ CUDA_Driver_jll
2294.7 ms ✓ Reactant_jll
215926.2 ms ✓ Enzyme
5594.5 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live [0K
[0K2025-02-05 19:15:47.938338: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:51] XLA (re)initializing LLVM with options fingerprint: 6211547431844565882
65366.9 ms ✓ Reactant
6 dependencies successfully precompiled in 287 seconds. 56 already precompiled.
1 dependency had output during precompilation:
┌ Reactant
│ [Output was shown above]
└
Precompiling LuxLibEnzymeExt...
5851.5 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
8247.6 ms ✓ Enzyme → EnzymeStaticArraysExt
1302.2 ms ✓ LuxLib → LuxLibEnzymeExt
11062.7 ms ✓ Enzyme → EnzymeChainRulesCoreExt
5702.5 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
5 dependencies successfully precompiled in 12 seconds. 126 already precompiled.
Precompiling LuxEnzymeExt...
6715.3 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 146 already precompiled.
Precompiling LuxCoreReactantExt...
12241.9 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 12 seconds. 67 already precompiled.
Precompiling MLDataDevicesReactantExt...
12878.1 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 13 seconds. 64 already precompiled.
Precompiling LuxLibReactantExt...
13041.1 ms ✓ Reactant → ReactantStatisticsExt
13442.3 ms ✓ Reactant → ReactantNNlibExt
13794.6 ms ✓ LuxLib → LuxLibReactantExt
12904.8 ms ✓ Reactant → ReactantSpecialFunctionsExt
13424.6 ms ✓ Reactant → ReactantKernelAbstractionsExt
12956.7 ms ✓ Reactant → ReactantArrayInterfaceExt
6 dependencies successfully precompiled in 27 seconds. 140 already precompiled.
Precompiling WeightInitializersReactantExt...
12554.9 ms ✓ WeightInitializers → WeightInitializersReactantExt
1 dependency successfully precompiled in 13 seconds. 78 already precompiled.
Precompiling LuxReactantExt...
8645.6 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 9 seconds. 163 already precompiled.
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false)
)
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##230".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() |> dev
# Create the model
model = model_type(2, 8, 1)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss/total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss/total_samples) (total_acc/total_samples)
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-02-05 19:24:08.285912: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:51] XLA (re)initializing LLVM with options fingerprint: 6285648702646104871
Epoch [ 1]: Loss 0.69397
Validation: Loss 0.63135 Accuracy 1.00000
Epoch [ 2]: Loss 0.59388
Validation: Loss 0.54033 Accuracy 1.00000
Epoch [ 3]: Loss 0.50327
Validation: Loss 0.45833 Accuracy 1.00000
Epoch [ 4]: Loss 0.42248
Validation: Loss 0.38206 Accuracy 1.00000
Epoch [ 5]: Loss 0.34620
Validation: Loss 0.30999 Accuracy 1.00000
Epoch [ 6]: Loss 0.27829
Validation: Loss 0.24980 Accuracy 1.00000
Epoch [ 7]: Loss 0.22434
Validation: Loss 0.19454 Accuracy 1.00000
Epoch [ 8]: Loss 0.16757
Validation: Loss 0.13528 Accuracy 1.00000
Epoch [ 9]: Loss 0.11305
Validation: Loss 0.08826 Accuracy 1.00000
Epoch [ 10]: Loss 0.07546
Validation: Loss 0.06114 Accuracy 1.00000
Epoch [ 11]: Loss 0.05440
Validation: Loss 0.04619 Accuracy 1.00000
Epoch [ 12]: Loss 0.04211
Validation: Loss 0.03655 Accuracy 1.00000
Epoch [ 13]: Loss 0.03360
Validation: Loss 0.02936 Accuracy 1.00000
Epoch [ 14]: Loss 0.02677
Validation: Loss 0.02296 Accuracy 1.00000
Epoch [ 15]: Loss 0.02133
Validation: Loss 0.01889 Accuracy 1.00000
Epoch [ 16]: Loss 0.01808
Validation: Loss 0.01642 Accuracy 1.00000
Epoch [ 17]: Loss 0.01588
Validation: Loss 0.01451 Accuracy 1.00000
Epoch [ 18]: Loss 0.01405
Validation: Loss 0.01303 Accuracy 1.00000
Epoch [ 19]: Loss 0.01277
Validation: Loss 0.01184 Accuracy 1.00000
Epoch [ 20]: Loss 0.01164
Validation: Loss 0.01086 Accuracy 1.00000
Epoch [ 21]: Loss 0.01071
Validation: Loss 0.01002 Accuracy 1.00000
Epoch [ 22]: Loss 0.00988
Validation: Loss 0.00930 Accuracy 1.00000
Epoch [ 23]: Loss 0.00916
Validation: Loss 0.00867 Accuracy 1.00000
Epoch [ 24]: Loss 0.00856
Validation: Loss 0.00811 Accuracy 1.00000
Epoch [ 25]: Loss 0.00799
Validation: Loss 0.00760 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.76119
Validation: Loss 0.68223 Accuracy 0.85938
Epoch [ 2]: Loss 0.62109
Validation: Loss 0.52857 Accuracy 1.00000
Epoch [ 3]: Loss 0.45825
Validation: Loss 0.35641 Accuracy 1.00000
Epoch [ 4]: Loss 0.30262
Validation: Loss 0.23736 Accuracy 1.00000
Epoch [ 5]: Loss 0.20405
Validation: Loss 0.16532 Accuracy 1.00000
Epoch [ 6]: Loss 0.14548
Validation: Loss 0.12080 Accuracy 1.00000
Epoch [ 7]: Loss 0.10650
Validation: Loss 0.08794 Accuracy 1.00000
Epoch [ 8]: Loss 0.07723
Validation: Loss 0.06416 Accuracy 1.00000
Epoch [ 9]: Loss 0.05721
Validation: Loss 0.04926 Accuracy 1.00000
Epoch [ 10]: Loss 0.04482
Validation: Loss 0.03982 Accuracy 1.00000
Epoch [ 11]: Loss 0.03692
Validation: Loss 0.03321 Accuracy 1.00000
Epoch [ 12]: Loss 0.03100
Validation: Loss 0.02826 Accuracy 1.00000
Epoch [ 13]: Loss 0.02660
Validation: Loss 0.02442 Accuracy 1.00000
Epoch [ 14]: Loss 0.02300
Validation: Loss 0.02136 Accuracy 1.00000
Epoch [ 15]: Loss 0.02025
Validation: Loss 0.01890 Accuracy 1.00000
Epoch [ 16]: Loss 0.01801
Validation: Loss 0.01686 Accuracy 1.00000
Epoch [ 17]: Loss 0.01613
Validation: Loss 0.01517 Accuracy 1.00000
Epoch [ 18]: Loss 0.01456
Validation: Loss 0.01371 Accuracy 1.00000
Epoch [ 19]: Loss 0.01320
Validation: Loss 0.01244 Accuracy 1.00000
Epoch [ 20]: Loss 0.01188
Validation: Loss 0.01131 Accuracy 1.00000
Epoch [ 21]: Loss 0.01091
Validation: Loss 0.01028 Accuracy 1.00000
Epoch [ 22]: Loss 0.00986
Validation: Loss 0.00932 Accuracy 1.00000
Epoch [ 23]: Loss 0.00894
Validation: Loss 0.00843 Accuracy 1.00000
Epoch [ 24]: Loss 0.00807
Validation: Loss 0.00759 Accuracy 1.00000
Epoch [ 25]: Loss 0.00723
Validation: Loss 0.00683 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.