Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function create_dataset(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
        d in data[1:(dataset_size ÷ 2)]
    ]
    anticlockwise_spirals = [
        reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
        d in data[((dataset_size ÷ 2) + 1):end]
    ]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    return x_data, labels
end

function get_dataloaders(; dataset_size=1000, sequence_length=50)
    x_data, labels = create_dataset(; dataset_size, sequence_length)
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
        ),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
    )
end

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a LSTM block and a classifier head.

We pass the field names lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
    )
end

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
    x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end

Defining Accuracy, Loss and Optimiser

Now let's define the binary cross-entropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = dev(get_dataloaders())

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        Reactant.with_config(;
            dot_general_precision=PrecisionConfig.HIGH,
            convolution_precision=PrecisionConfig.HIGH,
        ) do
            @compile model(first(train_loader)[1], ps, Lux.testmode(st))
        end
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf("Epoch [%3d]: Loss %4.5f\n", epoch, total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf(
            "Validation:\tLoss %4.5f\tAccuracy %4.5f\n",
            total_loss / total_samples,
            total_acc / total_samples
        )
    end

    return cpu_device()((train_state.parameters, train_state.states))
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-08-02 15:23:34.133985: I external/xla/xla/service/service.cc:163] XLA service 0x11a3e710 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-08-02 15:23:34.134024: I external/xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1754148214.134984 4072956 se_gpu_pjrt_client.cc:1380] Using BFC allocator.
I0000 00:00:1754148214.135080 4072956 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1754148214.135140 4072956 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
2025-08-02 15:23:34.148950: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800
Epoch [  1]: Loss 0.58661
Validation:	Loss 0.51742	Accuracy 1.00000
Epoch [  2]: Loss 0.48195
Validation:	Loss 0.42456	Accuracy 1.00000
Epoch [  3]: Loss 0.39707
Validation:	Loss 0.34845	Accuracy 1.00000
Epoch [  4]: Loss 0.32691
Validation:	Loss 0.28922	Accuracy 1.00000
Epoch [  5]: Loss 0.27600
Validation:	Loss 0.24548	Accuracy 1.00000
Epoch [  6]: Loss 0.23739
Validation:	Loss 0.21099	Accuracy 1.00000
Epoch [  7]: Loss 0.20368
Validation:	Loss 0.18194	Accuracy 1.00000
Epoch [  8]: Loss 0.17671
Validation:	Loss 0.15745	Accuracy 1.00000
Epoch [  9]: Loss 0.15375
Validation:	Loss 0.13762	Accuracy 1.00000
Epoch [ 10]: Loss 0.13480
Validation:	Loss 0.12163	Accuracy 1.00000
Epoch [ 11]: Loss 0.11921
Validation:	Loss 0.10834	Accuracy 1.00000
Epoch [ 12]: Loss 0.10704
Validation:	Loss 0.09710	Accuracy 1.00000
Epoch [ 13]: Loss 0.09640
Validation:	Loss 0.08742	Accuracy 1.00000
Epoch [ 14]: Loss 0.08546
Validation:	Loss 0.07893	Accuracy 1.00000
Epoch [ 15]: Loss 0.07784
Validation:	Loss 0.07129	Accuracy 1.00000
Epoch [ 16]: Loss 0.07010
Validation:	Loss 0.06411	Accuracy 1.00000
Epoch [ 17]: Loss 0.06301
Validation:	Loss 0.05720	Accuracy 1.00000
Epoch [ 18]: Loss 0.05583
Validation:	Loss 0.04998	Accuracy 1.00000
Epoch [ 19]: Loss 0.04787
Validation:	Loss 0.04198	Accuracy 1.00000
Epoch [ 20]: Loss 0.03942
Validation:	Loss 0.03286	Accuracy 1.00000
Epoch [ 21]: Loss 0.02954
Validation:	Loss 0.02403	Accuracy 1.00000
Epoch [ 22]: Loss 0.02133
Validation:	Loss 0.01748	Accuracy 1.00000
Epoch [ 23]: Loss 0.01583
Validation:	Loss 0.01382	Accuracy 1.00000
Epoch [ 24]: Loss 0.01301
Validation:	Loss 0.01208	Accuracy 1.00000
Epoch [ 25]: Loss 0.01156
Validation:	Loss 0.01090	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.49864
Validation:	Loss 0.41509	Accuracy 1.00000
Epoch [  2]: Loss 0.36967
Validation:	Loss 0.31113	Accuracy 1.00000
Epoch [  3]: Loss 0.27462
Validation:	Loss 0.22924	Accuracy 1.00000
Epoch [  4]: Loss 0.20196
Validation:	Loss 0.16798	Accuracy 1.00000
Epoch [  5]: Loss 0.14889
Validation:	Loss 0.12616	Accuracy 1.00000
Epoch [  6]: Loss 0.11370
Validation:	Loss 0.09916	Accuracy 1.00000
Epoch [  7]: Loss 0.09043
Validation:	Loss 0.07981	Accuracy 1.00000
Epoch [  8]: Loss 0.07377
Validation:	Loss 0.06643	Accuracy 1.00000
Epoch [  9]: Loss 0.06180
Validation:	Loss 0.05630	Accuracy 1.00000
Epoch [ 10]: Loss 0.05304
Validation:	Loss 0.04873	Accuracy 1.00000
Epoch [ 11]: Loss 0.04609
Validation:	Loss 0.04272	Accuracy 1.00000
Epoch [ 12]: Loss 0.04048
Validation:	Loss 0.03755	Accuracy 1.00000
Epoch [ 13]: Loss 0.03565
Validation:	Loss 0.03308	Accuracy 1.00000
Epoch [ 14]: Loss 0.03125
Validation:	Loss 0.02880	Accuracy 1.00000
Epoch [ 15]: Loss 0.02717
Validation:	Loss 0.02505	Accuracy 1.00000
Epoch [ 16]: Loss 0.02382
Validation:	Loss 0.02231	Accuracy 1.00000
Epoch [ 17]: Loss 0.02132
Validation:	Loss 0.02013	Accuracy 1.00000
Epoch [ 18]: Loss 0.01936
Validation:	Loss 0.01840	Accuracy 1.00000
Epoch [ 19]: Loss 0.01775
Validation:	Loss 0.01694	Accuracy 1.00000
Epoch [ 20]: Loss 0.01636
Validation:	Loss 0.01565	Accuracy 1.00000
Epoch [ 21]: Loss 0.01517
Validation:	Loss 0.01456	Accuracy 1.00000
Epoch [ 22]: Loss 0.01414
Validation:	Loss 0.01358	Accuracy 1.00000
Epoch [ 23]: Loss 0.01321
Validation:	Loss 0.01271	Accuracy 1.00000
Epoch [ 24]: Loss 0.01237
Validation:	Loss 0.01193	Accuracy 1.00000
Epoch [ 25]: Loss 0.01163
Validation:	Loss 0.01122	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.