Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
Precompiling LuxCUDA...
403.0 ms ✓ InvertedIndices
46076.5 ms ✓ DataFrames
51817.3 ms ✓ CUDA
5243.0 ms ✓ Atomix → AtomixCUDAExt
8696.7 ms ✓ cuDNN
5314.1 ms ✓ LuxCUDA
6 dependencies successfully precompiled in 118 seconds. 94 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
4911.4 ms ✓ ArrayInterface → ArrayInterfaceCUDAExt
1 dependency successfully precompiled in 5 seconds. 101 already precompiled.
Precompiling NNlibCUDAExt...
5334.8 ms ✓ CUDA → ChainRulesCoreExt
5389.5 ms ✓ NNlib → NNlibCUDAExt
2 dependencies successfully precompiled in 6 seconds. 102 already precompiled.
Precompiling MLDataDevicesCUDAExt...
4984.1 ms ✓ MLDataDevices → MLDataDevicesCUDAExt
1 dependency successfully precompiled in 5 seconds. 104 already precompiled.
Precompiling LuxLibCUDAExt...
5230.5 ms ✓ CUDA → SpecialFunctionsExt
5231.2 ms ✓ CUDA → EnzymeCoreExt
5770.6 ms ✓ LuxLib → LuxLibCUDAExt
3 dependencies successfully precompiled in 6 seconds. 167 already precompiled.
Precompiling WeightInitializersCUDAExt...
5013.7 ms ✓ WeightInitializers → WeightInitializersCUDAExt
1 dependency successfully precompiled in 5 seconds. 109 already precompiled.
Precompiling NNlibCUDACUDNNExt...
5384.6 ms ✓ NNlib → NNlibCUDACUDNNExt
1 dependency successfully precompiled in 6 seconds. 106 already precompiled.
Precompiling MLDataDevicescuDNNExt...
5053.0 ms ✓ MLDataDevices → MLDataDevicescuDNNExt
1 dependency successfully precompiled in 5 seconds. 107 already precompiled.
Precompiling LuxLibcuDNNExt...
5827.0 ms ✓ LuxLib → LuxLibcuDNNExt
1 dependency successfully precompiled in 6 seconds. 174 already precompiled.
Precompiling BangBangDataFramesExt...
1618.6 ms ✓ BangBang → BangBangDataFramesExt
1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling TransducersDataFramesExt...
1423.4 ms ✓ Transducers → TransducersDataFramesExt
1 dependency successfully precompiled in 2 seconds. 61 already precompiled.
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##230".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = gpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() .|> dev
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
for epoch in 1:25
# Train the model
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
Epoch [ 1]: Loss 0.61466
Epoch [ 1]: Loss 0.59222
Epoch [ 1]: Loss 0.56538
Epoch [ 1]: Loss 0.53192
Epoch [ 1]: Loss 0.52945
Epoch [ 1]: Loss 0.50069
Epoch [ 1]: Loss 0.49103
Validation: Loss 0.46551 Accuracy 1.00000
Validation: Loss 0.47201 Accuracy 1.00000
Epoch [ 2]: Loss 0.46439
Epoch [ 2]: Loss 0.45687
Epoch [ 2]: Loss 0.44415
Epoch [ 2]: Loss 0.42595
Epoch [ 2]: Loss 0.40914
Epoch [ 2]: Loss 0.39567
Epoch [ 2]: Loss 0.39944
Validation: Loss 0.36794 Accuracy 1.00000
Validation: Loss 0.37547 Accuracy 1.00000
Epoch [ 3]: Loss 0.37290
Epoch [ 3]: Loss 0.36134
Epoch [ 3]: Loss 0.33842
Epoch [ 3]: Loss 0.33959
Epoch [ 3]: Loss 0.32097
Epoch [ 3]: Loss 0.30847
Epoch [ 3]: Loss 0.29734
Validation: Loss 0.28265 Accuracy 1.00000
Validation: Loss 0.29107 Accuracy 1.00000
Epoch [ 4]: Loss 0.30058
Epoch [ 4]: Loss 0.27622
Epoch [ 4]: Loss 0.25626
Epoch [ 4]: Loss 0.24315
Epoch [ 4]: Loss 0.24900
Epoch [ 4]: Loss 0.23140
Epoch [ 4]: Loss 0.25694
Validation: Loss 0.21271 Accuracy 1.00000
Validation: Loss 0.22147 Accuracy 1.00000
Epoch [ 5]: Loss 0.21845
Epoch [ 5]: Loss 0.18657
Epoch [ 5]: Loss 0.19382
Epoch [ 5]: Loss 0.19102
Epoch [ 5]: Loss 0.19010
Epoch [ 5]: Loss 0.18979
Epoch [ 5]: Loss 0.19469
Validation: Loss 0.15762 Accuracy 1.00000
Validation: Loss 0.16596 Accuracy 1.00000
Epoch [ 6]: Loss 0.16529
Epoch [ 6]: Loss 0.16055
Epoch [ 6]: Loss 0.13431
Epoch [ 6]: Loss 0.13475
Epoch [ 6]: Loss 0.14147
Epoch [ 6]: Loss 0.13922
Epoch [ 6]: Loss 0.11118
Validation: Loss 0.11546 Accuracy 1.00000
Validation: Loss 0.12248 Accuracy 1.00000
Epoch [ 7]: Loss 0.12262
Epoch [ 7]: Loss 0.11204
Epoch [ 7]: Loss 0.10542
Epoch [ 7]: Loss 0.10456
Epoch [ 7]: Loss 0.09691
Epoch [ 7]: Loss 0.09730
Epoch [ 7]: Loss 0.08353
Validation: Loss 0.08286 Accuracy 1.00000
Validation: Loss 0.08817 Accuracy 1.00000
Epoch [ 8]: Loss 0.08157
Epoch [ 8]: Loss 0.07897
Epoch [ 8]: Loss 0.07694
Epoch [ 8]: Loss 0.07519
Epoch [ 8]: Loss 0.07263
Epoch [ 8]: Loss 0.06703
Epoch [ 8]: Loss 0.06024
Validation: Loss 0.05786 Accuracy 1.00000
Validation: Loss 0.06142 Accuracy 1.00000
Epoch [ 9]: Loss 0.06097
Epoch [ 9]: Loss 0.05504
Epoch [ 9]: Loss 0.05171
Epoch [ 9]: Loss 0.05220
Epoch [ 9]: Loss 0.05011
Epoch [ 9]: Loss 0.04811
Epoch [ 9]: Loss 0.04560
Validation: Loss 0.04275 Accuracy 1.00000
Validation: Loss 0.04519 Accuracy 1.00000
Epoch [ 10]: Loss 0.04774
Epoch [ 10]: Loss 0.04404
Epoch [ 10]: Loss 0.04161
Epoch [ 10]: Loss 0.04061
Epoch [ 10]: Loss 0.03748
Epoch [ 10]: Loss 0.03339
Epoch [ 10]: Loss 0.03148
Validation: Loss 0.03443 Accuracy 1.00000
Validation: Loss 0.03636 Accuracy 1.00000
Epoch [ 11]: Loss 0.03588
Epoch [ 11]: Loss 0.03527
Epoch [ 11]: Loss 0.03468
Epoch [ 11]: Loss 0.03222
Epoch [ 11]: Loss 0.03144
Epoch [ 11]: Loss 0.03076
Epoch [ 11]: Loss 0.02853
Validation: Loss 0.02917 Accuracy 1.00000
Validation: Loss 0.03085 Accuracy 1.00000
Epoch [ 12]: Loss 0.03344
Epoch [ 12]: Loss 0.02880
Epoch [ 12]: Loss 0.02919
Epoch [ 12]: Loss 0.02717
Epoch [ 12]: Loss 0.02616
Epoch [ 12]: Loss 0.02728
Epoch [ 12]: Loss 0.02289
Validation: Loss 0.02535 Accuracy 1.00000
Validation: Loss 0.02684 Accuracy 1.00000
Epoch [ 13]: Loss 0.02693
Epoch [ 13]: Loss 0.02511
Epoch [ 13]: Loss 0.02422
Epoch [ 13]: Loss 0.02572
Epoch [ 13]: Loss 0.02404
Epoch [ 13]: Loss 0.02340
Epoch [ 13]: Loss 0.02421
Validation: Loss 0.02241 Accuracy 1.00000
Validation: Loss 0.02376 Accuracy 1.00000
Epoch [ 14]: Loss 0.02341
Epoch [ 14]: Loss 0.02184
Epoch [ 14]: Loss 0.02189
Epoch [ 14]: Loss 0.02371
Epoch [ 14]: Loss 0.02139
Epoch [ 14]: Loss 0.02143
Epoch [ 14]: Loss 0.01792
Validation: Loss 0.02003 Accuracy 1.00000
Validation: Loss 0.02127 Accuracy 1.00000
Epoch [ 15]: Loss 0.01966
Epoch [ 15]: Loss 0.02140
Epoch [ 15]: Loss 0.02016
Epoch [ 15]: Loss 0.02024
Epoch [ 15]: Loss 0.01977
Epoch [ 15]: Loss 0.01755
Epoch [ 15]: Loss 0.02096
Validation: Loss 0.01807 Accuracy 1.00000
Validation: Loss 0.01922 Accuracy 1.00000
Epoch [ 16]: Loss 0.01791
Epoch [ 16]: Loss 0.02073
Epoch [ 16]: Loss 0.01857
Epoch [ 16]: Loss 0.01715
Epoch [ 16]: Loss 0.01956
Epoch [ 16]: Loss 0.01424
Epoch [ 16]: Loss 0.01647
Validation: Loss 0.01640 Accuracy 1.00000
Validation: Loss 0.01746 Accuracy 1.00000
Epoch [ 17]: Loss 0.01641
Epoch [ 17]: Loss 0.01715
Epoch [ 17]: Loss 0.01676
Epoch [ 17]: Loss 0.01654
Epoch [ 17]: Loss 0.01603
Epoch [ 17]: Loss 0.01584
Epoch [ 17]: Loss 0.01381
Validation: Loss 0.01499 Accuracy 1.00000
Validation: Loss 0.01598 Accuracy 1.00000
Epoch [ 18]: Loss 0.01619
Epoch [ 18]: Loss 0.01526
Epoch [ 18]: Loss 0.01433
Epoch [ 18]: Loss 0.01475
Epoch [ 18]: Loss 0.01442
Epoch [ 18]: Loss 0.01540
Epoch [ 18]: Loss 0.01324
Validation: Loss 0.01378 Accuracy 1.00000
Validation: Loss 0.01471 Accuracy 1.00000
Epoch [ 19]: Loss 0.01337
Epoch [ 19]: Loss 0.01502
Epoch [ 19]: Loss 0.01497
Epoch [ 19]: Loss 0.01372
Epoch [ 19]: Loss 0.01236
Epoch [ 19]: Loss 0.01316
Epoch [ 19]: Loss 0.01486
Validation: Loss 0.01273 Accuracy 1.00000
Validation: Loss 0.01359 Accuracy 1.00000
Epoch [ 20]: Loss 0.01351
Epoch [ 20]: Loss 0.01272
Epoch [ 20]: Loss 0.01258
Epoch [ 20]: Loss 0.01243
Epoch [ 20]: Loss 0.01280
Epoch [ 20]: Loss 0.01265
Epoch [ 20]: Loss 0.01232
Validation: Loss 0.01177 Accuracy 1.00000
Validation: Loss 0.01257 Accuracy 1.00000
Epoch [ 21]: Loss 0.01212
Epoch [ 21]: Loss 0.01248
Epoch [ 21]: Loss 0.01233
Epoch [ 21]: Loss 0.01232
Epoch [ 21]: Loss 0.01204
Epoch [ 21]: Loss 0.00974
Epoch [ 21]: Loss 0.01085
Validation: Loss 0.01083 Accuracy 1.00000
Validation: Loss 0.01157 Accuracy 1.00000
Epoch [ 22]: Loss 0.01096
Epoch [ 22]: Loss 0.01048
Epoch [ 22]: Loss 0.01099
Epoch [ 22]: Loss 0.01031
Epoch [ 22]: Loss 0.01111
Epoch [ 22]: Loss 0.01124
Epoch [ 22]: Loss 0.00949
Validation: Loss 0.00982 Accuracy 1.00000
Validation: Loss 0.01049 Accuracy 1.00000
Epoch [ 23]: Loss 0.01080
Epoch [ 23]: Loss 0.00986
Epoch [ 23]: Loss 0.00933
Epoch [ 23]: Loss 0.00972
Epoch [ 23]: Loss 0.00939
Epoch [ 23]: Loss 0.00949
Epoch [ 23]: Loss 0.00846
Validation: Loss 0.00872 Accuracy 1.00000
Validation: Loss 0.00930 Accuracy 1.00000
Epoch [ 24]: Loss 0.00874
Epoch [ 24]: Loss 0.00815
Epoch [ 24]: Loss 0.00947
Epoch [ 24]: Loss 0.00869
Epoch [ 24]: Loss 0.00847
Epoch [ 24]: Loss 0.00817
Epoch [ 24]: Loss 0.00814
Validation: Loss 0.00778 Accuracy 1.00000
Validation: Loss 0.00828 Accuracy 1.00000
Epoch [ 25]: Loss 0.00794
Epoch [ 25]: Loss 0.00774
Epoch [ 25]: Loss 0.00757
Epoch [ 25]: Loss 0.00783
Epoch [ 25]: Loss 0.00785
Epoch [ 25]: Loss 0.00762
Epoch [ 25]: Loss 0.00712
Validation: Loss 0.00712 Accuracy 1.00000
Validation: Loss 0.00755 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [ 1]: Loss 0.62815
Epoch [ 1]: Loss 0.59273
Epoch [ 1]: Loss 0.55839
Epoch [ 1]: Loss 0.54737
Epoch [ 1]: Loss 0.51929
Epoch [ 1]: Loss 0.49971
Epoch [ 1]: Loss 0.47626
Validation: Loss 0.46640 Accuracy 1.00000
Validation: Loss 0.46883 Accuracy 1.00000
Epoch [ 2]: Loss 0.46898
Epoch [ 2]: Loss 0.46389
Epoch [ 2]: Loss 0.44597
Epoch [ 2]: Loss 0.43167
Epoch [ 2]: Loss 0.40783
Epoch [ 2]: Loss 0.37916
Epoch [ 2]: Loss 0.39570
Validation: Loss 0.36861 Accuracy 1.00000
Validation: Loss 0.37164 Accuracy 1.00000
Epoch [ 3]: Loss 0.37045
Epoch [ 3]: Loss 0.35656
Epoch [ 3]: Loss 0.35852
Epoch [ 3]: Loss 0.32657
Epoch [ 3]: Loss 0.32571
Epoch [ 3]: Loss 0.30801
Epoch [ 3]: Loss 0.28603
Validation: Loss 0.28345 Accuracy 1.00000
Validation: Loss 0.28672 Accuracy 1.00000
Epoch [ 4]: Loss 0.29379
Epoch [ 4]: Loss 0.28620
Epoch [ 4]: Loss 0.26634
Epoch [ 4]: Loss 0.25464
Epoch [ 4]: Loss 0.23619
Epoch [ 4]: Loss 0.23390
Epoch [ 4]: Loss 0.20474
Validation: Loss 0.21358 Accuracy 1.00000
Validation: Loss 0.21666 Accuracy 1.00000
Epoch [ 5]: Loss 0.22371
Epoch [ 5]: Loss 0.21711
Epoch [ 5]: Loss 0.19038
Epoch [ 5]: Loss 0.18073
Epoch [ 5]: Loss 0.19395
Epoch [ 5]: Loss 0.17035
Epoch [ 5]: Loss 0.17330
Validation: Loss 0.15838 Accuracy 1.00000
Validation: Loss 0.16096 Accuracy 1.00000
Epoch [ 6]: Loss 0.17413
Epoch [ 6]: Loss 0.14452
Epoch [ 6]: Loss 0.15736
Epoch [ 6]: Loss 0.13227
Epoch [ 6]: Loss 0.13055
Epoch [ 6]: Loss 0.12925
Epoch [ 6]: Loss 0.14008
Validation: Loss 0.11597 Accuracy 1.00000
Validation: Loss 0.11791 Accuracy 1.00000
Epoch [ 7]: Loss 0.13106
Epoch [ 7]: Loss 0.10266
Epoch [ 7]: Loss 0.10650
Epoch [ 7]: Loss 0.10380
Epoch [ 7]: Loss 0.10333
Epoch [ 7]: Loss 0.08646
Epoch [ 7]: Loss 0.09514
Validation: Loss 0.08289 Accuracy 1.00000
Validation: Loss 0.08423 Accuracy 1.00000
Epoch [ 8]: Loss 0.08447
Epoch [ 8]: Loss 0.08013
Epoch [ 8]: Loss 0.07841
Epoch [ 8]: Loss 0.07234
Epoch [ 8]: Loss 0.06504
Epoch [ 8]: Loss 0.06937
Epoch [ 8]: Loss 0.05756
Validation: Loss 0.05774 Accuracy 1.00000
Validation: Loss 0.05862 Accuracy 1.00000
Epoch [ 9]: Loss 0.06087
Epoch [ 9]: Loss 0.05691
Epoch [ 9]: Loss 0.05082
Epoch [ 9]: Loss 0.04976
Epoch [ 9]: Loss 0.05170
Epoch [ 9]: Loss 0.04546
Epoch [ 9]: Loss 0.04831
Validation: Loss 0.04294 Accuracy 1.00000
Validation: Loss 0.04362 Accuracy 1.00000
Epoch [ 10]: Loss 0.04800
Epoch [ 10]: Loss 0.04253
Epoch [ 10]: Loss 0.04092
Epoch [ 10]: Loss 0.03956
Epoch [ 10]: Loss 0.03555
Epoch [ 10]: Loss 0.03826
Epoch [ 10]: Loss 0.03033
Validation: Loss 0.03476 Accuracy 1.00000
Validation: Loss 0.03534 Accuracy 1.00000
Epoch [ 11]: Loss 0.03568
Epoch [ 11]: Loss 0.03385
Epoch [ 11]: Loss 0.03366
Epoch [ 11]: Loss 0.03423
Epoch [ 11]: Loss 0.03319
Epoch [ 11]: Loss 0.03007
Epoch [ 11]: Loss 0.02974
Validation: Loss 0.02954 Accuracy 1.00000
Validation: Loss 0.03004 Accuracy 1.00000
Epoch [ 12]: Loss 0.02803
Epoch [ 12]: Loss 0.03045
Epoch [ 12]: Loss 0.02957
Epoch [ 12]: Loss 0.02731
Epoch [ 12]: Loss 0.02797
Epoch [ 12]: Loss 0.02697
Epoch [ 12]: Loss 0.03317
Validation: Loss 0.02572 Accuracy 1.00000
Validation: Loss 0.02618 Accuracy 1.00000
Epoch [ 13]: Loss 0.02459
Epoch [ 13]: Loss 0.02706
Epoch [ 13]: Loss 0.02679
Epoch [ 13]: Loss 0.02481
Epoch [ 13]: Loss 0.02351
Epoch [ 13]: Loss 0.02417
Epoch [ 13]: Loss 0.02273
Validation: Loss 0.02274 Accuracy 1.00000
Validation: Loss 0.02315 Accuracy 1.00000
Epoch [ 14]: Loss 0.02335
Epoch [ 14]: Loss 0.02519
Epoch [ 14]: Loss 0.02322
Epoch [ 14]: Loss 0.02088
Epoch [ 14]: Loss 0.02151
Epoch [ 14]: Loss 0.02080
Epoch [ 14]: Loss 0.01727
Validation: Loss 0.02035 Accuracy 1.00000
Validation: Loss 0.02072 Accuracy 1.00000
Epoch [ 15]: Loss 0.02131
Epoch [ 15]: Loss 0.02079
Epoch [ 15]: Loss 0.02139
Epoch [ 15]: Loss 0.01874
Epoch [ 15]: Loss 0.01938
Epoch [ 15]: Loss 0.01908
Epoch [ 15]: Loss 0.01762
Validation: Loss 0.01838 Accuracy 1.00000
Validation: Loss 0.01873 Accuracy 1.00000
Epoch [ 16]: Loss 0.01960
Epoch [ 16]: Loss 0.01796
Epoch [ 16]: Loss 0.01856
Epoch [ 16]: Loss 0.01815
Epoch [ 16]: Loss 0.01708
Epoch [ 16]: Loss 0.01718
Epoch [ 16]: Loss 0.01922
Validation: Loss 0.01672 Accuracy 1.00000
Validation: Loss 0.01704 Accuracy 1.00000
Epoch [ 17]: Loss 0.01786
Epoch [ 17]: Loss 0.01778
Epoch [ 17]: Loss 0.01614
Epoch [ 17]: Loss 0.01669
Epoch [ 17]: Loss 0.01603
Epoch [ 17]: Loss 0.01552
Epoch [ 17]: Loss 0.01337
Validation: Loss 0.01528 Accuracy 1.00000
Validation: Loss 0.01558 Accuracy 1.00000
Epoch [ 18]: Loss 0.01484
Epoch [ 18]: Loss 0.01525
Epoch [ 18]: Loss 0.01551
Epoch [ 18]: Loss 0.01503
Epoch [ 18]: Loss 0.01673
Epoch [ 18]: Loss 0.01397
Epoch [ 18]: Loss 0.01337
Validation: Loss 0.01406 Accuracy 1.00000
Validation: Loss 0.01433 Accuracy 1.00000
Epoch [ 19]: Loss 0.01476
Epoch [ 19]: Loss 0.01362
Epoch [ 19]: Loss 0.01250
Epoch [ 19]: Loss 0.01496
Epoch [ 19]: Loss 0.01426
Epoch [ 19]: Loss 0.01343
Epoch [ 19]: Loss 0.01501
Validation: Loss 0.01299 Accuracy 1.00000
Validation: Loss 0.01325 Accuracy 1.00000
Epoch [ 20]: Loss 0.01352
Epoch [ 20]: Loss 0.01363
Epoch [ 20]: Loss 0.01388
Epoch [ 20]: Loss 0.01207
Epoch [ 20]: Loss 0.01207
Epoch [ 20]: Loss 0.01305
Epoch [ 20]: Loss 0.01046
Validation: Loss 0.01204 Accuracy 1.00000
Validation: Loss 0.01228 Accuracy 1.00000
Epoch [ 21]: Loss 0.01216
Epoch [ 21]: Loss 0.01301
Epoch [ 21]: Loss 0.01132
Epoch [ 21]: Loss 0.01144
Epoch [ 21]: Loss 0.01225
Epoch [ 21]: Loss 0.01153
Epoch [ 21]: Loss 0.01290
Validation: Loss 0.01117 Accuracy 1.00000
Validation: Loss 0.01138 Accuracy 1.00000
Epoch [ 22]: Loss 0.01061
Epoch [ 22]: Loss 0.01049
Epoch [ 22]: Loss 0.01159
Epoch [ 22]: Loss 0.01091
Epoch [ 22]: Loss 0.01145
Epoch [ 22]: Loss 0.01131
Epoch [ 22]: Loss 0.01189
Validation: Loss 0.01029 Accuracy 1.00000
Validation: Loss 0.01048 Accuracy 1.00000
Epoch [ 23]: Loss 0.01049
Epoch [ 23]: Loss 0.01017
Epoch [ 23]: Loss 0.01092
Epoch [ 23]: Loss 0.00979
Epoch [ 23]: Loss 0.00878
Epoch [ 23]: Loss 0.01075
Epoch [ 23]: Loss 0.01037
Validation: Loss 0.00930 Accuracy 1.00000
Validation: Loss 0.00946 Accuracy 1.00000
Epoch [ 24]: Loss 0.00996
Epoch [ 24]: Loss 0.00860
Epoch [ 24]: Loss 0.00856
Epoch [ 24]: Loss 0.00969
Epoch [ 24]: Loss 0.00961
Epoch [ 24]: Loss 0.00839
Epoch [ 24]: Loss 0.00823
Validation: Loss 0.00825 Accuracy 1.00000
Validation: Loss 0.00839 Accuracy 1.00000
Epoch [ 25]: Loss 0.00840
Epoch [ 25]: Loss 0.00890
Epoch [ 25]: Loss 0.00890
Epoch [ 25]: Loss 0.00780
Epoch [ 25]: Loss 0.00723
Epoch [ 25]: Loss 0.00761
Epoch [ 25]: Loss 0.00710
Validation: Loss 0.00741 Accuracy 1.00000
Validation: Loss 0.00754 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.109 GiB / 4.750 GiB available)
This page was generated using Literate.jl.