Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
        d in data[1:(dataset_size ÷ 2)]
    ]
    anticlockwise_spirals = [
        reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
        d in data[((dataset_size ÷ 2) + 1):end]
    ]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
        ),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
    )
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
    )
end
Main.var"##230".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
    x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = dev(get_dataloaders())

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        @compile model(first(train_loader)[1], ps, Lux.testmode(st))
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (
            total_acc / total_samples
        )
    end

    return cpu_device()((train_state.parameters, train_state.states))
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-05-23 22:57:28.654724: I external/xla/xla/service/service.cc:152] XLA service 0xc537210 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-23 22:57:28.654821: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748041048.655508  800117 se_gpu_pjrt_client.cc:1026] Using BFC allocator.
I0000 00:00:1748041048.655588  800117 gpu_helpers.cc:136] XLA backend allocating 12527321088 bytes on device 0 for BFCAllocator.
I0000 00:00:1748041048.655619  800117 gpu_helpers.cc:177] XLA backend will use up to 4175773696 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1748041048.666862  800117 cuda_dnn.cc:529] Loaded cuDNN version 90400
Epoch [  1]: Loss 0.72328
Validation:	Loss 0.66961	Accuracy 0.47656
Epoch [  2]: Loss 0.62975
Validation:	Loss 0.60080	Accuracy 0.47656
Epoch [  3]: Loss 0.57419
Validation:	Loss 0.55200	Accuracy 1.00000
Epoch [  4]: Loss 0.52276
Validation:	Loss 0.49260	Accuracy 1.00000
Epoch [  5]: Loss 0.45420
Validation:	Loss 0.40837	Accuracy 1.00000
Epoch [  6]: Loss 0.35786
Validation:	Loss 0.29874	Accuracy 1.00000
Epoch [  7]: Loss 0.25151
Validation:	Loss 0.20315	Accuracy 1.00000
Epoch [  8]: Loss 0.17247
Validation:	Loss 0.14086	Accuracy 1.00000
Epoch [  9]: Loss 0.12162
Validation:	Loss 0.10067	Accuracy 1.00000
Epoch [ 10]: Loss 0.08789
Validation:	Loss 0.07288	Accuracy 1.00000
Epoch [ 11]: Loss 0.06482
Validation:	Loss 0.05478	Accuracy 1.00000
Epoch [ 12]: Loss 0.04981
Validation:	Loss 0.04289	Accuracy 1.00000
Epoch [ 13]: Loss 0.03951
Validation:	Loss 0.03464	Accuracy 1.00000
Epoch [ 14]: Loss 0.03226
Validation:	Loss 0.02865	Accuracy 1.00000
Epoch [ 15]: Loss 0.02687
Validation:	Loss 0.02423	Accuracy 1.00000
Epoch [ 16]: Loss 0.02296
Validation:	Loss 0.02086	Accuracy 1.00000
Epoch [ 17]: Loss 0.01982
Validation:	Loss 0.01826	Accuracy 1.00000
Epoch [ 18]: Loss 0.01743
Validation:	Loss 0.01622	Accuracy 1.00000
Epoch [ 19]: Loss 0.01555
Validation:	Loss 0.01460	Accuracy 1.00000
Epoch [ 20]: Loss 0.01403
Validation:	Loss 0.01328	Accuracy 1.00000
Epoch [ 21]: Loss 0.01278
Validation:	Loss 0.01217	Accuracy 1.00000
Epoch [ 22]: Loss 0.01174
Validation:	Loss 0.01120	Accuracy 1.00000
Epoch [ 23]: Loss 0.01081
Validation:	Loss 0.01035	Accuracy 1.00000
Epoch [ 24]: Loss 0.01000
Validation:	Loss 0.00958	Accuracy 1.00000
Epoch [ 25]: Loss 0.00925
Validation:	Loss 0.00887	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.62104
Validation:	Loss 0.54463	Accuracy 1.00000
Epoch [  2]: Loss 0.48101
Validation:	Loss 0.39142	Accuracy 1.00000
Epoch [  3]: Loss 0.35813
Validation:	Loss 0.30133	Accuracy 1.00000
Epoch [  4]: Loss 0.27453
Validation:	Loss 0.23122	Accuracy 1.00000
Epoch [  5]: Loss 0.21187
Validation:	Loss 0.18427	Accuracy 1.00000
Epoch [  6]: Loss 0.16971
Validation:	Loss 0.14992	Accuracy 1.00000
Epoch [  7]: Loss 0.13876
Validation:	Loss 0.12375	Accuracy 1.00000
Epoch [  8]: Loss 0.11506
Validation:	Loss 0.10300	Accuracy 1.00000
Epoch [  9]: Loss 0.09583
Validation:	Loss 0.08521	Accuracy 1.00000
Epoch [ 10]: Loss 0.07883
Validation:	Loss 0.06890	Accuracy 1.00000
Epoch [ 11]: Loss 0.06324
Validation:	Loss 0.05414	Accuracy 1.00000
Epoch [ 12]: Loss 0.04985
Validation:	Loss 0.04285	Accuracy 1.00000
Epoch [ 13]: Loss 0.04028
Validation:	Loss 0.03510	Accuracy 1.00000
Epoch [ 14]: Loss 0.03357
Validation:	Loss 0.02976	Accuracy 1.00000
Epoch [ 15]: Loss 0.02895
Validation:	Loss 0.02593	Accuracy 1.00000
Epoch [ 16]: Loss 0.02552
Validation:	Loss 0.02306	Accuracy 1.00000
Epoch [ 17]: Loss 0.02291
Validation:	Loss 0.02081	Accuracy 1.00000
Epoch [ 18]: Loss 0.02083
Validation:	Loss 0.01900	Accuracy 1.00000
Epoch [ 19]: Loss 0.01901
Validation:	Loss 0.01748	Accuracy 1.00000
Epoch [ 20]: Loss 0.01762
Validation:	Loss 0.01618	Accuracy 1.00000
Epoch [ 21]: Loss 0.01635
Validation:	Loss 0.01505	Accuracy 1.00000
Epoch [ 22]: Loss 0.01517
Validation:	Loss 0.01406	Accuracy 1.00000
Epoch [ 23]: Loss 0.01426
Validation:	Loss 0.01318	Accuracy 1.00000
Epoch [ 24]: Loss 0.01334
Validation:	Loss 0.01239	Accuracy 1.00000
Epoch [ 25]: Loss 0.01257
Validation:	Loss 0.01168	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.