Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling ADTypes...
    666.2 ms  ✓ ADTypes
  1 dependency successfully precompiled in 1 seconds
Precompiling Lux...
    347.2 ms  ✓ ConcreteStructs
    357.7 ms  ✓ Future
    404.5 ms  ✓ OpenLibm_jll
    424.8 ms  ✓ CEnum
    422.2 ms  ✓ ArgCheck
    544.7 ms  ✓ Statistics
    403.6 ms  ✓ ManualMemory
   1906.1 ms  ✓ UnsafeAtomics
    358.3 ms  ✓ Reexport
    357.0 ms  ✓ SIMDTypes
    577.2 ms  ✓ EnzymeCore
    335.2 ms  ✓ IfElse
   1142.6 ms  ✓ IrrationalConstants
    361.9 ms  ✓ CommonWorldInvalidations
    347.1 ms  ✓ FastClosures
   2532.5 ms  ✓ MacroTools
    395.7 ms  ✓ StaticArraysCore
    398.0 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    552.9 ms  ✓ ArrayInterface
    505.7 ms  ✓ GPUArraysCore
    665.8 ms  ✓ CpuId
    663.7 ms  ✓ DocStringExtensions
    529.5 ms  ✓ JLLWrappers
    496.4 ms  ✓ NaNMath
    859.7 ms  ✓ ThreadingUtilities
   1301.1 ms  ✓ ChainRulesCore
    521.1 ms  ✓ Atomix
    388.8 ms  ✓ EnzymeCore → AdaptExt
    400.1 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    792.2 ms  ✓ Static
    835.8 ms  ✓ CommonSubexpressions
    453.3 ms  ✓ DiffResults
   1778.5 ms  ✓ DispatchDoctor
    383.5 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
   1626.5 ms  ✓ Setfield
    392.7 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    633.1 ms  ✓ LogExpFunctions
    629.7 ms  ✓ Hwloc_jll
    438.9 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    656.5 ms  ✓ OpenSpecFun_jll
    726.8 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   1283.5 ms  ✓ Optimisers
    420.0 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    415.9 ms  ✓ BitTwiddlingConvenienceFunctions
   1576.9 ms  ✓ StaticArrayInterface
   1227.4 ms  ✓ CPUSummary
    441.5 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    672.3 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
   1244.5 ms  ✓ LuxCore
   7792.6 ms  ✓ StaticArrays
   1384.2 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    555.9 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    560.5 ms  ✓ Optimisers → OptimisersAdaptExt
    507.4 ms  ✓ CloseOpenIntervals
   2199.4 ms  ✓ Hwloc
    615.0 ms  ✓ LayoutPointers
    624.3 ms  ✓ PolyesterWeave
   2746.6 ms  ✓ SpecialFunctions
    646.4 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
    479.0 ms  ✓ LuxCore → LuxCoreFunctorsExt
    459.6 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    480.9 ms  ✓ LuxCore → LuxCoreSetfieldExt
    493.6 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    666.3 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    635.7 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    654.8 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    631.6 ms  ✓ Adapt → AdaptStaticArraysExt
    707.7 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    950.8 ms  ✓ StrideArraysCore
    644.3 ms  ✓ DiffRules
   1806.7 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
    731.3 ms  ✓ Polyester
   2801.2 ms  ✓ WeightInitializers
    960.5 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   4259.9 ms  ✓ KernelAbstractions
    675.4 ms  ✓ KernelAbstractions → LinearAlgebraExt
    741.1 ms  ✓ KernelAbstractions → EnzymeExt
   3863.5 ms  ✓ ForwardDiff
    863.2 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   5779.5 ms  ✓ NNlib
    856.5 ms  ✓ NNlib → NNlibEnzymeCoreExt
    912.4 ms  ✓ NNlib → NNlibSpecialFunctionsExt
   1037.1 ms  ✓ NNlib → NNlibForwardDiffExt
   6341.1 ms  ✓ LuxLib
   9838.2 ms  ✓ Lux
  85 dependencies successfully precompiled in 49 seconds. 25 already precompiled.
