Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function create_dataset(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
        d in data[1:(dataset_size ÷ 2)]
    ]
    anticlockwise_spirals = [
        reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
        d in data[((dataset_size ÷ 2) + 1):end]
    ]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    return x_data, labels
end

function get_dataloaders(; dataset_size=1000, sequence_length=50)
    x_data, labels = create_dataset(; dataset_size, sequence_length)
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
        ),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
    )
end

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a LSTM block and a classifier head.

We pass the field names lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
    )
end

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
    x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end

Defining Accuracy, Loss and Optimiser

Now let's define the binary cross-entropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = dev(get_dataloaders())

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        Reactant.with_config(;
            dot_general_precision=PrecisionConfig.HIGH,
            convolution_precision=PrecisionConfig.HIGH,
        ) do
            @compile model(first(train_loader)[1], ps, Lux.testmode(st))
        end
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf("Epoch [%3d]: Loss %4.5f\n", epoch, total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf(
            "Validation:\tLoss %4.5f\tAccuracy %4.5f\n",
            total_loss / total_samples,
            total_acc / total_samples
        )
    end

    return cpu_device()((train_state.parameters, train_state.states))
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1759714606.403657 1334286 service.cc:158] XLA service 0x3c28e510 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1759714606.404142 1334286 service.cc:166]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1759714606.406242 1334286 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1759714606.406533 1334286 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1759714606.406811 1334286 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1759714606.444287 1334286 cuda_dnn.cc:463] Loaded cuDNN version 91200
Epoch [  1]: Loss 0.56466
Validation:	Loss 0.46252	Accuracy 1.00000
Epoch [  2]: Loss 0.41123
Validation:	Loss 0.36907	Accuracy 1.00000
Epoch [  3]: Loss 0.33550
Validation:	Loss 0.28746	Accuracy 1.00000
Epoch [  4]: Loss 0.23725
Validation:	Loss 0.17010	Accuracy 1.00000
Epoch [  5]: Loss 0.14055
Validation:	Loss 0.10802	Accuracy 1.00000
Epoch [  6]: Loss 0.09527
Validation:	Loss 0.07789	Accuracy 1.00000
Epoch [  7]: Loss 0.06806
Validation:	Loss 0.05487	Accuracy 1.00000
Epoch [  8]: Loss 0.04801
Validation:	Loss 0.03822	Accuracy 1.00000
Epoch [  9]: Loss 0.03439
Validation:	Loss 0.02883	Accuracy 1.00000
Epoch [ 10]: Loss 0.02655
Validation:	Loss 0.02308	Accuracy 1.00000
Epoch [ 11]: Loss 0.02132
Validation:	Loss 0.01895	Accuracy 1.00000
Epoch [ 12]: Loss 0.01792
Validation:	Loss 0.01626	Accuracy 1.00000
Epoch [ 13]: Loss 0.01549
Validation:	Loss 0.01425	Accuracy 1.00000
Epoch [ 14]: Loss 0.01370
Validation:	Loss 0.01270	Accuracy 1.00000
Epoch [ 15]: Loss 0.01226
Validation:	Loss 0.01149	Accuracy 1.00000
Epoch [ 16]: Loss 0.01112
Validation:	Loss 0.01048	Accuracy 1.00000
Epoch [ 17]: Loss 0.01020
Validation:	Loss 0.00964	Accuracy 1.00000
Epoch [ 18]: Loss 0.00939
Validation:	Loss 0.00892	Accuracy 1.00000
Epoch [ 19]: Loss 0.00872
Validation:	Loss 0.00830	Accuracy 1.00000
Epoch [ 20]: Loss 0.00812
Validation:	Loss 0.00775	Accuracy 1.00000
Epoch [ 21]: Loss 0.00758
Validation:	Loss 0.00726	Accuracy 1.00000
Epoch [ 22]: Loss 0.00712
Validation:	Loss 0.00683	Accuracy 1.00000
Epoch [ 23]: Loss 0.00670
Validation:	Loss 0.00643	Accuracy 1.00000
Epoch [ 24]: Loss 0.00632
Validation:	Loss 0.00608	Accuracy 1.00000
Epoch [ 25]: Loss 0.00598
Validation:	Loss 0.00575	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.60442
Validation:	Loss 0.54005	Accuracy 0.51562
Epoch [  2]: Loss 0.51477
Validation:	Loss 0.46139	Accuracy 1.00000
Epoch [  3]: Loss 0.43498
Validation:	Loss 0.38172	Accuracy 1.00000
Epoch [  4]: Loss 0.34525
Validation:	Loss 0.27904	Accuracy 1.00000
Epoch [  5]: Loss 0.23912
Validation:	Loss 0.17893	Accuracy 1.00000
Epoch [  6]: Loss 0.14934
Validation:	Loss 0.11233	Accuracy 1.00000
Epoch [  7]: Loss 0.09825
Validation:	Loss 0.08111	Accuracy 1.00000
Epoch [  8]: Loss 0.07412
Validation:	Loss 0.06303	Accuracy 1.00000
Epoch [  9]: Loss 0.05821
Validation:	Loss 0.05079	Accuracy 1.00000
Epoch [ 10]: Loss 0.04781
Validation:	Loss 0.04243	Accuracy 1.00000
Epoch [ 11]: Loss 0.04056
Validation:	Loss 0.03627	Accuracy 1.00000
Epoch [ 12]: Loss 0.03465
Validation:	Loss 0.03139	Accuracy 1.00000
Epoch [ 13]: Loss 0.03023
Validation:	Loss 0.02747	Accuracy 1.00000
Epoch [ 14]: Loss 0.02644
Validation:	Loss 0.02427	Accuracy 1.00000
Epoch [ 15]: Loss 0.02327
Validation:	Loss 0.02157	Accuracy 1.00000
Epoch [ 16]: Loss 0.02095
Validation:	Loss 0.01928	Accuracy 1.00000
Epoch [ 17]: Loss 0.01876
Validation:	Loss 0.01735	Accuracy 1.00000
Epoch [ 18]: Loss 0.01683
Validation:	Loss 0.01566	Accuracy 1.00000
Epoch [ 19]: Loss 0.01535
Validation:	Loss 0.01420	Accuracy 1.00000
Epoch [ 20]: Loss 0.01382
Validation:	Loss 0.01293	Accuracy 1.00000
Epoch [ 21]: Loss 0.01260
Validation:	Loss 0.01183	Accuracy 1.00000
Epoch [ 22]: Loss 0.01165
Validation:	Loss 0.01086	Accuracy 1.00000
Epoch [ 23]: Loss 0.01067
Validation:	Loss 0.01001	Accuracy 1.00000
Epoch [ 24]: Loss 0.00986
Validation:	Loss 0.00926	Accuracy 1.00000
Epoch [ 25]: Loss 0.00912
Validation:	Loss 0.00859	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.