Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling Reactant...
  13173.7 ms  ? Enzyme
  13627.6 ms  ? Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  14897.2 ms  ? Reactant
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Enzyme...
Info Given Enzyme was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  13239.6 ms  ? Enzyme
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling LuxEnzymeExt...
  13477.6 ms  ? Enzyme
    712.3 ms  ? Enzyme → EnzymeChainRulesCoreExt
    864.5 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    899.6 ms  ? Enzyme → EnzymeStaticArraysExt
    701.6 ms  ? Enzyme → EnzymeLogExpFunctionsExt
Info Given LuxEnzymeExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    729.4 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    710.8 ms  ? Lux → LuxEnzymeExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeLogExpFunctionsExt...
  13046.4 ms  ? Enzyme
Info Given EnzymeLogExpFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    704.4 ms  ? Enzyme → EnzymeLogExpFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeChainRulesCoreExt...
  13284.7 ms  ? Enzyme
Info Given EnzymeChainRulesCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    699.9 ms  ? Enzyme → EnzymeChainRulesCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeSpecialFunctionsExt...
  12960.4 ms  ? Enzyme
    732.1 ms  ? Enzyme → EnzymeLogExpFunctionsExt
Info Given EnzymeSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    861.9 ms  ? Enzyme → EnzymeSpecialFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeGPUArraysCoreExt...
  13060.0 ms  ? Enzyme
Info Given EnzymeGPUArraysCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    685.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeStaticArraysExt...
  13450.1 ms  ? Enzyme
Info Given EnzymeStaticArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    892.5 ms  ? Enzyme → EnzymeStaticArraysExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling OptimisersReactantExt...
  13399.3 ms  ? Enzyme
    703.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    721.8 ms  ? Enzyme → EnzymeChainRulesCoreExt
   1963.7 ms  ? Reactant
    687.7 ms  ? Reactant → ReactantStatisticsExt
Info Given OptimisersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    726.1 ms  ? Optimisers → OptimisersReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxCoreReactantExt...
  13065.9 ms  ? Enzyme
    705.8 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1951.9 ms  ? Reactant
Info Given LuxCoreReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    766.4 ms  ? LuxCore → LuxCoreReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling MLDataDevicesReactantExt...
  13113.1 ms  ? Enzyme
    683.8 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1879.9 ms  ? Reactant
Info Given MLDataDevicesReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    727.2 ms  ? MLDataDevices → MLDataDevicesReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling WeightInitializersReactantExt...
  13172.6 ms  ? Enzyme
    723.1 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    721.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    923.3 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   1982.8 ms  ? Reactant
Info Given WeightInitializersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    759.0 ms  ? Reactant → ReactantStatisticsExt
    754.3 ms  ? WeightInitializers → WeightInitializersReactantExt
    906.1 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantKernelAbstractionsExt...
  13196.5 ms  ? Enzyme
    712.9 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    889.3 ms  ? Enzyme → EnzymeStaticArraysExt
   1901.9 ms  ? Reactant
Info Given ReactantKernelAbstractionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    684.2 ms  ? Reactant → ReactantKernelAbstractionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantArrayInterfaceExt...
  13729.1 ms  ? Enzyme
    716.1 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2038.0 ms  ? Reactant
Info Given ReactantArrayInterfaceExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    687.8 ms  ? Reactant → ReactantArrayInterfaceExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantSpecialFunctionsExt...
  13103.2 ms  ? Enzyme
    706.7 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    729.9 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    865.4 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   2052.4 ms  ? Reactant
Info Given ReactantSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    891.3 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantStatisticsExt...
  13680.2 ms  ? Enzyme
    693.9 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2044.9 ms  ? Reactant
Info Given ReactantStatisticsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    732.6 ms  ? Reactant → ReactantStatisticsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxLibReactantExt...
  13491.7 ms  ? Enzyme
    726.1 ms  ? Enzyme → EnzymeChainRulesCoreExt
    836.2 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    914.6 ms  ? Enzyme → EnzymeStaticArraysExt
    714.4 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    685.8 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1891.8 ms  ? Reactant
    688.0 ms  ? Reactant → ReactantStatisticsExt
    740.3 ms  ? Reactant → ReactantKernelAbstractionsExt
    887.3 ms  ? Reactant → ReactantSpecialFunctionsExt
    748.1 ms  ? Reactant → ReactantArrayInterfaceExt
    801.9 ms  ? MLDataDevices → MLDataDevicesReactantExt
    751.1 ms  ? LuxCore → LuxCoreReactantExt