Precompiling JLD2...
    413.1 ms  ✓ Zlib_jll
    554.9 ms  ✓ TranscodingStreams
    571.1 ms  ✓ OrderedCollections
   4226.6 ms  ✓ FileIO
  32656.6 ms  ✓ JLD2
  5 dependencies successfully precompiled in 37 seconds. 27 already precompiled.
Precompiling MLUtils...
    331.6 ms  ✓ IteratorInterfaceExtensions
    409.7 ms  ✓ StatsAPI
    447.6 ms  ✓ InverseFunctions
    872.8 ms  ✓ InitialValues
    603.0 ms  ✓ Serialization
    398.4 ms  ✓ PrettyPrint
    454.7 ms  ✓ ShowCases
    460.5 ms  ✓ SuiteSparse_jll
    344.1 ms  ✓ DataValueInterfaces
    374.3 ms  ✓ CompositionsBase
    367.9 ms  ✓ PtrArrays
    363.8 ms  ✓ DefineSingletons
    468.8 ms  ✓ DelimitedFiles
    383.9 ms  ✓ DataAPI
   1062.1 ms  ✓ Baselet
   1191.7 ms  ✓ SimpleTraits
    440.1 ms  ✓ ContextVariablesX
    375.3 ms  ✓ TableTraits
    444.4 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    495.4 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   1759.5 ms  ✓ DataStructures
   1965.3 ms  ✓ Distributed
    418.3 ms  ✓ NameResolution
   3840.3 ms  ✓ Test
    422.1 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    476.8 ms  ✓ AliasTables
    452.2 ms  ✓ Missings
    640.1 ms  ✓ FLoopsBase
   3808.4 ms  ✓ SparseArrays
    529.2 ms  ✓ SortingAlgorithms
    875.6 ms  ✓ Tables
    618.3 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   1323.3 ms  ✓ SplittablesBase
    704.0 ms  ✓ Statistics → SparseArraysExt
    643.1 ms  ✓ Adapt → AdaptSparseArraysExt
   2458.5 ms  ✓ Accessors
    723.3 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
    991.1 ms  ✓ KernelAbstractions → SparseArraysExt
   1132.5 ms  ✓ MLCore
    961.5 ms  ✓ Accessors → LinearAlgebraExt
    749.2 ms  ✓ Accessors → TestExt
   2546.5 ms  ✓ StatsBase
    700.0 ms  ✓ Accessors → StaticArraysExt
    949.5 ms  ✓ BangBang
  18179.7 ms  ✓ MLStyle
    748.3 ms  ✓ BangBang → BangBangChainRulesCoreExt
    752.9 ms  ✓ BangBang → BangBangStaticArraysExt
    599.1 ms  ✓ BangBang → BangBangTablesExt
   1195.6 ms  ✓ MicroCollections
   3000.5 ms  ✓ Transducers
   4865.9 ms  ✓ JuliaVariables
    709.6 ms  ✓ Transducers → TransducersAdaptExt
   5415.7 ms  ✓ FLoops
   5900.1 ms  ✓ MLUtils
  54 dependencies successfully precompiled in 37 seconds. 46 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
    640.7 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    683.3 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1635.3 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 104 already precompiled.
Precompiling LuxMLUtilsExt...
   2202.4 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 170 already precompiled.
Precompiling Reactant...
    369.5 ms  ✓ EnumX
    408.1 ms  ✓ ExprTools
    692.8 ms  ✓ ExpressionExplorer
    400.2 ms  ✓ StructIO
    431.3 ms  ✓ Scratch
    635.7 ms  ✓ LLVMOpenMP_jll
   1192.3 ms  ✓ CUDA_Driver_jll
   1087.0 ms  ✓ LazyArtifacts
    640.7 ms  ✓ ReactantCore
   1451.7 ms  ✓ Enzyme_jll
   2061.4 ms  ✓ ObjectFile
   2763.1 ms  ✓ TimerOutputs
   1499.6 ms  ✓ LLVMExtra_jll
   2312.3 ms  ✓ Reactant_jll
   6340.1 ms  ✓ LLVM
  26510.6 ms  ✓ GPUCompiler
 218633.7 ms  ✓ Enzyme
   5760.2 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
  71179.4 ms  ✓ Reactant
  19 dependencies successfully precompiled in 334 seconds. 45 already precompiled.
