Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote
.
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling ADTypes...
626.8 ms ✓ ADTypes
1 dependency successfully precompiled in 1 seconds
Precompiling Lux...
358.8 ms ✓ Future
382.7 ms ✓ ManualMemory
326.4 ms ✓ Reexport
312.8 ms ✓ SIMDTypes
1691.4 ms ✓ UnsafeAtomics
373.1 ms ✓ HashArrayMappedTries
318.0 ms ✓ IfElse
567.8 ms ✓ EnzymeCore
2378.7 ms ✓ MacroTools
410.0 ms ✓ ConstructionBase
332.0 ms ✓ CommonWorldInvalidations
328.3 ms ✓ FastClosures
360.3 ms ✓ StaticArraysCore
504.3 ms ✓ ArrayInterface
410.1 ms ✓ ADTypes → ADTypesChainRulesCoreExt
347.1 ms ✓ ScopedValues
490.5 ms ✓ Atomix
787.1 ms ✓ ThreadingUtilities
377.1 ms ✓ EnzymeCore → AdaptExt
371.0 ms ✓ ADTypes → ADTypesEnzymeCoreExt
612.1 ms ✓ CommonSubexpressions
368.7 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
379.4 ms ✓ ADTypes → ADTypesConstructionBaseExt
401.9 ms ✓ DiffResults
783.7 ms ✓ Static
1548.9 ms ✓ DispatchDoctor
365.3 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
405.2 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
359.6 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
619.7 ms ✓ Functors
1531.0 ms ✓ Setfield
424.4 ms ✓ BitTwiddlingConvenienceFunctions
1011.7 ms ✓ CPUSummary
3516.4 ms ✓ ForwardDiff
1197.7 ms ✓ StaticArrayInterface
411.5 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
597.5 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
796.0 ms ✓ MLDataDevices
1274.0 ms ✓ LuxCore
644.4 ms ✓ PolyesterWeave
1208.7 ms ✓ Optimisers
7604.3 ms ✓ StaticArrays
443.6 ms ✓ CloseOpenIntervals
578.1 ms ✓ LayoutPointers
632.0 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
574.8 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
443.5 ms ✓ LuxCore → LuxCoreFunctorsExt
431.3 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
436.4 ms ✓ LuxCore → LuxCoreSetfieldExt
446.3 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
430.4 ms ✓ Optimisers → OptimisersEnzymeCoreExt
410.8 ms ✓ Optimisers → OptimisersAdaptExt
646.1 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
589.1 ms ✓ Adapt → AdaptStaticArraysExt
639.7 ms ✓ StaticArrays → StaticArraysStatisticsExt
633.0 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
664.0 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
808.4 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
889.9 ms ✓ StrideArraysCore
730.8 ms ✓ Polyester
4140.9 ms ✓ KernelAbstractions
675.1 ms ✓ KernelAbstractions → LinearAlgebraExt
774.7 ms ✓ KernelAbstractions → EnzymeExt
5414.3 ms ✓ NNlib
857.6 ms ✓ NNlib → NNlibEnzymeCoreExt
899.5 ms ✓ NNlib → NNlibSpecialFunctionsExt
1004.4 ms ✓ NNlib → NNlibForwardDiffExt
5499.0 ms ✓ LuxLib
9420.6 ms ✓ Lux
69 dependencies successfully precompiled in 42 seconds. 36 already precompiled.
Precompiling JLD2...
401.7 ms ✓ Zlib_jll
531.4 ms ✓ TranscodingStreams
584.6 ms ✓ OrderedCollections
33939.4 ms ✓ JLD2
4 dependencies successfully precompiled in 35 seconds. 30 already precompiled.
Precompiling MLUtils...
305.7 ms ✓ IteratorInterfaceExtensions
378.1 ms ✓ StatsAPI
417.2 ms ✓ InverseFunctions
548.1 ms ✓ Serialization
853.0 ms ✓ InitialValues
393.7 ms ✓ PrettyPrint
430.2 ms ✓ ShowCases
473.1 ms ✓ SuiteSparse_jll
314.3 ms ✓ DataValueInterfaces
346.8 ms ✓ CompositionsBase
336.4 ms ✓ PtrArrays
343.8 ms ✓ DefineSingletons
454.1 ms ✓ DelimitedFiles
371.4 ms ✓ DataAPI
1069.1 ms ✓ Baselet
1085.8 ms ✓ SimpleTraits
345.0 ms ✓ TableTraits
403.1 ms ✓ InverseFunctions → InverseFunctionsDatesExt
397.4 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
1743.9 ms ✓ DataStructures
1986.8 ms ✓ Distributed
413.6 ms ✓ NameResolution
3900.3 ms ✓ Test
406.6 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
472.8 ms ✓ AliasTables
462.5 ms ✓ Missings
3715.9 ms ✓ SparseArrays
794.3 ms ✓ Tables
529.4 ms ✓ SortingAlgorithms
627.6 ms ✓ InverseFunctions → InverseFunctionsTestExt
1091.6 ms ✓ SplittablesBase
624.0 ms ✓ Statistics → SparseArraysExt
606.6 ms ✓ Adapt → AdaptSparseArraysExt
642.7 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
2444.5 ms ✓ Accessors
1023.4 ms ✓ KernelAbstractions → SparseArraysExt
1030.5 ms ✓ MLCore
846.2 ms ✓ Accessors → LinearAlgebraExt
637.6 ms ✓ Accessors → TestExt
2241.2 ms ✓ StatsBase
673.1 ms ✓ Accessors → StaticArraysExt
797.9 ms ✓ BangBang
512.5 ms ✓ BangBang → BangBangChainRulesCoreExt
688.4 ms ✓ BangBang → BangBangStaticArraysExt
505.3 ms ✓ BangBang → BangBangTablesExt
17863.1 ms ✓ MLStyle
906.0 ms ✓ MicroCollections
2705.3 ms ✓ Transducers
664.4 ms ✓ Transducers → TransducersAdaptExt
4166.6 ms ✓ JuliaVariables
4951.6 ms ✓ FLoops
5838.7 ms ✓ MLUtils
52 dependencies successfully precompiled in 35 seconds. 45 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
612.6 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
654.6 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1491.4 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 101 already precompiled.
Precompiling LuxMLUtilsExt...
2023.2 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 164 already precompiled.
Precompiling Reactant...
376.9 ms ✓ CEnum
401.2 ms ✓ ExprTools
661.8 ms ✓ ExpressionExplorer
368.0 ms ✓ EnumX
419.8 ms ✓ StructIO
355.4 ms ✓ SimpleBufferStream
610.1 ms ✓ URIs
482.4 ms ✓ CodecZlib
373.3 ms ✓ BitFlags
515.1 ms ✓ ExceptionUnwrapping
694.5 ms ✓ ConcurrentUtilities
609.3 ms ✓ ReactantCore
911.0 ms ✓ Tracy
1929.7 ms ✓ ObjectFile
2007.5 ms ✓ OpenSSL
7141.6 ms ✓ LLVM
18969.0 ms ✓ HTTP
29996.7 ms ✓ GPUCompiler
93755.4 ms ✓ Enzyme
6872.5 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
90113.8 ms ✓ Reactant
21 dependencies successfully precompiled in 229 seconds. 59 already precompiled.
Precompiling UnsafeAtomicsLLVM...
1751.1 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
6751.7 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
7085.8 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
1271.4 ms ✓ LuxLib → LuxLibEnzymeExt
15176.6 ms ✓ Enzyme → EnzymeStaticArraysExt
15627.8 ms ✓ Enzyme → EnzymeChainRulesCoreExt
5 dependencies successfully precompiled in 16 seconds. 129 already precompiled.
Precompiling LuxEnzymeExt...
7818.4 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 8 seconds. 149 already precompiled.
Precompiling HTTPExt...
1782.5 ms ✓ FileIO → HTTPExt
1 dependency successfully precompiled in 2 seconds. 43 already precompiled.
Precompiling OptimisersReactantExt...
18033.4 ms ✓ Reactant → ReactantStatisticsExt
20381.6 ms ✓ Optimisers → OptimisersReactantExt
2 dependencies successfully precompiled in 21 seconds. 88 already precompiled.
Precompiling LuxCoreReactantExt...
17112.3 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 17 seconds. 85 already precompiled.
Precompiling MLDataDevicesReactantExt...
17183.4 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 17 seconds. 82 already precompiled.
Precompiling LuxLibReactantExt...
17338.2 ms ✓ Reactant → ReactantSpecialFunctionsExt
17904.6 ms ✓ Reactant → ReactantKernelAbstractionsExt
18028.3 ms ✓ LuxLib → LuxLibReactantExt
17502.7 ms ✓ Reactant → ReactantArrayInterfaceExt
4 dependencies successfully precompiled in 35 seconds. 158 already precompiled.
Precompiling WeightInitializersReactantExt...
17674.4 ms ✓ WeightInitializers → WeightInitializersReactantExt
1 dependency successfully precompiled in 18 seconds. 96 already precompiled.
Precompiling ReactantNNlibExt...
20176.3 ms ✓ Reactant → ReactantNNlibExt
1 dependency successfully precompiled in 20 seconds. 103 already precompiled.
Precompiling LuxReactantExt...
12817.6 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 13 seconds. 180 already precompiled.
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [
reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
d in data[1:(dataset_size ÷ 2)]
]
anticlockwise_spirals = [
reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
d in data[((dataset_size ÷ 2) + 1):end]
]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
)
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
)
end
Main.var"##230".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = dev(get_dataloaders())
# Create the model
model = model_type(2, 8, 1)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
end
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (
total_acc / total_samples
)
end
return cpu_device()((train_state.parameters, train_state.states))
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-07-02 12:51:16.804447: I external/xla/xla/service/service.cc:153] XLA service 0x27e8e150 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-02 12:51:16.804476: I external/xla/xla/service/service.cc:161] StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1751460676.805063 350135 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1751460676.805122 350135 gpu_helpers.cc:136] XLA backend allocating 12528893952 bytes on device 0 for BFCAllocator.
I0000 00:00:1751460676.805161 350135 gpu_helpers.cc:177] XLA backend will use up to 4176297984 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1751460676.815199 350135 cuda_dnn.cc:471] Loaded cuDNN version 90800
Epoch [ 1]: Loss 0.67839
Validation: Loss 0.46142 Accuracy 1.00000
Epoch [ 2]: Loss 0.41877
Validation: Loss 0.35907 Accuracy 1.00000
Epoch [ 3]: Loss 0.33219
Validation: Loss 0.28885 Accuracy 1.00000
Epoch [ 4]: Loss 0.26561
Validation: Loss 0.22314 Accuracy 1.00000
Epoch [ 5]: Loss 0.19888
Validation: Loss 0.16119 Accuracy 1.00000
Epoch [ 6]: Loss 0.13870
Validation: Loss 0.10635 Accuracy 1.00000
Epoch [ 7]: Loss 0.08912
Validation: Loss 0.06985 Accuracy 1.00000
Epoch [ 8]: Loss 0.06066
Validation: Loss 0.05064 Accuracy 1.00000
Epoch [ 9]: Loss 0.04496
Validation: Loss 0.03898 Accuracy 1.00000
Epoch [ 10]: Loss 0.03523
Validation: Loss 0.03135 Accuracy 1.00000
Epoch [ 11]: Loss 0.02896
Validation: Loss 0.02609 Accuracy 1.00000
Epoch [ 12]: Loss 0.02410
Validation: Loss 0.02226 Accuracy 1.00000
Epoch [ 13]: Loss 0.02077
Validation: Loss 0.01930 Accuracy 1.00000
Epoch [ 14]: Loss 0.01810
Validation: Loss 0.01700 Accuracy 1.00000
Epoch [ 15]: Loss 0.01603
Validation: Loss 0.01526 Accuracy 1.00000
Epoch [ 16]: Loss 0.01450
Validation: Loss 0.01381 Accuracy 1.00000
Epoch [ 17]: Loss 0.01316
Validation: Loss 0.01260 Accuracy 1.00000
Epoch [ 18]: Loss 0.01205
Validation: Loss 0.01156 Accuracy 1.00000
Epoch [ 19]: Loss 0.01107
Validation: Loss 0.01064 Accuracy 1.00000
Epoch [ 20]: Loss 0.01018
Validation: Loss 0.00982 Accuracy 1.00000
Epoch [ 21]: Loss 0.00938
Validation: Loss 0.00910 Accuracy 1.00000
Epoch [ 22]: Loss 0.00874
Validation: Loss 0.00845 Accuracy 1.00000
Epoch [ 23]: Loss 0.00810
Validation: Loss 0.00788 Accuracy 1.00000
Epoch [ 24]: Loss 0.00759
Validation: Loss 0.00738 Accuracy 1.00000
Epoch [ 25]: Loss 0.00712
Validation: Loss 0.00695 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.53395
Validation: Loss 0.49143 Accuracy 1.00000
Epoch [ 2]: Loss 0.45677
Validation: Loss 0.41279 Accuracy 1.00000
Epoch [ 3]: Loss 0.37888
Validation: Loss 0.34094 Accuracy 1.00000
Epoch [ 4]: Loss 0.31205
Validation: Loss 0.28056 Accuracy 1.00000
Epoch [ 5]: Loss 0.25496
Validation: Loss 0.22801 Accuracy 1.00000
Epoch [ 6]: Loss 0.20520
Validation: Loss 0.17914 Accuracy 1.00000
Epoch [ 7]: Loss 0.15835
Validation: Loss 0.13502 Accuracy 1.00000
Epoch [ 8]: Loss 0.11804
Validation: Loss 0.09946 Accuracy 1.00000
Epoch [ 9]: Loss 0.08704
Validation: Loss 0.07297 Accuracy 1.00000
Epoch [ 10]: Loss 0.06454
Validation: Loss 0.05505 Accuracy 1.00000
Epoch [ 11]: Loss 0.04978
Validation: Loss 0.04388 Accuracy 1.00000
Epoch [ 12]: Loss 0.04013
Validation: Loss 0.03564 Accuracy 1.00000
Epoch [ 13]: Loss 0.03234
Validation: Loss 0.02799 Accuracy 1.00000
Epoch [ 14]: Loss 0.02488
Validation: Loss 0.02115 Accuracy 1.00000
Epoch [ 15]: Loss 0.01921
Validation: Loss 0.01693 Accuracy 1.00000
Epoch [ 16]: Loss 0.01554
Validation: Loss 0.01384 Accuracy 1.00000
Epoch [ 17]: Loss 0.01276
Validation: Loss 0.01150 Accuracy 1.00000
Epoch [ 18]: Loss 0.01068
Validation: Loss 0.00974 Accuracy 1.00000
Epoch [ 19]: Loss 0.00912
Validation: Loss 0.00842 Accuracy 1.00000
Epoch [ 20]: Loss 0.00794
Validation: Loss 0.00741 Accuracy 1.00000
Epoch [ 21]: Loss 0.00701
Validation: Loss 0.00660 Accuracy 1.00000
Epoch [ 22]: Loss 0.00628
Validation: Loss 0.00593 Accuracy 1.00000
Epoch [ 23]: Loss 0.00566
Validation: Loss 0.00536 Accuracy 1.00000
Epoch [ 24]: Loss 0.00513
Validation: Loss 0.00486 Accuracy 1.00000
Epoch [ 25]: Loss 0.00467
Validation: Loss 0.00443 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.