Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling Lux...
   2693.8 ms  ✓ WeightInitializers
    944.5 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   9292.7 ms  ✓ Lux
  3 dependencies successfully precompiled in 14 seconds. 106 already precompiled.
Precompiling JLD2...
   4138.7 ms  ✓ FileIO
  32118.2 ms  ✓ JLD2
  2 dependencies successfully precompiled in 36 seconds. 30 already precompiled.
Precompiling MLUtils...
    421.6 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    436.2 ms  ✓ ContextVariablesX
    615.6 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   1146.0 ms  ✓ SimpleTraits
    459.6 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   1220.0 ms  ✓ SplittablesBase
    602.8 ms  ✓ FLoopsBase
   2275.4 ms  ✓ StatsBase
   2330.2 ms  ✓ Accessors
    684.9 ms  ✓ Accessors → StaticArraysExt
    687.9 ms  ✓ Accessors → TestExt
    853.8 ms  ✓ Accessors → LinearAlgebraExt
    756.4 ms  ✓ BangBang
    525.9 ms  ✓ BangBang → BangBangChainRulesCoreExt
    526.9 ms  ✓ BangBang → BangBangTablesExt
    696.1 ms  ✓ BangBang → BangBangStaticArraysExt
   1041.1 ms  ✓ MicroCollections
   2726.1 ms  ✓ Transducers
    672.3 ms  ✓ Transducers → TransducersAdaptExt
   5306.9 ms  ✓ FLoops
   6256.3 ms  ✓ MLUtils
  21 dependencies successfully precompiled in 22 seconds. 77 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1765.5 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 102 already precompiled.
Precompiling LuxMLUtilsExt...
   2073.3 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 167 already precompiled.
Precompiling Reactant...
    624.9 ms  ✓ ReactantCore
    947.1 ms  ✓ CUDA_Driver_jll
   2294.7 ms  ✓ Reactant_jll
 215926.2 ms  ✓ Enzyme
   5594.5 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live 
2025-02-05 19:15:47.938338: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:51] XLA (re)initializing LLVM with options fingerprint: 6211547431844565882
  65366.9 ms  ✓ Reactant
  6 dependencies successfully precompiled in 287 seconds. 56 already precompiled.
  1 dependency had output during precompilation:
┌ Reactant
│  [Output was shown above]

Precompiling LuxLibEnzymeExt...
   5851.5 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   8247.6 ms  ✓ Enzyme → EnzymeStaticArraysExt
   1302.2 ms  ✓ LuxLib → LuxLibEnzymeExt
  11062.7 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   5702.5 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
  5 dependencies successfully precompiled in 12 seconds. 126 already precompiled.
Precompiling LuxEnzymeExt...
   6715.3 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 146 already precompiled.
Precompiling LuxCoreReactantExt...
  12241.9 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 12 seconds. 67 already precompiled.
Precompiling MLDataDevicesReactantExt...
  12878.1 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 13 seconds. 64 already precompiled.
Precompiling LuxLibReactantExt...
  13041.1 ms  ✓ Reactant → ReactantStatisticsExt
  13442.3 ms  ✓ Reactant → ReactantNNlibExt
  13794.6 ms  ✓ LuxLib → LuxLibReactantExt
  12904.8 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  13424.6 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  12956.7 ms  ✓ Reactant → ReactantArrayInterfaceExt
  6 dependencies successfully precompiled in 27 seconds. 140 already precompiled.
Precompiling WeightInitializersReactantExt...
  12554.9 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  1 dependency successfully precompiled in 13 seconds. 78 already precompiled.
