Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling ADTypes...
    687.9 ms  ✓ ADTypes
  1 dependency successfully precompiled in 1 seconds
Precompiling Lux...
    331.8 ms  ✓ Future
    349.1 ms  ✓ ConcreteStructs
    366.0 ms  ✓ OpenLibm_jll
    488.0 ms  ✓ Statistics
    350.5 ms  ✓ ArgCheck
    435.5 ms  ✓ CompilerSupportLibraries_jll
    358.2 ms  ✓ ManualMemory
   1705.2 ms  ✓ UnsafeAtomics
    307.2 ms  ✓ Reexport
    279.8 ms  ✓ SIMDTypes
    500.4 ms  ✓ DocStringExtensions
    356.4 ms  ✓ HashArrayMappedTries
   1100.5 ms  ✓ IrrationalConstants
    538.2 ms  ✓ EnzymeCore
   2304.9 ms  ✓ MacroTools
    341.0 ms  ✓ IfElse
    314.6 ms  ✓ CommonWorldInvalidations
    418.2 ms  ✓ ConstructionBase
    445.8 ms  ✓ Adapt
    308.1 ms  ✓ FastClosures
    389.5 ms  ✓ StaticArraysCore
    624.6 ms  ✓ CpuId
    476.0 ms  ✓ JLLWrappers
    483.5 ms  ✓ NaNMath
    793.0 ms  ✓ ThreadingUtilities
   1216.5 ms  ✓ ChainRulesCore
    548.4 ms  ✓ Atomix
    352.1 ms  ✓ ScopedValues
    356.4 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    590.4 ms  ✓ LogExpFunctions
    627.2 ms  ✓ CommonSubexpressions
    359.3 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    743.7 ms  ✓ Static
    367.9 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    500.5 ms  ✓ ArrayInterface
   1502.4 ms  ✓ DispatchDoctor
    426.2 ms  ✓ GPUArraysCore
    349.3 ms  ✓ EnzymeCore → AdaptExt
    369.8 ms  ✓ DiffResults
    379.2 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    585.6 ms  ✓ OpenSpecFun_jll
    573.8 ms  ✓ Functors
   1245.2 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    408.6 ms  ✓ BitTwiddlingConvenienceFunctions
   1445.9 ms  ✓ Setfield
   1006.5 ms  ✓ CPUSummary
    390.9 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
   1225.0 ms  ✓ StaticArrayInterface
    374.6 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    423.6 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    615.7 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
    349.6 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
   1285.0 ms  ✓ LuxCore
    784.7 ms  ✓ MLDataDevices
   7338.3 ms  ✓ StaticArrays
   2554.1 ms  ✓ SpecialFunctions
   1279.2 ms  ✓ Optimisers
    655.3 ms  ✓ PolyesterWeave
    440.7 ms  ✓ CloseOpenIntervals
    576.4 ms  ✓ LayoutPointers
    449.3 ms  ✓ LuxCore → LuxCoreFunctorsExt
    580.4 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
    425.9 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    441.3 ms  ✓ LuxCore → LuxCoreSetfieldExt
    483.4 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    611.6 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
    688.1 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    598.5 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    602.0 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    667.1 ms  ✓ Adapt → AdaptStaticArraysExt
    664.4 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    612.0 ms  ✓ DiffRules
    438.0 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
   1595.9 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
    432.8 ms  ✓ Optimisers → OptimisersAdaptExt
    870.3 ms  ✓ StrideArraysCore
   2640.3 ms  ✓ WeightInitializers
    720.7 ms  ✓ Polyester
    911.2 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   4143.7 ms  ✓ KernelAbstractions
   3734.5 ms  ✓ ForwardDiff
    732.5 ms  ✓ KernelAbstractions → LinearAlgebraExt
    756.4 ms  ✓ KernelAbstractions → EnzymeExt
    892.4 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   5467.8 ms  ✓ NNlib
    874.0 ms  ✓ NNlib → NNlibEnzymeCoreExt
    876.0 ms  ✓ NNlib → NNlibSpecialFunctionsExt
    922.2 ms  ✓ NNlib → NNlibForwardDiffExt
   5559.4 ms  ✓ LuxLib
   9575.2 ms  ✓ Lux
  90 dependencies successfully precompiled in 46 seconds. 15 already precompiled.