Precompiling UnsafeAtomicsLLVM...
   1874.9 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
  1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
   6106.9 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   8491.3 ms  ✓ Enzyme → EnzymeStaticArraysExt
   1424.9 ms  ✓ LuxLib → LuxLibEnzymeExt
  11519.1 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   6064.8 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
  5 dependencies successfully precompiled in 13 seconds. 128 already precompiled.
Precompiling LuxEnzymeExt...
   6827.0 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling LuxCoreReactantExt...
  12020.6 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 12 seconds. 69 already precompiled.
Precompiling MLDataDevicesReactantExt...
  12407.5 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 13 seconds. 66 already precompiled.
Precompiling LuxLibReactantExt...
  12417.0 ms  ✓ Reactant → ReactantStatisticsExt
  13301.7 ms  ✓ LuxLib → LuxLibReactantExt
  13591.9 ms  ✓ Reactant → ReactantNNlibExt
  13081.1 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  11952.3 ms  ✓ Reactant → ReactantArrayInterfaceExt
  12539.3 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  6 dependencies successfully precompiled in 26 seconds. 143 already precompiled.
Precompiling WeightInitializersReactantExt...
  12070.1 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  1 dependency successfully precompiled in 12 seconds. 80 already precompiled.
Precompiling LuxReactantExt...
  10345.6 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 11 seconds. 166 already precompiled.

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size = 1000, sequence_length = 50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
            for d in data[1:(dataset_size ÷ 2)]
    ]
    anticlockwise_spirals = [
        reshape(
                d[1][:, (sequence_length + 1):end], :, sequence_length, 1
            )
            for d in data[((dataset_size ÷ 2) + 1):end]
    ]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims = 3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at = 0.8, shuffle = true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize = 128, shuffle = true, partial = false
        ),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize = 128, shuffle = false, partial = false),
    )
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
    )
