Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function create_dataset(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
        d in data[1:(dataset_size ÷ 2)]
    ]
    anticlockwise_spirals = [
        reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
        d in data[((dataset_size ÷ 2) + 1):end]
    ]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    return x_data, labels
end

function get_dataloaders(; dataset_size=1000, sequence_length=50)
    x_data, labels = create_dataset(; dataset_size, sequence_length)
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
        ),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
    )
end

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a LSTM block and a classifier head.

We pass the field names lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
    )
end

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
    x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end

Defining Accuracy, Loss and Optimiser

Now let's define the binary cross-entropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = dev(get_dataloaders())

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        Reactant.with_config(;
            dot_general_precision=PrecisionConfig.HIGH,
            convolution_precision=PrecisionConfig.HIGH,
        ) do
            @compile model(first(train_loader)[1], ps, Lux.testmode(st))
        end
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf("Epoch [%3d]: Loss %4.5f\n", epoch, total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf(
            "Validation:\tLoss %4.5f\tAccuracy %4.5f\n",
            total_loss / total_samples,
            total_acc / total_samples
        )
    end

    return cpu_device()((train_state.parameters, train_state.states))
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-08-05 23:25:19.962283: I external/xla/xla/service/service.cc:163] XLA service 0x46473b90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-08-05 23:25:19.962369: I external/xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1754436319.963208  319589 se_gpu_pjrt_client.cc:1373] Using BFC allocator.
I0000 00:00:1754436319.963396  319589 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1754436319.963549  319589 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
2025-08-05 23:25:19.979836: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800
Epoch [  1]: Loss 0.49418
Validation:	Loss 0.42908	Accuracy 0.50781
Epoch [  2]: Loss 0.41481
Validation:	Loss 0.38093	Accuracy 1.00000
Epoch [  3]: Loss 0.36448
Validation:	Loss 0.33782	Accuracy 1.00000
Epoch [  4]: Loss 0.31993
Validation:	Loss 0.28648	Accuracy 1.00000
Epoch [  5]: Loss 0.26585
Validation:	Loss 0.22743	Accuracy 1.00000
Epoch [  6]: Loss 0.20549
Validation:	Loss 0.16805	Accuracy 1.00000
Epoch [  7]: Loss 0.14619
Validation:	Loss 0.11574	Accuracy 1.00000
Epoch [  8]: Loss 0.09926
Validation:	Loss 0.07505	Accuracy 1.00000
Epoch [  9]: Loss 0.06308
Validation:	Loss 0.04703	Accuracy 1.00000
Epoch [ 10]: Loss 0.03952
Validation:	Loss 0.03005	Accuracy 1.00000
Epoch [ 11]: Loss 0.02618
Validation:	Loss 0.02109	Accuracy 1.00000
Epoch [ 12]: Loss 0.01878
Validation:	Loss 0.01577	Accuracy 1.00000
Epoch [ 13]: Loss 0.01435
Validation:	Loss 0.01234	Accuracy 1.00000
Epoch [ 14]: Loss 0.01140
Validation:	Loss 0.01007	Accuracy 1.00000
Epoch [ 15]: Loss 0.00946
Validation:	Loss 0.00854	Accuracy 1.00000
Epoch [ 16]: Loss 0.00817
Validation:	Loss 0.00744	Accuracy 1.00000
Epoch [ 17]: Loss 0.00711
Validation:	Loss 0.00662	Accuracy 1.00000
Epoch [ 18]: Loss 0.00636
Validation:	Loss 0.00594	Accuracy 1.00000
Epoch [ 19]: Loss 0.00574
Validation:	Loss 0.00538	Accuracy 1.00000
Epoch [ 20]: Loss 0.00523
Validation:	Loss 0.00490	Accuracy 1.00000
Epoch [ 21]: Loss 0.00476
Validation:	Loss 0.00449	Accuracy 1.00000
Epoch [ 22]: Loss 0.00438
Validation:	Loss 0.00414	Accuracy 1.00000
Epoch [ 23]: Loss 0.00403
Validation:	Loss 0.00382	Accuracy 1.00000
Epoch [ 24]: Loss 0.00374
Validation:	Loss 0.00353	Accuracy 1.00000
Epoch [ 25]: Loss 0.00345
Validation:	Loss 0.00326	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.75722
Validation:	Loss 0.65034	Accuracy 0.50000
Epoch [  2]: Loss 0.60343
Validation:	Loss 0.54516	Accuracy 0.50000
Epoch [  3]: Loss 0.50653
Validation:	Loss 0.45937	Accuracy 1.00000
Epoch [  4]: Loss 0.42980
Validation:	Loss 0.39885	Accuracy 1.00000
Epoch [  5]: Loss 0.37428
Validation:	Loss 0.35542	Accuracy 1.00000
Epoch [  6]: Loss 0.33829
Validation:	Loss 0.32162	Accuracy 1.00000
Epoch [  7]: Loss 0.30593
Validation:	Loss 0.29370	Accuracy 1.00000
Epoch [  8]: Loss 0.28075
Validation:	Loss 0.26765	Accuracy 1.00000
Epoch [  9]: Loss 0.25346
Validation:	Loss 0.24088	Accuracy 1.00000
Epoch [ 10]: Loss 0.22605
Validation:	Loss 0.21062	Accuracy 1.00000
Epoch [ 11]: Loss 0.19468
Validation:	Loss 0.17477	Accuracy 1.00000
Epoch [ 12]: Loss 0.15650
Validation:	Loss 0.13699	Accuracy 1.00000
Epoch [ 13]: Loss 0.12160
Validation:	Loss 0.10415	Accuracy 1.00000
Epoch [ 14]: Loss 0.09041
Validation:	Loss 0.07536	Accuracy 1.00000
Epoch [ 15]: Loss 0.06361
Validation:	Loss 0.04965	Accuracy 1.00000
Epoch [ 16]: Loss 0.04194
Validation:	Loss 0.03421	Accuracy 1.00000
Epoch [ 17]: Loss 0.03058
Validation:	Loss 0.02691	Accuracy 1.00000
Epoch [ 18]: Loss 0.02480
Validation:	Loss 0.02259	Accuracy 1.00000
Epoch [ 19]: Loss 0.02120
Validation:	Loss 0.01950	Accuracy 1.00000
Epoch [ 20]: Loss 0.01838
Validation:	Loss 0.01715	Accuracy 1.00000
Epoch [ 21]: Loss 0.01631
Validation:	Loss 0.01526	Accuracy 1.00000
Epoch [ 22]: Loss 0.01452
Validation:	Loss 0.01364	Accuracy 1.00000
Epoch [ 23]: Loss 0.01302
Validation:	Loss 0.01224	Accuracy 1.00000
Epoch [ 24]: Loss 0.01166
Validation:	Loss 0.01091	Accuracy 1.00000
Epoch [ 25]: Loss 0.01034
Validation:	Loss 0.00953	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.