Precompiling JLD2...
    364.5 ms  ✓ Zlib_jll
    541.5 ms  ✓ TranscodingStreams
    602.0 ms  ✓ OrderedCollections
   4655.3 ms  ✓ FileIO
  34509.9 ms  ✓ JLD2
  5 dependencies successfully precompiled in 40 seconds. 29 already precompiled.
Precompiling MLUtils...
    445.5 ms  ✓ DelimitedFiles
    450.2 ms  ✓ SuiteSparse_jll
    559.1 ms  ✓ Serialization
    381.2 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    388.5 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    803.8 ms  ✓ Tables
    405.9 ms  ✓ ContextVariablesX
   1049.2 ms  ✓ SimpleTraits
   1707.1 ms  ✓ DataStructures
   1907.2 ms  ✓ Distributed
   3733.6 ms  ✓ SparseArrays
   2427.6 ms  ✓ Accessors
    619.7 ms  ✓ FLoopsBase
    974.7 ms  ✓ MLCore
    484.8 ms  ✓ SortingAlgorithms
   3933.3 ms  ✓ Test
    616.4 ms  ✓ Statistics → SparseArraysExt
    614.9 ms  ✓ Adapt → AdaptSparseArraysExt
    643.1 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
    805.5 ms  ✓ Accessors → LinearAlgebraExt
    670.2 ms  ✓ Accessors → StaticArraysExt
    934.1 ms  ✓ KernelAbstractions → SparseArraysExt
    598.6 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    629.4 ms  ✓ Accessors → TestExt
   1193.2 ms  ✓ SplittablesBase
    854.1 ms  ✓ BangBang
    478.4 ms  ✓ BangBang → BangBangChainRulesCoreExt
    696.8 ms  ✓ BangBang → BangBangStaticArraysExt
    477.5 ms  ✓ BangBang → BangBangTablesExt
   2248.5 ms  ✓ StatsBase
    946.9 ms  ✓ MicroCollections
   2705.9 ms  ✓ Transducers
    655.1 ms  ✓ Transducers → TransducersAdaptExt
   5000.5 ms  ✓ FLoops
   5874.4 ms  ✓ MLUtils
  35 dependencies successfully precompiled in 25 seconds. 62 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
    616.9 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    668.4 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1514.9 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 101 already precompiled.
Precompiling LuxMLUtilsExt...
   2082.2 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 164 already precompiled.
Precompiling Reactant...
    369.8 ms  ✓ ExprTools
    372.3 ms  ✓ CEnum
    651.9 ms  ✓ ExpressionExplorer
    355.8 ms  ✓ EnumX
    361.7 ms  ✓ StructIO
    327.3 ms  ✓ SimpleBufferStream
    358.4 ms  ✓ BitFlags
    464.4 ms  ✓ CodecZlib
    418.5 ms  ✓ Scratch
    703.1 ms  ✓ ConcurrentUtilities
    973.9 ms  ✓ MbedTLS
    516.2 ms  ✓ LoggingExtras
    469.7 ms  ✓ ExceptionUnwrapping
    607.8 ms  ✓ OpenSSL_jll
    639.0 ms  ✓ LLVMOpenMP_jll
    586.1 ms  ✓ LibTracyClient_jll
   1593.7 ms  ✓ CUDA_Driver_jll
    647.2 ms  ✓ ReactantCore
   1137.1 ms  ✓ LazyArtifacts
    879.2 ms  ✓ Tracy
   2075.3 ms  ✓ ObjectFile
   2047.6 ms  ✓ OpenSSL
   1415.1 ms  ✓ Enzyme_jll
   1435.8 ms  ✓ LLVMExtra_jll
   2835.2 ms  ✓ Reactant_jll
   6536.0 ms  ✓ LLVM
  19078.5 ms  ✓ HTTP
  29069.3 ms  ✓ GPUCompiler
  92753.1 ms  ✓ Enzyme
   6639.6 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
  91529.4 ms  ✓ Reactant
  31 dependencies successfully precompiled in 234 seconds. 49 already precompiled.
