Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote
.
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling ADTypes...
666.2 ms ✓ ADTypes
1 dependency successfully precompiled in 1 seconds
Precompiling Lux...
347.2 ms ✓ ConcreteStructs
357.7 ms ✓ Future
404.5 ms ✓ OpenLibm_jll
424.8 ms ✓ CEnum
422.2 ms ✓ ArgCheck
544.7 ms ✓ Statistics
403.6 ms ✓ ManualMemory
1906.1 ms ✓ UnsafeAtomics
358.3 ms ✓ Reexport
357.0 ms ✓ SIMDTypes
577.2 ms ✓ EnzymeCore
335.2 ms ✓ IfElse
1142.6 ms ✓ IrrationalConstants
361.9 ms ✓ CommonWorldInvalidations
347.1 ms ✓ FastClosures
2532.5 ms ✓ MacroTools
395.7 ms ✓ StaticArraysCore
398.0 ms ✓ ADTypes → ADTypesConstructionBaseExt
552.9 ms ✓ ArrayInterface
505.7 ms ✓ GPUArraysCore
665.8 ms ✓ CpuId
663.7 ms ✓ DocStringExtensions
529.5 ms ✓ JLLWrappers
496.4 ms ✓ NaNMath
859.7 ms ✓ ThreadingUtilities
1301.1 ms ✓ ChainRulesCore
521.1 ms ✓ Atomix
388.8 ms ✓ EnzymeCore → AdaptExt
400.1 ms ✓ ADTypes → ADTypesEnzymeCoreExt
792.2 ms ✓ Static
835.8 ms ✓ CommonSubexpressions
453.3 ms ✓ DiffResults
1778.5 ms ✓ DispatchDoctor
383.5 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
1626.5 ms ✓ Setfield
392.7 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
633.1 ms ✓ LogExpFunctions
629.7 ms ✓ Hwloc_jll
438.9 ms ✓ ADTypes → ADTypesChainRulesCoreExt
656.5 ms ✓ OpenSpecFun_jll
726.8 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
1283.5 ms ✓ Optimisers
420.0 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
415.9 ms ✓ BitTwiddlingConvenienceFunctions
1576.9 ms ✓ StaticArrayInterface
1227.4 ms ✓ CPUSummary
441.5 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
672.3 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
1244.5 ms ✓ LuxCore
7792.6 ms ✓ StaticArrays
1384.2 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
555.9 ms ✓ Optimisers → OptimisersEnzymeCoreExt
560.5 ms ✓ Optimisers → OptimisersAdaptExt
507.4 ms ✓ CloseOpenIntervals
2199.4 ms ✓ Hwloc
615.0 ms ✓ LayoutPointers
624.3 ms ✓ PolyesterWeave
2746.6 ms ✓ SpecialFunctions
646.4 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
479.0 ms ✓ LuxCore → LuxCoreFunctorsExt
459.6 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
480.9 ms ✓ LuxCore → LuxCoreSetfieldExt
493.6 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
666.3 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
635.7 ms ✓ StaticArrays → StaticArraysStatisticsExt
654.8 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
631.6 ms ✓ Adapt → AdaptStaticArraysExt
707.7 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
950.8 ms ✓ StrideArraysCore
644.3 ms ✓ DiffRules
1806.7 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
731.3 ms ✓ Polyester
2801.2 ms ✓ WeightInitializers
960.5 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
4259.9 ms ✓ KernelAbstractions
675.4 ms ✓ KernelAbstractions → LinearAlgebraExt
741.1 ms ✓ KernelAbstractions → EnzymeExt
3863.5 ms ✓ ForwardDiff
863.2 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
5779.5 ms ✓ NNlib
856.5 ms ✓ NNlib → NNlibEnzymeCoreExt
912.4 ms ✓ NNlib → NNlibSpecialFunctionsExt
1037.1 ms ✓ NNlib → NNlibForwardDiffExt
6341.1 ms ✓ LuxLib
9838.2 ms ✓ Lux
85 dependencies successfully precompiled in 49 seconds. 25 already precompiled.
Precompiling JLD2...
413.1 ms ✓ Zlib_jll
554.9 ms ✓ TranscodingStreams
571.1 ms ✓ OrderedCollections
4226.6 ms ✓ FileIO
32656.6 ms ✓ JLD2
5 dependencies successfully precompiled in 37 seconds. 27 already precompiled.
Precompiling MLUtils...
331.6 ms ✓ IteratorInterfaceExtensions
409.7 ms ✓ StatsAPI
447.6 ms ✓ InverseFunctions
872.8 ms ✓ InitialValues
603.0 ms ✓ Serialization
398.4 ms ✓ PrettyPrint
454.7 ms ✓ ShowCases
460.5 ms ✓ SuiteSparse_jll
344.1 ms ✓ DataValueInterfaces
374.3 ms ✓ CompositionsBase
367.9 ms ✓ PtrArrays
363.8 ms ✓ DefineSingletons
468.8 ms ✓ DelimitedFiles
383.9 ms ✓ DataAPI
1062.1 ms ✓ Baselet
1191.7 ms ✓ SimpleTraits
440.1 ms ✓ ContextVariablesX
375.3 ms ✓ TableTraits
444.4 ms ✓ InverseFunctions → InverseFunctionsDatesExt
495.4 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
1759.5 ms ✓ DataStructures
1965.3 ms ✓ Distributed
418.3 ms ✓ NameResolution
3840.3 ms ✓ Test
422.1 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
476.8 ms ✓ AliasTables
452.2 ms ✓ Missings
640.1 ms ✓ FLoopsBase
3808.4 ms ✓ SparseArrays
529.2 ms ✓ SortingAlgorithms
875.6 ms ✓ Tables
618.3 ms ✓ InverseFunctions → InverseFunctionsTestExt
1323.3 ms ✓ SplittablesBase
704.0 ms ✓ Statistics → SparseArraysExt
643.1 ms ✓ Adapt → AdaptSparseArraysExt
2458.5 ms ✓ Accessors
723.3 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
991.1 ms ✓ KernelAbstractions → SparseArraysExt
1132.5 ms ✓ MLCore
961.5 ms ✓ Accessors → LinearAlgebraExt
749.2 ms ✓ Accessors → TestExt
2546.5 ms ✓ StatsBase
700.0 ms ✓ Accessors → StaticArraysExt
949.5 ms ✓ BangBang
18179.7 ms ✓ MLStyle
748.3 ms ✓ BangBang → BangBangChainRulesCoreExt
752.9 ms ✓ BangBang → BangBangStaticArraysExt
599.1 ms ✓ BangBang → BangBangTablesExt
1195.6 ms ✓ MicroCollections
3000.5 ms ✓ Transducers
4865.9 ms ✓ JuliaVariables
709.6 ms ✓ Transducers → TransducersAdaptExt
5415.7 ms ✓ FLoops
5900.1 ms ✓ MLUtils
54 dependencies successfully precompiled in 37 seconds. 46 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
640.7 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
683.3 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1635.3 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 104 already precompiled.
Precompiling LuxMLUtilsExt...
2202.4 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 170 already precompiled.
Precompiling Reactant...
369.5 ms ✓ EnumX
408.1 ms ✓ ExprTools
692.8 ms ✓ ExpressionExplorer
400.2 ms ✓ StructIO
431.3 ms ✓ Scratch
635.7 ms ✓ LLVMOpenMP_jll
1192.3 ms ✓ CUDA_Driver_jll
1087.0 ms ✓ LazyArtifacts
640.7 ms ✓ ReactantCore
1451.7 ms ✓ Enzyme_jll
2061.4 ms ✓ ObjectFile
2763.1 ms ✓ TimerOutputs
1499.6 ms ✓ LLVMExtra_jll
2312.3 ms ✓ Reactant_jll
6340.1 ms ✓ LLVM
26510.6 ms ✓ GPUCompiler
218633.7 ms ✓ Enzyme
5760.2 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
71179.4 ms ✓ Reactant
19 dependencies successfully precompiled in 334 seconds. 45 already precompiled.
Precompiling UnsafeAtomicsLLVM...
1874.9 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
6106.9 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
8491.3 ms ✓ Enzyme → EnzymeStaticArraysExt
1424.9 ms ✓ LuxLib → LuxLibEnzymeExt
11519.1 ms ✓ Enzyme → EnzymeChainRulesCoreExt
6064.8 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
5 dependencies successfully precompiled in 13 seconds. 128 already precompiled.
Precompiling LuxEnzymeExt...
6827.0 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling LuxCoreReactantExt...
12020.6 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 12 seconds. 69 already precompiled.
Precompiling MLDataDevicesReactantExt...
12407.5 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 13 seconds. 66 already precompiled.
Precompiling LuxLibReactantExt...
12417.0 ms ✓ Reactant → ReactantStatisticsExt
13301.7 ms ✓ LuxLib → LuxLibReactantExt
13591.9 ms ✓ Reactant → ReactantNNlibExt
13081.1 ms ✓ Reactant → ReactantKernelAbstractionsExt
11952.3 ms ✓ Reactant → ReactantArrayInterfaceExt
12539.3 ms ✓ Reactant → ReactantSpecialFunctionsExt
6 dependencies successfully precompiled in 26 seconds. 143 already precompiled.
Precompiling WeightInitializersReactantExt...
12070.1 ms ✓ WeightInitializers → WeightInitializersReactantExt
1 dependency successfully precompiled in 12 seconds. 80 already precompiled.
Precompiling LuxReactantExt...
10345.6 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 11 seconds. 166 already precompiled.
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size = 1000, sequence_length = 50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [
reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]
]
anticlockwise_spirals = [
reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1
)
for d in data[((dataset_size ÷ 2) + 1):end]
]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims = 3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at = 0.8, shuffle = true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize = 128, shuffle = true, partial = false
),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize = 128, shuffle = false, partial = false),
)
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
)
end
Main.var"##230".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple
) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier = st_classifier, lstm_cell = st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred = ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() |> dev
# Create the model
model = model_type(2, 8, 1)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (total_acc / total_samples)
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-03-11 22:46:33.618744: I external/xla/xla/service/service.cc:152] XLA service 0x8941cc0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-11 22:46:33.619371: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741733193.621299 999918 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741733193.621452 999918 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741733193.621823 999918 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741733193.640362 999918 cuda_dnn.cc:529] Loaded cuDNN version 90400
Epoch [ 1]: Loss 0.74438
Validation: Loss 0.68827 Accuracy 0.50000
Epoch [ 2]: Loss 0.64720
Validation: Loss 0.58855 Accuracy 0.74219
Epoch [ 3]: Loss 0.54445
Validation: Loss 0.48560 Accuracy 1.00000
Epoch [ 4]: Loss 0.44412
Validation: Loss 0.39013 Accuracy 1.00000
Epoch [ 5]: Loss 0.36092
Validation: Loss 0.32314 Accuracy 1.00000
Epoch [ 6]: Loss 0.30250
Validation: Loss 0.27289 Accuracy 1.00000
Epoch [ 7]: Loss 0.25654
Validation: Loss 0.23241 Accuracy 1.00000
Epoch [ 8]: Loss 0.21793
Validation: Loss 0.19780 Accuracy 1.00000
Epoch [ 9]: Loss 0.18569
Validation: Loss 0.16624 Accuracy 1.00000
Epoch [ 10]: Loss 0.15337
Validation: Loss 0.13069 Accuracy 1.00000
Epoch [ 11]: Loss 0.11425
Validation: Loss 0.09165 Accuracy 1.00000
Epoch [ 12]: Loss 0.07697
Validation: Loss 0.05846 Accuracy 1.00000
Epoch [ 13]: Loss 0.04897
Validation: Loss 0.03708 Accuracy 1.00000
Epoch [ 14]: Loss 0.03152
Validation: Loss 0.02527 Accuracy 1.00000
Epoch [ 15]: Loss 0.02241
Validation: Loss 0.01909 Accuracy 1.00000
Epoch [ 16]: Loss 0.01754
Validation: Loss 0.01565 Accuracy 1.00000
Epoch [ 17]: Loss 0.01460
Validation: Loss 0.01329 Accuracy 1.00000
Epoch [ 18]: Loss 0.01250
Validation: Loss 0.01154 Accuracy 1.00000
Epoch [ 19]: Loss 0.01090
Validation: Loss 0.01009 Accuracy 1.00000
Epoch [ 20]: Loss 0.00951
Validation: Loss 0.00878 Accuracy 1.00000
Epoch [ 21]: Loss 0.00828
Validation: Loss 0.00768 Accuracy 1.00000
Epoch [ 22]: Loss 0.00728
Validation: Loss 0.00691 Accuracy 1.00000
Epoch [ 23]: Loss 0.00666
Validation: Loss 0.00637 Accuracy 1.00000
Epoch [ 24]: Loss 0.00617
Validation: Loss 0.00595 Accuracy 1.00000
Epoch [ 25]: Loss 0.00580
Validation: Loss 0.00558 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.69127
Validation: Loss 0.62883 Accuracy 1.00000
Epoch [ 2]: Loss 0.59773
Validation: Loss 0.55691 Accuracy 1.00000
Epoch [ 3]: Loss 0.53535
Validation: Loss 0.50087 Accuracy 1.00000
Epoch [ 4]: Loss 0.47979
Validation: Loss 0.45229 Accuracy 1.00000
Epoch [ 5]: Loss 0.43386
Validation: Loss 0.40632 Accuracy 1.00000
Epoch [ 6]: Loss 0.38648
Validation: Loss 0.35964 Accuracy 1.00000
Epoch [ 7]: Loss 0.33984
Validation: Loss 0.30956 Accuracy 1.00000
Epoch [ 8]: Loss 0.28701
Validation: Loss 0.25509 Accuracy 1.00000
Epoch [ 9]: Loss 0.23092
Validation: Loss 0.19590 Accuracy 1.00000
Epoch [ 10]: Loss 0.17012
Validation: Loss 0.13502 Accuracy 1.00000
Epoch [ 11]: Loss 0.11352
Validation: Loss 0.08692 Accuracy 1.00000
Epoch [ 12]: Loss 0.07430
Validation: Loss 0.06068 Accuracy 1.00000
Epoch [ 13]: Loss 0.05426
Validation: Loss 0.04693 Accuracy 1.00000
Epoch [ 14]: Loss 0.04305
Validation: Loss 0.03859 Accuracy 1.00000
Epoch [ 15]: Loss 0.03607
Validation: Loss 0.03279 Accuracy 1.00000
Epoch [ 16]: Loss 0.03084
Validation: Loss 0.02864 Accuracy 1.00000
Epoch [ 17]: Loss 0.02733
Validation: Loss 0.02545 Accuracy 1.00000
Epoch [ 18]: Loss 0.02431
Validation: Loss 0.02293 Accuracy 1.00000
Epoch [ 19]: Loss 0.02205
Validation: Loss 0.02087 Accuracy 1.00000
Epoch [ 20]: Loss 0.02008
Validation: Loss 0.01913 Accuracy 1.00000
Epoch [ 21]: Loss 0.01852
Validation: Loss 0.01766 Accuracy 1.00000
Epoch [ 22]: Loss 0.01712
Validation: Loss 0.01638 Accuracy 1.00000
Epoch [ 23]: Loss 0.01588
Validation: Loss 0.01527 Accuracy 1.00000
Epoch [ 24]: Loss 0.01479
Validation: Loss 0.01428 Accuracy 1.00000
Epoch [ 25]: Loss 0.01393
Validation: Loss 0.01341 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.