end
Main.var"##230".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple
    ) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier = st_classifier, lstm_cell = st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred = ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() |> dev

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = Lux.setup(Random.default_rng(), model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        @compile model(first(train_loader)[1], ps, Lux.testmode(st))
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (total_acc / total_samples)
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-03-11 22:46:33.618744: I external/xla/xla/service/service.cc:152] XLA service 0x8941cc0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-11 22:46:33.619371: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741733193.621299  999918 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741733193.621452  999918 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741733193.621823  999918 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741733193.640362  999918 cuda_dnn.cc:529] Loaded cuDNN version 90400
Epoch [  1]: Loss 0.74438
Validation:	Loss 0.68827	Accuracy 0.50000
Epoch [  2]: Loss 0.64720
Validation:	Loss 0.58855	Accuracy 0.74219
Epoch [  3]: Loss 0.54445
Validation:	Loss 0.48560	Accuracy 1.00000
Epoch [  4]: Loss 0.44412
Validation:	Loss 0.39013	Accuracy 1.00000
Epoch [  5]: Loss 0.36092
Validation:	Loss 0.32314	Accuracy 1.00000
Epoch [  6]: Loss 0.30250
Validation:	Loss 0.27289	Accuracy 1.00000
Epoch [  7]: Loss 0.25654
Validation:	Loss 0.23241	Accuracy 1.00000
Epoch [  8]: Loss 0.21793
Validation:	Loss 0.19780	Accuracy 1.00000
Epoch [  9]: Loss 0.18569
Validation:	Loss 0.16624	Accuracy 1.00000
Epoch [ 10]: Loss 0.15337
Validation:	Loss 0.13069	Accuracy 1.00000
Epoch [ 11]: Loss 0.11425
Validation:	Loss 0.09165	Accuracy 1.00000
Epoch [ 12]: Loss 0.07697
Validation:	Loss 0.05846	Accuracy 1.00000
Epoch [ 13]: Loss 0.04897
Validation:	Loss 0.03708	Accuracy 1.00000
Epoch [ 14]: Loss 0.03152
Validation:	Loss 0.02527	Accuracy 1.00000
Epoch [ 15]: Loss 0.02241
Validation:	Loss 0.01909	Accuracy 1.00000
Epoch [ 16]: Loss 0.01754
Validation:	Loss 0.01565	Accuracy 1.00000
Epoch [ 17]: Loss 0.01460
Validation:	Loss 0.01329	Accuracy 1.00000
Epoch [ 18]: Loss 0.01250
Validation:	Loss 0.01154	Accuracy 1.00000
Epoch [ 19]: Loss 0.01090
Validation:	Loss 0.01009	Accuracy 1.00000
Epoch [ 20]: Loss 0.00951
Validation:	Loss 0.00878	Accuracy 1.00000
Epoch [ 21]: Loss 0.00828
Validation:	Loss 0.00768	Accuracy 1.00000
Epoch [ 22]: Loss 0.00728
Validation:	Loss 0.00691	Accuracy 1.00000
Epoch [ 23]: Loss 0.00666
Validation:	Loss 0.00637	Accuracy 1.00000
Epoch [ 24]: Loss 0.00617
Validation:	Loss 0.00595	Accuracy 1.00000
Epoch [ 25]: Loss 0.00580
Validation:	Loss 0.00558	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.69127
Validation:	Loss 0.62883	Accuracy 1.00000
Epoch [  2]: Loss 0.59773
Validation:	Loss 0.55691	Accuracy 1.00000
Epoch [  3]: Loss 0.53535
Validation:	Loss 0.50087	Accuracy 1.00000
Epoch [  4]: Loss 0.47979
Validation:	Loss 0.45229	Accuracy 1.00000
Epoch [  5]: Loss 0.43386
Validation:	Loss 0.40632	Accuracy 1.00000
Epoch [  6]: Loss 0.38648
Validation:	Loss 0.35964	Accuracy 1.00000
Epoch [  7]: Loss 0.33984
Validation:	Loss 0.30956	Accuracy 1.00000
Epoch [  8]: Loss 0.28701
Validation:	Loss 0.25509	Accuracy 1.00000
Epoch [  9]: Loss 0.23092
Validation:	Loss 0.19590	Accuracy 1.00000
Epoch [ 10]: Loss 0.17012
Validation:	Loss 0.13502	Accuracy 1.00000
Epoch [ 11]: Loss 0.11352
Validation:	Loss 0.08692	Accuracy 1.00000
Epoch [ 12]: Loss 0.07430
Validation:	Loss 0.06068	Accuracy 1.00000
Epoch [ 13]: Loss 0.05426
Validation:	Loss 0.04693	Accuracy 1.00000
Epoch [ 14]: Loss 0.04305
Validation:	Loss 0.03859	Accuracy 1.00000
Epoch [ 15]: Loss 0.03607
Validation:	Loss 0.03279	Accuracy 1.00000
Epoch [ 16]: Loss 0.03084
Validation:	Loss 0.02864	Accuracy 1.00000
Epoch [ 17]: Loss 0.02733
Validation:	Loss 0.02545	Accuracy 1.00000
Epoch [ 18]: Loss 0.02431
Validation:	Loss 0.02293	Accuracy 1.00000
Epoch [ 19]: Loss 0.02205
Validation:	Loss 0.02087	Accuracy 1.00000
Epoch [ 20]: Loss 0.02008
Validation:	Loss 0.01913	Accuracy 1.00000
Epoch [ 21]: Loss 0.01852
Validation:	Loss 0.01766	Accuracy 1.00000
Epoch [ 22]: Loss 0.01712
Validation:	Loss 0.01638	Accuracy 1.00000
Epoch [ 23]: Loss 0.01588
Validation:	Loss 0.01527	Accuracy 1.00000
Epoch [ 24]: Loss 0.01479
Validation:	Loss 0.01428	Accuracy 1.00000
Epoch [ 25]: Loss 0.01393
Validation:	Loss 0.01341	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.