Precompiling UnsafeAtomicsLLVM...
   1955.3 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
  1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
   7068.1 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   6977.2 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   1227.7 ms  ✓ LuxLib → LuxLibEnzymeExt
  18281.2 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
  18530.2 ms  ✓ Enzyme → EnzymeStaticArraysExt
  5 dependencies successfully precompiled in 19 seconds. 129 already precompiled.
Precompiling LuxEnzymeExt...
   7819.7 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 8 seconds. 149 already precompiled.
Precompiling HTTPExt...
   1762.6 ms  ✓ FileIO → HTTPExt
  1 dependency successfully precompiled in 2 seconds. 43 already precompiled.
Precompiling OptimisersReactantExt...
  20483.2 ms  ✓ Reactant → ReactantStatisticsExt
  23872.7 ms  ✓ Optimisers → OptimisersReactantExt
  2 dependencies successfully precompiled in 24 seconds. 88 already precompiled.
Precompiling LuxCoreReactantExt...
  20317.4 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 21 seconds. 85 already precompiled.
Precompiling MLDataDevicesReactantExt...
  20382.8 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 21 seconds. 82 already precompiled.
Precompiling LuxLibReactantExt...
  20958.9 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  21947.8 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  22087.4 ms  ✓ LuxLib → LuxLibReactantExt
  20087.7 ms  ✓ Reactant → ReactantArrayInterfaceExt
  4 dependencies successfully precompiled in 41 seconds. 158 already precompiled.
Precompiling WeightInitializersReactantExt...
  20969.1 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  1 dependency successfully precompiled in 21 seconds. 96 already precompiled.
Precompiling ReactantNNlibExt...
  23875.0 ms  ✓ Reactant → ReactantNNlibExt
  1 dependency successfully precompiled in 24 seconds. 103 already precompiled.
Precompiling LuxReactantExt...
  12862.2 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 13 seconds. 180 already precompiled.

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
        d in data[1:(dataset_size ÷ 2)]
    ]
    anticlockwise_spirals = [
        reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
        d in data[((dataset_size ÷ 2) + 1):end]
    ]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
        ),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
    )
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
    )
