Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling JLD2...
   4297.8 ms  ✓ FileIO
  32260.2 ms  ✓ JLD2
  2 dependencies successfully precompiled in 37 seconds. 30 already precompiled.
Precompiling MLUtils...
    445.6 ms  ✓ DelimitedFiles
    841.4 ms  ✓ BangBang
    589.9 ms  ✓ ContextVariablesX
   1193.0 ms  ✓ SimpleTraits
    787.6 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    663.9 ms  ✓ Accessors → TestExt
    682.7 ms  ✓ BangBang → BangBangStaticArraysExt
    506.9 ms  ✓ BangBang → BangBangChainRulesCoreExt
   1398.3 ms  ✓ SplittablesBase
    509.5 ms  ✓ BangBang → BangBangTablesExt
    599.8 ms  ✓ FLoopsBase
   1052.9 ms  ✓ MicroCollections
   1071.6 ms  ✓ MLCore
   2708.3 ms  ✓ Transducers
    674.2 ms  ✓ Transducers → TransducersAdaptExt
   5380.0 ms  ✓ FLoops
   5958.9 ms  ✓ MLUtils
  17 dependencies successfully precompiled in 18 seconds. 83 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    718.1 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1720.8 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 104 already precompiled.
Precompiling LuxMLUtilsExt...
   2349.8 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 3 seconds. 170 already precompiled.
Precompiling Reactant...
    599.6 ms  ✓ ReactantCore
    639.7 ms  ✓ LLVMOpenMP_jll
    937.7 ms  ✓ CUDA_Driver_jll
   2249.3 ms  ✓ Reactant_jll
 219755.0 ms  ✓ Enzyme
   5864.0 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
  77146.1 ms  ✓ Reactant
  7 dependencies successfully precompiled in 303 seconds. 57 already precompiled.
Precompiling LuxLibEnzymeExt...
   6077.1 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   8324.4 ms  ✓ Enzyme → EnzymeStaticArraysExt
   1315.3 ms  ✓ LuxLib → LuxLibEnzymeExt
  11395.0 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   6089.9 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
  5 dependencies successfully precompiled in 13 seconds. 128 already precompiled.
Precompiling LuxEnzymeExt...
   6904.0 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling LuxCoreReactantExt...
  12365.0 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 13 seconds. 69 already precompiled.
Precompiling MLDataDevicesReactantExt...
  12808.9 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 13 seconds. 66 already precompiled.
Precompiling LuxLibReactantExt...
  12681.4 ms  ✓ Reactant → ReactantStatisticsExt
  13285.7 ms  ✓ Reactant → ReactantNNlibExt
  13317.6 ms  ✓ LuxLib → LuxLibReactantExt
  12721.6 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  12625.0 ms  ✓ Reactant → ReactantArrayInterfaceExt
  12737.0 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  6 dependencies successfully precompiled in 26 seconds. 143 already precompiled.
Precompiling WeightInitializersReactantExt...
  12780.0 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  1 dependency successfully precompiled in 13 seconds. 80 already precompiled.
Precompiling LuxReactantExt...
  10361.9 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 11 seconds. 166 already precompiled.

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size = 1000, sequence_length = 50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
            for d in data[1:(dataset_size ÷ 2)]
    ]
    anticlockwise_spirals = [
        reshape(
                d[1][:, (sequence_length + 1):end], :, sequence_length, 1
            )
            for d in data[((dataset_size ÷ 2) + 1):end]
    ]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims = 3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at = 0.8, shuffle = true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize = 128, shuffle = true, partial = false
        ),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize = 128, shuffle = false, partial = false),
    )
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
    )