Info Given LuxLibReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    704.9 ms  ? LuxLib → LuxLibReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantNNlibExt...
  13380.0 ms  ? Enzyme
    732.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    772.3 ms  ? Enzyme → EnzymeChainRulesCoreExt
    925.9 ms  ? Enzyme → EnzymeStaticArraysExt
   2102.9 ms  ? Reactant
    705.6 ms  ? Reactant → ReactantStatisticsExt
    732.6 ms  ? Reactant → ReactantKernelAbstractionsExt
Info Given ReactantNNlibExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
   1024.3 ms  ? Reactant → ReactantNNlibExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-cf3b-a57751497ca7 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxReactantExt...
  13342.6 ms  ? Enzyme
    754.4 ms  ? Enzyme → EnzymeChainRulesCoreExt
    862.0 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    917.2 ms  ? Enzyme → EnzymeStaticArraysExt
    742.6 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    733.6 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    714.3 ms  ? Lux → LuxEnzymeExt
   2103.5 ms  ? Reactant
    697.1 ms  ? Reactant → ReactantStatisticsExt
    717.3 ms  ? Reactant → ReactantKernelAbstractionsExt
    890.5 ms  ? Reactant → ReactantSpecialFunctionsExt
    735.8 ms  ? Reactant → ReactantArrayInterfaceExt
    770.7 ms  ? MLDataDevices → MLDataDevicesReactantExt
    749.2 ms  ? LuxCore → LuxCoreReactantExt
    760.8 ms  ? Optimisers → OptimisersReactantExt
    720.0 ms  ? WeightInitializers → WeightInitializersReactantExt
    718.5 ms  ? LuxLib → LuxLibReactantExt
Info Given LuxReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    728.8 ms  ? Lux → LuxReactantExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-2d10-7b61019616ef is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function create_dataset(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
        d in data[1:(dataset_size ÷ 2)]
    ]
    anticlockwise_spirals = [
        reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
        d in data[((dataset_size ÷ 2) + 1):end]
    ]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    return x_data, labels
end

function get_dataloaders(; dataset_size=1000, sequence_length=50)
    x_data, labels = create_dataset(; dataset_size, sequence_length)
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
        ),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
    )
end

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a LSTM block and a classifier head.

We pass the field names lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
    )
end

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
    x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end

Defining Accuracy, Loss and Optimiser