end
Main.var"##230".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
    x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = dev(get_dataloaders())

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        Reactant.with_config(;
            dot_general_precision=PrecisionConfig.HIGH,
            convolution_precision=PrecisionConfig.HIGH,
        ) do
            @compile model(first(train_loader)[1], ps, Lux.testmode(st))
        end
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (
            total_acc / total_samples
        )
    end

    return cpu_device()((train_state.parameters, train_state.states))
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-07-01 14:27:41.804953: I external/xla/xla/service/service.cc:153] XLA service 0x372cef10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-01 14:27:41.804985: I external/xla/xla/service/service.cc:161]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1751380061.805741 2437862 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1751380061.805803 2437862 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1751380061.805850 2437862 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1751380061.819117 2437862 cuda_dnn.cc:471] Loaded cuDNN version 90800
Epoch [  1]: Loss 0.64281
Validation:	Loss 0.57897	Accuracy 1.00000
Epoch [  2]: Loss 0.53634
Validation:	Loss 0.47932	Accuracy 1.00000
Epoch [  3]: Loss 0.43283
Validation:	Loss 0.39715	Accuracy 1.00000
Epoch [  4]: Loss 0.35946
Validation:	Loss 0.34208	Accuracy 1.00000
Epoch [  5]: Loss 0.30770
Validation:	Loss 0.29286	Accuracy 1.00000
Epoch [  6]: Loss 0.26009
Validation:	Loss 0.24528	Accuracy 1.00000
Epoch [  7]: Loss 0.21348
Validation:	Loss 0.20193	Accuracy 1.00000
Epoch [  8]: Loss 0.17343
Validation:	Loss 0.16440	Accuracy 1.00000
Epoch [  9]: Loss 0.14148
Validation:	Loss 0.13377	Accuracy 1.00000
Epoch [ 10]: Loss 0.11473
Validation:	Loss 0.10965	Accuracy 1.00000
Epoch [ 11]: Loss 0.09437
Validation:	Loss 0.09058	Accuracy 1.00000
Epoch [ 12]: Loss 0.07801
Validation:	Loss 0.07556	Accuracy 1.00000
Epoch [ 13]: Loss 0.06520
Validation:	Loss 0.06351	Accuracy 1.00000
Epoch [ 14]: Loss 0.05438
Validation:	Loss 0.05357	Accuracy 1.00000
Epoch [ 15]: Loss 0.04622
Validation:	Loss 0.04470	Accuracy 1.00000
Epoch [ 16]: Loss 0.03857
Validation:	Loss 0.03665	Accuracy 1.00000
Epoch [ 17]: Loss 0.03129
Validation:	Loss 0.02887	Accuracy 1.00000
Epoch [ 18]: Loss 0.02439
Validation:	Loss 0.02136	Accuracy 1.00000
Epoch [ 19]: Loss 0.01795
Validation:	Loss 0.01585	Accuracy 1.00000
Epoch [ 20]: Loss 0.01390
Validation:	Loss 0.01265	Accuracy 1.00000
Epoch [ 21]: Loss 0.01142
Validation:	Loss 0.01061	Accuracy 1.00000
Epoch [ 22]: Loss 0.00976
Validation:	Loss 0.00908	Accuracy 1.00000
Epoch [ 23]: Loss 0.00841
Validation:	Loss 0.00786	Accuracy 1.00000
Epoch [ 24]: Loss 0.00735
Validation:	Loss 0.00684	Accuracy 1.00000
Epoch [ 25]: Loss 0.00644
Validation:	Loss 0.00596	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.54187
Validation:	Loss 0.39127	Accuracy 1.00000
Epoch [  2]: Loss 0.34422
Validation:	Loss 0.28346	Accuracy 1.00000
Epoch [  3]: Loss 0.25743
Validation:	Loss 0.22008	Accuracy 1.00000
Epoch [  4]: Loss 0.20414
Validation:	Loss 0.17669	Accuracy 1.00000
Epoch [  5]: Loss 0.16354
Validation:	Loss 0.14113	Accuracy 1.00000
Epoch [  6]: Loss 0.12969
Validation:	Loss 0.11079	Accuracy 1.00000
Epoch [  7]: Loss 0.10071
Validation:	Loss 0.08574	Accuracy 1.00000
Epoch [  8]: Loss 0.07791
Validation:	Loss 0.06685	Accuracy 1.00000
Epoch [  9]: Loss 0.06130
Validation:	Loss 0.05342	Accuracy 1.00000
Epoch [ 10]: Loss 0.04940
Validation:	Loss 0.04357	Accuracy 1.00000
Epoch [ 11]: Loss 0.04040
Validation:	Loss 0.03604	Accuracy 1.00000
Epoch [ 12]: Loss 0.03347
Validation:	Loss 0.03002	Accuracy 1.00000
Epoch [ 13]: Loss 0.02783
Validation:	Loss 0.02507	Accuracy 1.00000
Epoch [ 14]: Loss 0.02329
Validation:	Loss 0.02113	Accuracy 1.00000
Epoch [ 15]: Loss 0.01966
Validation:	Loss 0.01810	Accuracy 1.00000
Epoch [ 16]: Loss 0.01701
Validation:	Loss 0.01578	Accuracy 1.00000
Epoch [ 17]: Loss 0.01488
Validation:	Loss 0.01397	Accuracy 1.00000
Epoch [ 18]: Loss 0.01328
Validation:	Loss 0.01252	Accuracy 1.00000
Epoch [ 19]: Loss 0.01193
Validation:	Loss 0.01129	Accuracy 1.00000
Epoch [ 20]: Loss 0.01076
Validation:	Loss 0.01024	Accuracy 1.00000
Epoch [ 21]: Loss 0.00981
Validation:	Loss 0.00932	Accuracy 1.00000
Epoch [ 22]: Loss 0.00894
Validation:	Loss 0.00850	Accuracy 1.00000
Epoch [ 23]: Loss 0.00815
Validation:	Loss 0.00775	Accuracy 1.00000
Epoch [ 24]: Loss 0.00744
Validation:	Loss 0.00709	Accuracy 1.00000
Epoch [ 25]: Loss 0.00681
Validation:	Loss 0.00649	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.