end
Main.var"##230".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple
    ) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier = st_classifier, lstm_cell = st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred = ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() |> dev

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = Lux.setup(Random.default_rng(), model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        @compile model(first(train_loader)[1], ps, Lux.testmode(st))
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (total_acc / total_samples)
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-03-07 23:46:55.628278: I external/xla/xla/service/service.cc:152] XLA service 0x9ae9d50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-07 23:46:55.628438: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741391215.629255 3638329 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741391215.629320 3638329 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741391215.629364 3638329 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741391215.644784 3638329 cuda_dnn.cc:529] Loaded cuDNN version 90400
Epoch [  1]: Loss 0.65806
Validation:	Loss 0.58377	Accuracy 1.00000
Epoch [  2]: Loss 0.53889
Validation:	Loss 0.48049	Accuracy 1.00000
Epoch [  3]: Loss 0.44765
Validation:	Loss 0.39908	Accuracy 1.00000
Epoch [  4]: Loss 0.37392
Validation:	Loss 0.33268	Accuracy 1.00000
Epoch [  5]: Loss 0.31215
Validation:	Loss 0.27409	Accuracy 1.00000
Epoch [  6]: Loss 0.25143
Validation:	Loss 0.20997	Accuracy 1.00000
Epoch [  7]: Loss 0.18228
Validation:	Loss 0.15024	Accuracy 1.00000
Epoch [  8]: Loss 0.13617
Validation:	Loss 0.11492	Accuracy 1.00000
Epoch [  9]: Loss 0.10375
Validation:	Loss 0.08718	Accuracy 1.00000
Epoch [ 10]: Loss 0.07819
Validation:	Loss 0.06623	Accuracy 1.00000
Epoch [ 11]: Loss 0.05933
Validation:	Loss 0.05051	Accuracy 1.00000
Epoch [ 12]: Loss 0.04508
Validation:	Loss 0.03884	Accuracy 1.00000
Epoch [ 13]: Loss 0.03473
Validation:	Loss 0.03074	Accuracy 1.00000
Epoch [ 14]: Loss 0.02794
Validation:	Loss 0.02558	Accuracy 1.00000
Epoch [ 15]: Loss 0.02355
Validation:	Loss 0.02204	Accuracy 1.00000
Epoch [ 16]: Loss 0.02055
Validation:	Loss 0.01939	Accuracy 1.00000
Epoch [ 17]: Loss 0.01814
Validation:	Loss 0.01732	Accuracy 1.00000
Epoch [ 18]: Loss 0.01625
Validation:	Loss 0.01565	Accuracy 1.00000
Epoch [ 19]: Loss 0.01473
Validation:	Loss 0.01429	Accuracy 1.00000
Epoch [ 20]: Loss 0.01348
Validation:	Loss 0.01314	Accuracy 1.00000
Epoch [ 21]: Loss 0.01241
Validation:	Loss 0.01217	Accuracy 1.00000
Epoch [ 22]: Loss 0.01141
Validation:	Loss 0.01133	Accuracy 1.00000
Epoch [ 23]: Loss 0.01076
Validation:	Loss 0.01060	Accuracy 1.00000
Epoch [ 24]: Loss 0.01004
Validation:	Loss 0.00996	Accuracy 1.00000
Epoch [ 25]: Loss 0.00941
Validation:	Loss 0.00939	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.56069
Validation:	Loss 0.50002	Accuracy 0.50781
Epoch [  2]: Loss 0.45935
Validation:	Loss 0.40520	Accuracy 0.50781
Epoch [  3]: Loss 0.39022
Validation:	Loss 0.34996	Accuracy 1.00000
Epoch [  4]: Loss 0.32765
Validation:	Loss 0.28566	Accuracy 1.00000
Epoch [  5]: Loss 0.26878
Validation:	Loss 0.22812	Accuracy 1.00000
Epoch [  6]: Loss 0.20919
Validation:	Loss 0.17706	Accuracy 1.00000
Epoch [  7]: Loss 0.15947
Validation:	Loss 0.13204	Accuracy 1.00000
Epoch [  8]: Loss 0.11734
Validation:	Loss 0.09225	Accuracy 1.00000
Epoch [  9]: Loss 0.08175
Validation:	Loss 0.06725	Accuracy 1.00000
Epoch [ 10]: Loss 0.06103
Validation:	Loss 0.05031	Accuracy 1.00000
Epoch [ 11]: Loss 0.04458
Validation:	Loss 0.03722	Accuracy 1.00000
Epoch [ 12]: Loss 0.03371
Validation:	Loss 0.02848	Accuracy 1.00000
Epoch [ 13]: Loss 0.02671
Validation:	Loss 0.02289	Accuracy 1.00000
Epoch [ 14]: Loss 0.02178
Validation:	Loss 0.01941	Accuracy 1.00000
Epoch [ 15]: Loss 0.01880
Validation:	Loss 0.01707	Accuracy 1.00000
Epoch [ 16]: Loss 0.01657
Validation:	Loss 0.01535	Accuracy 1.00000
Epoch [ 17]: Loss 0.01505
Validation:	Loss 0.01396	Accuracy 1.00000
Epoch [ 18]: Loss 0.01353
Validation:	Loss 0.01277	Accuracy 1.00000
Epoch [ 19]: Loss 0.01258
Validation:	Loss 0.01171	Accuracy 1.00000
Epoch [ 20]: Loss 0.01151
Validation:	Loss 0.01066	Accuracy 1.00000
Epoch [ 21]: Loss 0.01035
Validation:	Loss 0.00954	Accuracy 1.00000
Epoch [ 22]: Loss 0.00910
Validation:	Loss 0.00826	Accuracy 1.00000
Epoch [ 23]: Loss 0.00791
Validation:	Loss 0.00705	Accuracy 1.00000
Epoch [ 24]: Loss 0.00676
Validation:	Loss 0.00615	Accuracy 1.00000
Epoch [ 25]: Loss 0.00591
Validation:	Loss 0.00544	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.