Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling Lux...
   5899.6 ms  ✓ LuxLib
   9313.1 ms  ✓ Lux
  2 dependencies successfully precompiled in 16 seconds. 107 already precompiled.
Precompiling LuxMLUtilsExt...
   2229.9 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 167 already precompiled.
Precompiling LuxLibEnzymeExt...
   1362.5 ms  ✓ LuxLib → LuxLibEnzymeExt
  1 dependency successfully precompiled in 2 seconds. 130 already precompiled.
Precompiling LuxEnzymeExt...
   6890.5 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 146 already precompiled.
Precompiling LuxLibReactantExt...
  13508.2 ms  ✓ Reactant → ReactantNNlibExt
  13758.2 ms  ✓ LuxLib → LuxLibReactantExt
  2 dependencies successfully precompiled in 14 seconds. 142 already precompiled.
Precompiling LuxReactantExt...
   8542.8 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 9 seconds. 161 already precompiled.

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false)
    )
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##230".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() |> dev

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = Lux.setup(Random.default_rng(), model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        @compile model(first(train_loader)[1], ps, Lux.testmode(st))
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss/total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss/total_samples) (total_acc/total_samples)
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-01-24 06:01:33.744382: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 15805882848018654454
Epoch [  1]: Loss 0.61726
Validation:	Loss 0.54094	Accuracy 1.00000
Epoch [  2]: Loss 0.50823
Validation:	Loss 0.46061	Accuracy 1.00000
Epoch [  3]: Loss 0.43373
Validation:	Loss 0.39411	Accuracy 1.00000
Epoch [  4]: Loss 0.37233
Validation:	Loss 0.33552	Accuracy 1.00000
Epoch [  5]: Loss 0.31353
Validation:	Loss 0.27997	Accuracy 1.00000
Epoch [  6]: Loss 0.25701
Validation:	Loss 0.22527	Accuracy 1.00000
Epoch [  7]: Loss 0.20482
Validation:	Loss 0.17676	Accuracy 1.00000
Epoch [  8]: Loss 0.16106
Validation:	Loss 0.13826	Accuracy 1.00000
Epoch [  9]: Loss 0.12582
Validation:	Loss 0.10878	Accuracy 1.00000
Epoch [ 10]: Loss 0.09994
Validation:	Loss 0.08616	Accuracy 1.00000
Epoch [ 11]: Loss 0.07903
Validation:	Loss 0.06723	Accuracy 1.00000
Epoch [ 12]: Loss 0.06045
Validation:	Loss 0.05047	Accuracy 1.00000
Epoch [ 13]: Loss 0.04459
Validation:	Loss 0.03730	Accuracy 1.00000
Epoch [ 14]: Loss 0.03376
Validation:	Loss 0.02840	Accuracy 1.00000
Epoch [ 15]: Loss 0.02581
Validation:	Loss 0.02201	Accuracy 1.00000
Epoch [ 16]: Loss 0.02013
Validation:	Loss 0.01738	Accuracy 1.00000
Epoch [ 17]: Loss 0.01603
Validation:	Loss 0.01405	Accuracy 1.00000
Epoch [ 18]: Loss 0.01306
Validation:	Loss 0.01164	Accuracy 1.00000
Epoch [ 19]: Loss 0.01091
Validation:	Loss 0.00988	Accuracy 1.00000
Epoch [ 20]: Loss 0.00935
Validation:	Loss 0.00856	Accuracy 1.00000
Epoch [ 21]: Loss 0.00814
Validation:	Loss 0.00756	Accuracy 1.00000
Epoch [ 22]: Loss 0.00724
Validation:	Loss 0.00681	Accuracy 1.00000
Epoch [ 23]: Loss 0.00657
Validation:	Loss 0.00623	Accuracy 1.00000
Epoch [ 24]: Loss 0.00601
Validation:	Loss 0.00573	Accuracy 1.00000
Epoch [ 25]: Loss 0.00553
Validation:	Loss 0.00529	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.66203
Validation:	Loss 0.57153	Accuracy 0.57031
Epoch [  2]: Loss 0.54116
Validation:	Loss 0.46976	Accuracy 1.00000
Epoch [  3]: Loss 0.43337
Validation:	Loss 0.37439	Accuracy 1.00000
Epoch [  4]: Loss 0.33407
Validation:	Loss 0.28618	Accuracy 1.00000
Epoch [  5]: Loss 0.24639
Validation:	Loss 0.20935	Accuracy 1.00000
Epoch [  6]: Loss 0.17783
Validation:	Loss 0.15226	Accuracy 1.00000
Epoch [  7]: Loss 0.12870
Validation:	Loss 0.11263	Accuracy 1.00000
Epoch [  8]: Loss 0.09590
Validation:	Loss 0.08546	Accuracy 1.00000
Epoch [  9]: Loss 0.07340
Validation:	Loss 0.06627	Accuracy 1.00000
Epoch [ 10]: Loss 0.05773
Validation:	Loss 0.05271	Accuracy 1.00000
Epoch [ 11]: Loss 0.04620
Validation:	Loss 0.04315	Accuracy 1.00000
Epoch [ 12]: Loss 0.03817
Validation:	Loss 0.03601	Accuracy 1.00000
Epoch [ 13]: Loss 0.03216
Validation:	Loss 0.03047	Accuracy 1.00000
Epoch [ 14]: Loss 0.02748
Validation:	Loss 0.02618	Accuracy 1.00000
Epoch [ 15]: Loss 0.02378
Validation:	Loss 0.02296	Accuracy 1.00000
Epoch [ 16]: Loss 0.02090
Validation:	Loss 0.02045	Accuracy 1.00000
Epoch [ 17]: Loss 0.01889
Validation:	Loss 0.01840	Accuracy 1.00000
Epoch [ 18]: Loss 0.01699
Validation:	Loss 0.01672	Accuracy 1.00000
Epoch [ 19]: Loss 0.01554
Validation:	Loss 0.01531	Accuracy 1.00000
Epoch [ 20]: Loss 0.01431
Validation:	Loss 0.01410	Accuracy 1.00000
Epoch [ 21]: Loss 0.01323
Validation:	Loss 0.01306	Accuracy 1.00000
Epoch [ 22]: Loss 0.01226
Validation:	Loss 0.01215	Accuracy 1.00000
Epoch [ 23]: Loss 0.01142
Validation:	Loss 0.01134	Accuracy 1.00000
Epoch [ 24]: Loss 0.01065
Validation:	Loss 0.01063	Accuracy 1.00000
Epoch [ 25]: Loss 0.00998
Validation:	Loss 0.00999	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.