Now let's define the binary cross-entropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = dev(get_dataloaders())

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        Reactant.with_config(;
            dot_general_precision=PrecisionConfig.HIGH,
            convolution_precision=PrecisionConfig.HIGH,
        ) do
            @compile model(first(train_loader)[1], ps, Lux.testmode(st))
        end
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf("Epoch [%3d]: Loss %4.5f\n", epoch, total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf(
            "Validation:\tLoss %4.5f\tAccuracy %4.5f\n",
            total_loss / total_samples,
            total_acc / total_samples
        )
    end

    return cpu_device()((train_state.parameters, train_state.states))
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")
Epoch [  1]: Loss 0.56034
Validation:	Loss 0.49995	Accuracy 1.00000
Epoch [  2]: Loss 0.47919
Validation:	Loss 0.43457	Accuracy 1.00000
Epoch [  3]: Loss 0.41732
Validation:	Loss 0.37449	Accuracy 1.00000
Epoch [  4]: Loss 0.36183
Validation:	Loss 0.31881	Accuracy 1.00000
Epoch [  5]: Loss 0.30706
Validation:	Loss 0.26932	Accuracy 1.00000
Epoch [  6]: Loss 0.26053
Validation:	Loss 0.22607	Accuracy 1.00000
Epoch [  7]: Loss 0.21532
Validation:	Loss 0.18761	Accuracy 1.00000
Epoch [  8]: Loss 0.18034
Validation:	Loss 0.15336	Accuracy 1.00000
Epoch [  9]: Loss 0.14510
Validation:	Loss 0.12397	Accuracy 1.00000
Epoch [ 10]: Loss 0.11857
Validation:	Loss 0.09905	Accuracy 1.00000
Epoch [ 11]: Loss 0.09325
Validation:	Loss 0.07704	Accuracy 1.00000
Epoch [ 12]: Loss 0.07055
Validation:	Loss 0.05727	Accuracy 1.00000
Epoch [ 13]: Loss 0.05224
Validation:	Loss 0.04244	Accuracy 1.00000
Epoch [ 14]: Loss 0.03941
Validation:	Loss 0.03247	Accuracy 1.00000
Epoch [ 15]: Loss 0.03041
Validation:	Loss 0.02610	Accuracy 1.00000
Epoch [ 16]: Loss 0.02476
Validation:	Loss 0.02157	Accuracy 1.00000
Epoch [ 17]: Loss 0.02061
Validation:	Loss 0.01812	Accuracy 1.00000
Epoch [ 18]: Loss 0.01751
Validation:	Loss 0.01541	Accuracy 1.00000
Epoch [ 19]: Loss 0.01479
Validation:	Loss 0.01334	Accuracy 1.00000
Epoch [ 20]: Loss 0.01293
Validation:	Loss 0.01174	Accuracy 1.00000
Epoch [ 21]: Loss 0.01157
Validation:	Loss 0.01048	Accuracy 1.00000
Epoch [ 22]: Loss 0.01030
Validation:	Loss 0.00948	Accuracy 1.00000
Epoch [ 23]: Loss 0.00935
Validation:	Loss 0.00867	Accuracy 1.00000
Epoch [ 24]: Loss 0.00865
Validation:	Loss 0.00799	Accuracy 1.00000
Epoch [ 25]: Loss 0.00796
Validation:	Loss 0.00741	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.48760
Validation:	Loss 0.41910	Accuracy 1.00000
Epoch [  2]: Loss 0.34760
Validation:	Loss 0.29167	Accuracy 1.00000
Epoch [  3]: Loss 0.25437
Validation:	Loss 0.23289	Accuracy 1.00000
Epoch [  4]: Loss 0.20502
Validation:	Loss 0.19176	Accuracy 1.00000
Epoch [  5]: Loss 0.16703
Validation:	Loss 0.15554	Accuracy 1.00000
Epoch [  6]: Loss 0.13056
Validation:	Loss 0.11228	Accuracy 1.00000
Epoch [  7]: Loss 0.09093
Validation:	Loss 0.07573	Accuracy 1.00000
Epoch [  8]: Loss 0.06429
Validation:	Loss 0.05762	Accuracy 1.00000
Epoch [  9]: Loss 0.04965
Validation:	Loss 0.04575	Accuracy 1.00000
Epoch [ 10]: Loss 0.03996
Validation:	Loss 0.03729	Accuracy 1.00000
Epoch [ 11]: Loss 0.03271
Validation:	Loss 0.03073	Accuracy 1.00000
Epoch [ 12]: Loss 0.02712
Validation:	Loss 0.02544	Accuracy 1.00000
Epoch [ 13]: Loss 0.02261
Validation:	Loss 0.02104	Accuracy 1.00000
Epoch [ 14]: Loss 0.01875
Validation:	Loss 0.01744	Accuracy 1.00000
Epoch [ 15]: Loss 0.01578
Validation:	Loss 0.01472	Accuracy 1.00000
Epoch [ 16]: Loss 0.01352
Validation:	Loss 0.01279	Accuracy 1.00000
Epoch [ 17]: Loss 0.01193
Validation:	Loss 0.01139	Accuracy 1.00000
Epoch [ 18]: Loss 0.01072
Validation:	Loss 0.01030	Accuracy 1.00000
Epoch [ 19]: Loss 0.00978
Validation:	Loss 0.00939	Accuracy 1.00000
Epoch [ 20]: Loss 0.00893
Validation:	Loss 0.00860	Accuracy 1.00000
Epoch [ 21]: Loss 0.00820
Validation:	Loss 0.00789	Accuracy 1.00000
Epoch [ 22]: Loss 0.00753
Validation:	Loss 0.00725	Accuracy 1.00000
Epoch [ 23]: Loss 0.00693
Validation:	Loss 0.00666	Accuracy 1.00000
Epoch [ 24]: Loss 0.00637
Validation:	Loss 0.00612	Accuracy 1.00000
Epoch [ 25]: Loss 0.00587
Validation:	Loss 0.00563	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.