Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, RandomPrecompiling ADTypes...
648.0 ms ✓ ADTypes
1 dependency successfully precompiled in 1 seconds
Precompiling Lux...
355.4 ms ✓ CEnum
377.4 ms ✓ ManualMemory
550.2 ms ✓ EnzymeCore
370.9 ms ✓ StaticArraysCore
522.5 ms ✓ ADTypes → ADTypesConstructionBaseExt
505.8 ms ✓ ArrayInterface
419.3 ms ✓ ADTypes → ADTypesChainRulesCoreExt
388.7 ms ✓ EnzymeCore → AdaptExt
798.0 ms ✓ ThreadingUtilities
358.0 ms ✓ ADTypes → ADTypesEnzymeCoreExt
415.6 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
434.0 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
412.4 ms ✓ Optimisers → OptimisersEnzymeCoreExt
394.9 ms ✓ DiffResults
2099.3 ms ✓ Hwloc
1477.9 ms ✓ Setfield
1495.1 ms ✓ StaticArrayInterface
349.5 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
362.0 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
385.6 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
626.7 ms ✓ PolyesterWeave
452.3 ms ✓ LuxCore → LuxCoreSetfieldExt
474.2 ms ✓ CloseOpenIntervals
574.1 ms ✓ LayoutPointers
927.5 ms ✓ StrideArraysCore
3522.3 ms ✓ ForwardDiff
760.1 ms ✓ Polyester
7168.5 ms ✓ StaticArrays
625.2 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
631.4 ms ✓ StaticArrays → StaticArraysStatisticsExt
629.9 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
607.4 ms ✓ Adapt → AdaptStaticArraysExt
730.5 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
884.7 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
3815.4 ms ✓ KernelAbstractions
667.6 ms ✓ KernelAbstractions → LinearAlgebraExt
734.2 ms ✓ KernelAbstractions → EnzymeExt
5194.1 ms ✓ NNlib
943.0 ms ✓ NNlib → NNlibEnzymeCoreExt
1069.2 ms ✓ NNlib → NNlibForwardDiffExt
6097.8 ms ✓ LuxLib
9298.6 ms ✓ Lux
42 dependencies successfully precompiled in 38 seconds. 67 already precompiled.
Precompiling JLD2...
350.7 ms ✓ Zlib_jll
469.3 ms ✓ OrderedCollections
511.0 ms ✓ TranscodingStreams
34028.1 ms ✓ JLD2
4 dependencies successfully precompiled in 35 seconds. 28 already precompiled.
Precompiling MLUtils...
301.8 ms ✓ IteratorInterfaceExtensions
362.6 ms ✓ StatsAPI
432.3 ms ✓ InverseFunctions
844.4 ms ✓ InitialValues
581.1 ms ✓ Serialization
359.8 ms ✓ PrettyPrint
404.7 ms ✓ ShowCases
420.5 ms ✓ SuiteSparse_jll
290.7 ms ✓ DataValueInterfaces
340.3 ms ✓ CompositionsBase
317.2 ms ✓ PtrArrays
323.1 ms ✓ DefineSingletons
342.9 ms ✓ DataAPI
1075.4 ms ✓ Baselet
329.7 ms ✓ TableTraits
404.1 ms ✓ InverseFunctions → InverseFunctionsDatesExt
1669.5 ms ✓ DataStructures
457.8 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
1919.2 ms ✓ Distributed
384.7 ms ✓ NameResolution
3800.8 ms ✓ Test
388.8 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
446.2 ms ✓ AliasTables
433.8 ms ✓ Missings
792.8 ms ✓ Tables
4146.6 ms ✓ SparseArrays
532.1 ms ✓ SortingAlgorithms
582.7 ms ✓ InverseFunctions → InverseFunctionsTestExt
1179.4 ms ✓ SplittablesBase
631.7 ms ✓ Statistics → SparseArraysExt
816.7 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
980.3 ms ✓ KernelAbstractions → SparseArraysExt
2994.5 ms ✓ Accessors
638.5 ms ✓ Accessors → AccessorsTestExt
1634.2 ms ✓ Accessors → AccessorsDatesExt
2607.2 ms ✓ StatsBase
717.3 ms ✓ Accessors → AccessorsStaticArraysExt
810.8 ms ✓ BangBang
509.6 ms ✓ BangBang → BangBangChainRulesCoreExt
722.8 ms ✓ BangBang → BangBangStaticArraysExt
507.5 ms ✓ BangBang → BangBangTablesExt
926.5 ms ✓ MicroCollections
18554.6 ms ✓ MLStyle
2652.7 ms ✓ Transducers
644.9 ms ✓ Transducers → TransducersAdaptExt
4206.9 ms ✓ JuliaVariables
5224.9 ms ✓ FLoops
6203.1 ms ✓ MLUtils
48 dependencies successfully precompiled in 36 seconds. 50 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
583.9 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 7 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
668.4 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1775.0 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 102 already precompiled.
Precompiling LuxMLUtilsExt...
2322.8 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 3 seconds. 167 already precompiled.
Precompiling Reactant...
352.8 ms ✓ StructIO
382.6 ms ✓ ExprTools
644.9 ms ✓ ExpressionExplorer
920.8 ms ✓ CUDA_Driver_jll
2036.5 ms ✓ ObjectFile
610.7 ms ✓ ReactantCore
2774.8 ms ✓ TimerOutputs
2241.1 ms ✓ Reactant_jll
6999.4 ms ✓ LLVM
26194.3 ms ✓ GPUCompiler
227017.3 ms ✓ Enzyme
6585.0 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live [0K
[0K2025-01-08 21:35:46.057087: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 8045854322994363981
55303.5 ms ✓ Reactant
13 dependencies successfully precompiled in 323 seconds. 43 already precompiled.
1 dependency had output during precompilation:
┌ Reactant
│ [Output was shown above]
└
Precompiling UnsafeAtomicsLLVM...
1838.0 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
7113.8 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
7048.2 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
1449.0 ms ✓ LuxLib → LuxLibEnzymeExt
18768.4 ms ✓ Enzyme → EnzymeStaticArraysExt
19964.6 ms ✓ Enzyme → EnzymeChainRulesCoreExt
5 dependencies successfully precompiled in 20 seconds. 126 already precompiled.
Precompiling LuxEnzymeExt...
8471.6 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 9 seconds. 146 already precompiled.
Precompiling LuxCoreReactantExt...
8328.4 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 9 seconds. 62 already precompiled.
Precompiling MLDataDevicesReactantExt...
21376.7 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 22 seconds. 63 already precompiled.
Precompiling WeightInitializersReactantExt...
8712.4 ms ✓ WeightInitializers → WeightInitializersReactantExt
8862.0 ms ✓ Reactant → ReactantSpecialFunctionsExt
8920.8 ms ✓ Reactant → ReactantStatisticsExt
3 dependencies successfully precompiled in 9 seconds. 70 already precompiled.
Precompiling ReactantNNlibExt...
22117.9 ms ✓ Reactant → ReactantNNlibExt
1 dependency successfully precompiled in 22 seconds. 79 already precompiled.
Precompiling ReactantArrayInterfaceExt...
8399.0 ms ✓ Reactant → ReactantArrayInterfaceExt
1 dependency successfully precompiled in 9 seconds. 59 already precompiled.
Precompiling LuxReactantExt...
9736.5 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 10 seconds. 160 already precompiled.Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false)
)
endget_dataloaders (generic function with 1 method)Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
endWe won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
endMain.var"##230".SpiralClassifierWe can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
endUsing the @compact API
We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
endSpiralClassifierCompact (generic function with 1 method)Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)accuracy (generic function with 1 method)Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() |> dev
# Create the model
model = model_type(2, 8, 1)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss/total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss/total_samples) (total_acc/total_samples)
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-01-08 21:44:07.567272: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 16489874339354537300
Epoch [ 1]: Loss 0.68110
Validation: Loss 0.63250 Accuracy 1.00000
Epoch [ 2]: Loss 0.60197
Validation: Loss 0.56777 Accuracy 1.00000
Epoch [ 3]: Loss 0.53570
Validation: Loss 0.50769 Accuracy 1.00000
Epoch [ 4]: Loss 0.47968
Validation: Loss 0.45797 Accuracy 1.00000
Epoch [ 5]: Loss 0.43037
Validation: Loss 0.41106 Accuracy 1.00000
Epoch [ 6]: Loss 0.37877
Validation: Loss 0.35582 Accuracy 1.00000
Epoch [ 7]: Loss 0.31701
Validation: Loss 0.27180 Accuracy 1.00000
Epoch [ 8]: Loss 0.21536
Validation: Loss 0.15994 Accuracy 1.00000
Epoch [ 9]: Loss 0.13222
Validation: Loss 0.10416 Accuracy 1.00000
Epoch [ 10]: Loss 0.08801
Validation: Loss 0.07033 Accuracy 1.00000
Epoch [ 11]: Loss 0.06101
Validation: Loss 0.04988 Accuracy 1.00000
Epoch [ 12]: Loss 0.04475
Validation: Loss 0.03786 Accuracy 1.00000
Epoch [ 13]: Loss 0.03493
Validation: Loss 0.03040 Accuracy 1.00000
Epoch [ 14]: Loss 0.02865
Validation: Loss 0.02539 Accuracy 1.00000
Epoch [ 15]: Loss 0.02445
Validation: Loss 0.02175 Accuracy 1.00000
Epoch [ 16]: Loss 0.02105
Validation: Loss 0.01891 Accuracy 1.00000
Epoch [ 17]: Loss 0.01849
Validation: Loss 0.01649 Accuracy 1.00000
Epoch [ 18]: Loss 0.01617
Validation: Loss 0.01426 Accuracy 1.00000
Epoch [ 19]: Loss 0.01392
Validation: Loss 0.01240 Accuracy 1.00000
Epoch [ 20]: Loss 0.01234
Validation: Loss 0.01104 Accuracy 1.00000
Epoch [ 21]: Loss 0.01112
Validation: Loss 0.01002 Accuracy 1.00000
Epoch [ 22]: Loss 0.01016
Validation: Loss 0.00922 Accuracy 1.00000
Epoch [ 23]: Loss 0.00938
Validation: Loss 0.00857 Accuracy 1.00000
Epoch [ 24]: Loss 0.00869
Validation: Loss 0.00802 Accuracy 1.00000
Epoch [ 25]: Loss 0.00835
Validation: Loss 0.00755 Accuracy 1.00000We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.72050
Validation: Loss 0.66014 Accuracy 0.46875
Epoch [ 2]: Loss 0.60964
Validation: Loss 0.55337 Accuracy 0.77344
Epoch [ 3]: Loss 0.49122
Validation: Loss 0.45393 Accuracy 1.00000
Epoch [ 4]: Loss 0.42413
Validation: Loss 0.40921 Accuracy 1.00000
Epoch [ 5]: Loss 0.37817
Validation: Loss 0.36828 Accuracy 1.00000
Epoch [ 6]: Loss 0.34090
Validation: Loss 0.33366 Accuracy 1.00000
Epoch [ 7]: Loss 0.30455
Validation: Loss 0.30447 Accuracy 1.00000
Epoch [ 8]: Loss 0.27996
Validation: Loss 0.27991 Accuracy 1.00000
Epoch [ 9]: Loss 0.25610
Validation: Loss 0.25802 Accuracy 1.00000
Epoch [ 10]: Loss 0.23492
Validation: Loss 0.23604 Accuracy 1.00000
Epoch [ 11]: Loss 0.21325
Validation: Loss 0.20852 Accuracy 1.00000
Epoch [ 12]: Loss 0.18491
Validation: Loss 0.17322 Accuracy 1.00000
Epoch [ 13]: Loss 0.15022
Validation: Loss 0.13684 Accuracy 1.00000
Epoch [ 14]: Loss 0.11801
Validation: Loss 0.11003 Accuracy 1.00000
Epoch [ 15]: Loss 0.09662
Validation: Loss 0.09060 Accuracy 1.00000
Epoch [ 16]: Loss 0.07800
Validation: Loss 0.07317 Accuracy 1.00000
Epoch [ 17]: Loss 0.06205
Validation: Loss 0.05630 Accuracy 1.00000
Epoch [ 18]: Loss 0.04780
Validation: Loss 0.04332 Accuracy 1.00000
Epoch [ 19]: Loss 0.03726
Validation: Loss 0.03325 Accuracy 1.00000
Epoch [ 20]: Loss 0.02849
Validation: Loss 0.02503 Accuracy 1.00000
Epoch [ 21]: Loss 0.02154
Validation: Loss 0.01873 Accuracy 1.00000
Epoch [ 22]: Loss 0.01614
Validation: Loss 0.01460 Accuracy 1.00000
Epoch [ 23]: Loss 0.01301
Validation: Loss 0.01205 Accuracy 1.00000
Epoch [ 24]: Loss 0.01108
Validation: Loss 0.01033 Accuracy 1.00000
Epoch [ 25]: Loss 0.00961
Validation: Loss 0.00916 Accuracy 1.00000Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trainedLet's try loading the model
@load "trained_model.jld2" ps_trained st_trained2-element Vector{Symbol}:
:ps_trained
:st_trainedAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.