Precompiling LuxReactantExt...
   8645.6 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 9 seconds. 163 already precompiled.

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false)
    )
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##230".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() |> dev

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = Lux.setup(Random.default_rng(), model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        @compile model(first(train_loader)[1], ps, Lux.testmode(st))
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss/total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss/total_samples) (total_acc/total_samples)
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-02-05 19:24:08.285912: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:51] XLA (re)initializing LLVM with options fingerprint: 6285648702646104871
Epoch [  1]: Loss 0.69397
Validation:	Loss 0.63135	Accuracy 1.00000
Epoch [  2]: Loss 0.59388
Validation:	Loss 0.54033	Accuracy 1.00000
Epoch [  3]: Loss 0.50327
Validation:	Loss 0.45833	Accuracy 1.00000
Epoch [  4]: Loss 0.42248
Validation:	Loss 0.38206	Accuracy 1.00000
Epoch [  5]: Loss 0.34620
Validation:	Loss 0.30999	Accuracy 1.00000
Epoch [  6]: Loss 0.27829
Validation:	Loss 0.24980	Accuracy 1.00000
Epoch [  7]: Loss 0.22434
Validation:	Loss 0.19454	Accuracy 1.00000
Epoch [  8]: Loss 0.16757
Validation:	Loss 0.13528	Accuracy 1.00000
Epoch [  9]: Loss 0.11305
Validation:	Loss 0.08826	Accuracy 1.00000
Epoch [ 10]: Loss 0.07546
Validation:	Loss 0.06114	Accuracy 1.00000
Epoch [ 11]: Loss 0.05440
Validation:	Loss 0.04619	Accuracy 1.00000
Epoch [ 12]: Loss 0.04211
Validation:	Loss 0.03655	Accuracy 1.00000
Epoch [ 13]: Loss 0.03360
Validation:	Loss 0.02936	Accuracy 1.00000
Epoch [ 14]: Loss 0.02677
Validation:	Loss 0.02296	Accuracy 1.00000
Epoch [ 15]: Loss 0.02133
Validation:	Loss 0.01889	Accuracy 1.00000
Epoch [ 16]: Loss 0.01808
Validation:	Loss 0.01642	Accuracy 1.00000
Epoch [ 17]: Loss 0.01588
Validation:	Loss 0.01451	Accuracy 1.00000
Epoch [ 18]: Loss 0.01405
Validation:	Loss 0.01303	Accuracy 1.00000
Epoch [ 19]: Loss 0.01277
Validation:	Loss 0.01184	Accuracy 1.00000
Epoch [ 20]: Loss 0.01164
Validation:	Loss 0.01086	Accuracy 1.00000
Epoch [ 21]: Loss 0.01071
Validation:	Loss 0.01002	Accuracy 1.00000
Epoch [ 22]: Loss 0.00988
Validation:	Loss 0.00930	Accuracy 1.00000
Epoch [ 23]: Loss 0.00916
Validation:	Loss 0.00867	Accuracy 1.00000
Epoch [ 24]: Loss 0.00856
Validation:	Loss 0.00811	Accuracy 1.00000
Epoch [ 25]: Loss 0.00799
Validation:	Loss 0.00760	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.76119
Validation:	Loss 0.68223	Accuracy 0.85938
Epoch [  2]: Loss 0.62109
Validation:	Loss 0.52857	Accuracy 1.00000
Epoch [  3]: Loss 0.45825
Validation:	Loss 0.35641	Accuracy 1.00000
Epoch [  4]: Loss 0.30262
Validation:	Loss 0.23736	Accuracy 1.00000
Epoch [  5]: Loss 0.20405
Validation:	Loss 0.16532	Accuracy 1.00000
Epoch [  6]: Loss 0.14548
Validation:	Loss 0.12080	Accuracy 1.00000
Epoch [  7]: Loss 0.10650
Validation:	Loss 0.08794	Accuracy 1.00000
Epoch [  8]: Loss 0.07723
Validation:	Loss 0.06416	Accuracy 1.00000
Epoch [  9]: Loss 0.05721
Validation:	Loss 0.04926	Accuracy 1.00000
Epoch [ 10]: Loss 0.04482
Validation:	Loss 0.03982	Accuracy 1.00000
Epoch [ 11]: Loss 0.03692
Validation:	Loss 0.03321	Accuracy 1.00000
Epoch [ 12]: Loss 0.03100
Validation:	Loss 0.02826	Accuracy 1.00000
Epoch [ 13]: Loss 0.02660
Validation:	Loss 0.02442	Accuracy 1.00000
Epoch [ 14]: Loss 0.02300
Validation:	Loss 0.02136	Accuracy 1.00000
Epoch [ 15]: Loss 0.02025
Validation:	Loss 0.01890	Accuracy 1.00000
Epoch [ 16]: Loss 0.01801
Validation:	Loss 0.01686	Accuracy 1.00000
Epoch [ 17]: Loss 0.01613
Validation:	Loss 0.01517	Accuracy 1.00000
Epoch [ 18]: Loss 0.01456
Validation:	Loss 0.01371	Accuracy 1.00000
Epoch [ 19]: Loss 0.01320
Validation:	Loss 0.01244	Accuracy 1.00000
Epoch [ 20]: Loss 0.01188
Validation:	Loss 0.01131	Accuracy 1.00000
Epoch [ 21]: Loss 0.01091
Validation:	Loss 0.01028	Accuracy 1.00000
Epoch [ 22]: Loss 0.00986
Validation:	Loss 0.00932	Accuracy 1.00000
Epoch [ 23]: Loss 0.00894
Validation:	Loss 0.00843	Accuracy 1.00000
Epoch [ 24]: Loss 0.00807
Validation:	Loss 0.00759	Accuracy 1.00000
Epoch [ 25]: Loss 0.00723
Validation:	Loss 0.00683	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.