Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote
.
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling ADTypes...
687.9 ms ✓ ADTypes
1 dependency successfully precompiled in 1 seconds
Precompiling Lux...
331.8 ms ✓ Future
349.1 ms ✓ ConcreteStructs
366.0 ms ✓ OpenLibm_jll
488.0 ms ✓ Statistics
350.5 ms ✓ ArgCheck
435.5 ms ✓ CompilerSupportLibraries_jll
358.2 ms ✓ ManualMemory
1705.2 ms ✓ UnsafeAtomics
307.2 ms ✓ Reexport
279.8 ms ✓ SIMDTypes
500.4 ms ✓ DocStringExtensions
356.4 ms ✓ HashArrayMappedTries
1100.5 ms ✓ IrrationalConstants
538.2 ms ✓ EnzymeCore
2304.9 ms ✓ MacroTools
341.0 ms ✓ IfElse
314.6 ms ✓ CommonWorldInvalidations
418.2 ms ✓ ConstructionBase
445.8 ms ✓ Adapt
308.1 ms ✓ FastClosures
389.5 ms ✓ StaticArraysCore
624.6 ms ✓ CpuId
476.0 ms ✓ JLLWrappers
483.5 ms ✓ NaNMath
793.0 ms ✓ ThreadingUtilities
1216.5 ms ✓ ChainRulesCore
548.4 ms ✓ Atomix
352.1 ms ✓ ScopedValues
356.4 ms ✓ ADTypes → ADTypesEnzymeCoreExt
590.4 ms ✓ LogExpFunctions
627.2 ms ✓ CommonSubexpressions
359.3 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
743.7 ms ✓ Static
367.9 ms ✓ ADTypes → ADTypesConstructionBaseExt
500.5 ms ✓ ArrayInterface
1502.4 ms ✓ DispatchDoctor
426.2 ms ✓ GPUArraysCore
349.3 ms ✓ EnzymeCore → AdaptExt
369.8 ms ✓ DiffResults
379.2 ms ✓ ADTypes → ADTypesChainRulesCoreExt
585.6 ms ✓ OpenSpecFun_jll
573.8 ms ✓ Functors
1245.2 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
408.6 ms ✓ BitTwiddlingConvenienceFunctions
1445.9 ms ✓ Setfield
1006.5 ms ✓ CPUSummary
390.9 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
1225.0 ms ✓ StaticArrayInterface
374.6 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
423.6 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
615.7 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
349.6 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
1285.0 ms ✓ LuxCore
784.7 ms ✓ MLDataDevices
7338.3 ms ✓ StaticArrays
2554.1 ms ✓ SpecialFunctions
1279.2 ms ✓ Optimisers
655.3 ms ✓ PolyesterWeave
440.7 ms ✓ CloseOpenIntervals
576.4 ms ✓ LayoutPointers
449.3 ms ✓ LuxCore → LuxCoreFunctorsExt
580.4 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
425.9 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
441.3 ms ✓ LuxCore → LuxCoreSetfieldExt
483.4 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
611.6 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
688.1 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
598.5 ms ✓ StaticArrays → StaticArraysStatisticsExt
602.0 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
667.1 ms ✓ Adapt → AdaptStaticArraysExt
664.4 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
612.0 ms ✓ DiffRules
438.0 ms ✓ Optimisers → OptimisersEnzymeCoreExt
1595.9 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
432.8 ms ✓ Optimisers → OptimisersAdaptExt
870.3 ms ✓ StrideArraysCore
2640.3 ms ✓ WeightInitializers
720.7 ms ✓ Polyester
911.2 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
4143.7 ms ✓ KernelAbstractions
3734.5 ms ✓ ForwardDiff
732.5 ms ✓ KernelAbstractions → LinearAlgebraExt
756.4 ms ✓ KernelAbstractions → EnzymeExt
892.4 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
5467.8 ms ✓ NNlib
874.0 ms ✓ NNlib → NNlibEnzymeCoreExt
876.0 ms ✓ NNlib → NNlibSpecialFunctionsExt
922.2 ms ✓ NNlib → NNlibForwardDiffExt
5559.4 ms ✓ LuxLib
9575.2 ms ✓ Lux
90 dependencies successfully precompiled in 46 seconds. 15 already precompiled.
Precompiling JLD2...
364.5 ms ✓ Zlib_jll
541.5 ms ✓ TranscodingStreams
602.0 ms ✓ OrderedCollections
4655.3 ms ✓ FileIO
34509.9 ms ✓ JLD2
5 dependencies successfully precompiled in 40 seconds. 29 already precompiled.
Precompiling MLUtils...
445.5 ms ✓ DelimitedFiles
450.2 ms ✓ SuiteSparse_jll
559.1 ms ✓ Serialization
381.2 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
388.5 ms ✓ InverseFunctions → InverseFunctionsDatesExt
803.8 ms ✓ Tables
405.9 ms ✓ ContextVariablesX
1049.2 ms ✓ SimpleTraits
1707.1 ms ✓ DataStructures
1907.2 ms ✓ Distributed
3733.6 ms ✓ SparseArrays
2427.6 ms ✓ Accessors
619.7 ms ✓ FLoopsBase
974.7 ms ✓ MLCore
484.8 ms ✓ SortingAlgorithms
3933.3 ms ✓ Test
616.4 ms ✓ Statistics → SparseArraysExt
614.9 ms ✓ Adapt → AdaptSparseArraysExt
643.1 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
805.5 ms ✓ Accessors → LinearAlgebraExt
670.2 ms ✓ Accessors → StaticArraysExt
934.1 ms ✓ KernelAbstractions → SparseArraysExt
598.6 ms ✓ InverseFunctions → InverseFunctionsTestExt
629.4 ms ✓ Accessors → TestExt
1193.2 ms ✓ SplittablesBase
854.1 ms ✓ BangBang
478.4 ms ✓ BangBang → BangBangChainRulesCoreExt
696.8 ms ✓ BangBang → BangBangStaticArraysExt
477.5 ms ✓ BangBang → BangBangTablesExt
2248.5 ms ✓ StatsBase
946.9 ms ✓ MicroCollections
2705.9 ms ✓ Transducers
655.1 ms ✓ Transducers → TransducersAdaptExt
5000.5 ms ✓ FLoops
5874.4 ms ✓ MLUtils
35 dependencies successfully precompiled in 25 seconds. 62 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
616.9 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
668.4 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1514.9 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 101 already precompiled.
Precompiling LuxMLUtilsExt...
2082.2 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 164 already precompiled.
Precompiling Reactant...
369.8 ms ✓ ExprTools
372.3 ms ✓ CEnum
651.9 ms ✓ ExpressionExplorer
355.8 ms ✓ EnumX
361.7 ms ✓ StructIO
327.3 ms ✓ SimpleBufferStream
358.4 ms ✓ BitFlags
464.4 ms ✓ CodecZlib
418.5 ms ✓ Scratch
703.1 ms ✓ ConcurrentUtilities
973.9 ms ✓ MbedTLS
516.2 ms ✓ LoggingExtras
469.7 ms ✓ ExceptionUnwrapping
607.8 ms ✓ OpenSSL_jll
639.0 ms ✓ LLVMOpenMP_jll
586.1 ms ✓ LibTracyClient_jll
1593.7 ms ✓ CUDA_Driver_jll
647.2 ms ✓ ReactantCore
1137.1 ms ✓ LazyArtifacts
879.2 ms ✓ Tracy
2075.3 ms ✓ ObjectFile
2047.6 ms ✓ OpenSSL
1415.1 ms ✓ Enzyme_jll
1435.8 ms ✓ LLVMExtra_jll
2835.2 ms ✓ Reactant_jll
6536.0 ms ✓ LLVM
19078.5 ms ✓ HTTP
29069.3 ms ✓ GPUCompiler
92753.1 ms ✓ Enzyme
6639.6 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
91529.4 ms ✓ Reactant
31 dependencies successfully precompiled in 234 seconds. 49 already precompiled.
Precompiling UnsafeAtomicsLLVM...
1955.3 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
7068.1 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
6977.2 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
1227.7 ms ✓ LuxLib → LuxLibEnzymeExt
18281.2 ms ✓ Enzyme → EnzymeChainRulesCoreExt
18530.2 ms ✓ Enzyme → EnzymeStaticArraysExt
5 dependencies successfully precompiled in 19 seconds. 129 already precompiled.
Precompiling LuxEnzymeExt...
7819.7 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 8 seconds. 149 already precompiled.
Precompiling HTTPExt...
1762.6 ms ✓ FileIO → HTTPExt
1 dependency successfully precompiled in 2 seconds. 43 already precompiled.
Precompiling OptimisersReactantExt...
20483.2 ms ✓ Reactant → ReactantStatisticsExt
23872.7 ms ✓ Optimisers → OptimisersReactantExt
2 dependencies successfully precompiled in 24 seconds. 88 already precompiled.
Precompiling LuxCoreReactantExt...
20317.4 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 21 seconds. 85 already precompiled.
Precompiling MLDataDevicesReactantExt...
20382.8 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 21 seconds. 82 already precompiled.
Precompiling LuxLibReactantExt...
20958.9 ms ✓ Reactant → ReactantSpecialFunctionsExt
21947.8 ms ✓ Reactant → ReactantKernelAbstractionsExt
22087.4 ms ✓ LuxLib → LuxLibReactantExt
20087.7 ms ✓ Reactant → ReactantArrayInterfaceExt
4 dependencies successfully precompiled in 41 seconds. 158 already precompiled.
Precompiling WeightInitializersReactantExt...
20969.1 ms ✓ WeightInitializers → WeightInitializersReactantExt
1 dependency successfully precompiled in 21 seconds. 96 already precompiled.
Precompiling ReactantNNlibExt...
23875.0 ms ✓ Reactant → ReactantNNlibExt
1 dependency successfully precompiled in 24 seconds. 103 already precompiled.
Precompiling LuxReactantExt...
12862.2 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 13 seconds. 180 already precompiled.
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [
reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
d in data[1:(dataset_size ÷ 2)]
]
anticlockwise_spirals = [
reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
d in data[((dataset_size ÷ 2) + 1):end]
]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
)
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
)
end
Main.var"##230".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = dev(get_dataloaders())
# Create the model
model = model_type(2, 8, 1)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
end
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (
total_acc / total_samples
)
end
return cpu_device()((train_state.parameters, train_state.states))
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-07-01 14:27:41.804953: I external/xla/xla/service/service.cc:153] XLA service 0x372cef10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-01 14:27:41.804985: I external/xla/xla/service/service.cc:161] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1751380061.805741 2437862 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1751380061.805803 2437862 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1751380061.805850 2437862 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1751380061.819117 2437862 cuda_dnn.cc:471] Loaded cuDNN version 90800
Epoch [ 1]: Loss 0.64281
Validation: Loss 0.57897 Accuracy 1.00000
Epoch [ 2]: Loss 0.53634
Validation: Loss 0.47932 Accuracy 1.00000
Epoch [ 3]: Loss 0.43283
Validation: Loss 0.39715 Accuracy 1.00000
Epoch [ 4]: Loss 0.35946
Validation: Loss 0.34208 Accuracy 1.00000
Epoch [ 5]: Loss 0.30770
Validation: Loss 0.29286 Accuracy 1.00000
Epoch [ 6]: Loss 0.26009
Validation: Loss 0.24528 Accuracy 1.00000
Epoch [ 7]: Loss 0.21348
Validation: Loss 0.20193 Accuracy 1.00000
Epoch [ 8]: Loss 0.17343
Validation: Loss 0.16440 Accuracy 1.00000
Epoch [ 9]: Loss 0.14148
Validation: Loss 0.13377 Accuracy 1.00000
Epoch [ 10]: Loss 0.11473
Validation: Loss 0.10965 Accuracy 1.00000
Epoch [ 11]: Loss 0.09437
Validation: Loss 0.09058 Accuracy 1.00000
Epoch [ 12]: Loss 0.07801
Validation: Loss 0.07556 Accuracy 1.00000
Epoch [ 13]: Loss 0.06520
Validation: Loss 0.06351 Accuracy 1.00000
Epoch [ 14]: Loss 0.05438
Validation: Loss 0.05357 Accuracy 1.00000
Epoch [ 15]: Loss 0.04622
Validation: Loss 0.04470 Accuracy 1.00000
Epoch [ 16]: Loss 0.03857
Validation: Loss 0.03665 Accuracy 1.00000
Epoch [ 17]: Loss 0.03129
Validation: Loss 0.02887 Accuracy 1.00000
Epoch [ 18]: Loss 0.02439
Validation: Loss 0.02136 Accuracy 1.00000
Epoch [ 19]: Loss 0.01795
Validation: Loss 0.01585 Accuracy 1.00000
Epoch [ 20]: Loss 0.01390
Validation: Loss 0.01265 Accuracy 1.00000
Epoch [ 21]: Loss 0.01142
Validation: Loss 0.01061 Accuracy 1.00000
Epoch [ 22]: Loss 0.00976
Validation: Loss 0.00908 Accuracy 1.00000
Epoch [ 23]: Loss 0.00841
Validation: Loss 0.00786 Accuracy 1.00000
Epoch [ 24]: Loss 0.00735
Validation: Loss 0.00684 Accuracy 1.00000
Epoch [ 25]: Loss 0.00644
Validation: Loss 0.00596 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.54187
Validation: Loss 0.39127 Accuracy 1.00000
Epoch [ 2]: Loss 0.34422
Validation: Loss 0.28346 Accuracy 1.00000
Epoch [ 3]: Loss 0.25743
Validation: Loss 0.22008 Accuracy 1.00000
Epoch [ 4]: Loss 0.20414
Validation: Loss 0.17669 Accuracy 1.00000
Epoch [ 5]: Loss 0.16354
Validation: Loss 0.14113 Accuracy 1.00000
Epoch [ 6]: Loss 0.12969
Validation: Loss 0.11079 Accuracy 1.00000
Epoch [ 7]: Loss 0.10071
Validation: Loss 0.08574 Accuracy 1.00000
Epoch [ 8]: Loss 0.07791
Validation: Loss 0.06685 Accuracy 1.00000
Epoch [ 9]: Loss 0.06130
Validation: Loss 0.05342 Accuracy 1.00000
Epoch [ 10]: Loss 0.04940
Validation: Loss 0.04357 Accuracy 1.00000
Epoch [ 11]: Loss 0.04040
Validation: Loss 0.03604 Accuracy 1.00000
Epoch [ 12]: Loss 0.03347
Validation: Loss 0.03002 Accuracy 1.00000
Epoch [ 13]: Loss 0.02783
Validation: Loss 0.02507 Accuracy 1.00000
Epoch [ 14]: Loss 0.02329
Validation: Loss 0.02113 Accuracy 1.00000
Epoch [ 15]: Loss 0.01966
Validation: Loss 0.01810 Accuracy 1.00000
Epoch [ 16]: Loss 0.01701
Validation: Loss 0.01578 Accuracy 1.00000
Epoch [ 17]: Loss 0.01488
Validation: Loss 0.01397 Accuracy 1.00000
Epoch [ 18]: Loss 0.01328
Validation: Loss 0.01252 Accuracy 1.00000
Epoch [ 19]: Loss 0.01193
Validation: Loss 0.01129 Accuracy 1.00000
Epoch [ 20]: Loss 0.01076
Validation: Loss 0.01024 Accuracy 1.00000
Epoch [ 21]: Loss 0.00981
Validation: Loss 0.00932 Accuracy 1.00000
Epoch [ 22]: Loss 0.00894
Validation: Loss 0.00850 Accuracy 1.00000
Epoch [ 23]: Loss 0.00815
Validation: Loss 0.00775 Accuracy 1.00000
Epoch [ 24]: Loss 0.00744
Validation: Loss 0.00709 Accuracy 1.00000
Epoch [ 25]: Loss 0.00681
Validation: Loss 0.00649 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.