Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = gpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() .|> dev
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
for epoch in 1:25
# Train the model
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
Epoch [ 1]: Loss 0.61378
Epoch [ 1]: Loss 0.61847
Epoch [ 1]: Loss 0.57032
Epoch [ 1]: Loss 0.54322
Epoch [ 1]: Loss 0.51738
Epoch [ 1]: Loss 0.49400
Epoch [ 1]: Loss 0.49185
Validation: Loss 0.45480 Accuracy 1.00000
Validation: Loss 0.47453 Accuracy 1.00000
Epoch [ 2]: Loss 0.47216
Epoch [ 2]: Loss 0.43763
Epoch [ 2]: Loss 0.44298
Epoch [ 2]: Loss 0.42261
Epoch [ 2]: Loss 0.41180
Epoch [ 2]: Loss 0.42190
Epoch [ 2]: Loss 0.42016
Validation: Loss 0.35533 Accuracy 1.00000
Validation: Loss 0.37903 Accuracy 1.00000
Epoch [ 3]: Loss 0.37999
Epoch [ 3]: Loss 0.35903
Epoch [ 3]: Loss 0.34447
Epoch [ 3]: Loss 0.34428
Epoch [ 3]: Loss 0.31666
Epoch [ 3]: Loss 0.31763
Epoch [ 3]: Loss 0.30522
Validation: Loss 0.26980 Accuracy 1.00000
Validation: Loss 0.29549 Accuracy 1.00000
Epoch [ 4]: Loss 0.28968
Epoch [ 4]: Loss 0.28080
Epoch [ 4]: Loss 0.26499
Epoch [ 4]: Loss 0.25127
Epoch [ 4]: Loss 0.24872
Epoch [ 4]: Loss 0.25242
Epoch [ 4]: Loss 0.23939
Validation: Loss 0.19994 Accuracy 1.00000
Validation: Loss 0.22619 Accuracy 1.00000
Epoch [ 5]: Loss 0.22451
Epoch [ 5]: Loss 0.20902
Epoch [ 5]: Loss 0.20315
Epoch [ 5]: Loss 0.20046
Epoch [ 5]: Loss 0.19254
Epoch [ 5]: Loss 0.17871
Epoch [ 5]: Loss 0.15594
Validation: Loss 0.14601 Accuracy 1.00000
Validation: Loss 0.17044 Accuracy 1.00000
Epoch [ 6]: Loss 0.16866
Epoch [ 6]: Loss 0.16190
Epoch [ 6]: Loss 0.15189
Epoch [ 6]: Loss 0.13589
Epoch [ 6]: Loss 0.15081
Epoch [ 6]: Loss 0.13322
Epoch [ 6]: Loss 0.11285
Validation: Loss 0.10586 Accuracy 1.00000
Validation: Loss 0.12644 Accuracy 1.00000
Epoch [ 7]: Loss 0.11841
Epoch [ 7]: Loss 0.12407
Epoch [ 7]: Loss 0.11227
Epoch [ 7]: Loss 0.11104
Epoch [ 7]: Loss 0.09337
Epoch [ 7]: Loss 0.10119
Epoch [ 7]: Loss 0.08812
Validation: Loss 0.07567 Accuracy 1.00000
Validation: Loss 0.09117 Accuracy 1.00000
Epoch [ 8]: Loss 0.08371
Epoch [ 8]: Loss 0.08902
Epoch [ 8]: Loss 0.08305
Epoch [ 8]: Loss 0.07813
Epoch [ 8]: Loss 0.07189
Epoch [ 8]: Loss 0.06359
Epoch [ 8]: Loss 0.06339
Validation: Loss 0.05300 Accuracy 1.00000
Validation: Loss 0.06344 Accuracy 1.00000
Epoch [ 9]: Loss 0.06268
Epoch [ 9]: Loss 0.05926
Epoch [ 9]: Loss 0.05604
Epoch [ 9]: Loss 0.05444
Epoch [ 9]: Loss 0.05001
Epoch [ 9]: Loss 0.04881
Epoch [ 9]: Loss 0.04332
Validation: Loss 0.03955 Accuracy 1.00000
Validation: Loss 0.04699 Accuracy 1.00000
Epoch [ 10]: Loss 0.04381
Epoch [ 10]: Loss 0.04381
Epoch [ 10]: Loss 0.04311
Epoch [ 10]: Loss 0.04110
Epoch [ 10]: Loss 0.04220
Epoch [ 10]: Loss 0.03807
Epoch [ 10]: Loss 0.03988
Validation: Loss 0.03194 Accuracy 1.00000
Validation: Loss 0.03800 Accuracy 1.00000
Epoch [ 11]: Loss 0.03509
Epoch [ 11]: Loss 0.04033
Epoch [ 11]: Loss 0.03274
Epoch [ 11]: Loss 0.03695
Epoch [ 11]: Loss 0.03296
Epoch [ 11]: Loss 0.02967
Epoch [ 11]: Loss 0.03474
Validation: Loss 0.02701 Accuracy 1.00000
Validation: Loss 0.03226 Accuracy 1.00000
Epoch [ 12]: Loss 0.03138
Epoch [ 12]: Loss 0.03313
Epoch [ 12]: Loss 0.03035
Epoch [ 12]: Loss 0.02687
Epoch [ 12]: Loss 0.02963
Epoch [ 12]: Loss 0.02759
Epoch [ 12]: Loss 0.02627
Validation: Loss 0.02343 Accuracy 1.00000
Validation: Loss 0.02807 Accuracy 1.00000
Epoch [ 13]: Loss 0.02680
Epoch [ 13]: Loss 0.02584
Epoch [ 13]: Loss 0.02700
Epoch [ 13]: Loss 0.02559
Epoch [ 13]: Loss 0.02473
Epoch [ 13]: Loss 0.02564
Epoch [ 13]: Loss 0.02675
Validation: Loss 0.02068 Accuracy 1.00000
Validation: Loss 0.02486 Accuracy 1.00000
Epoch [ 14]: Loss 0.02438
Epoch [ 14]: Loss 0.02269
Epoch [ 14]: Loss 0.02441
Epoch [ 14]: Loss 0.02154
Epoch [ 14]: Loss 0.02305
Epoch [ 14]: Loss 0.02208
Epoch [ 14]: Loss 0.02473
Validation: Loss 0.01848 Accuracy 1.00000
Validation: Loss 0.02227 Accuracy 1.00000
Epoch [ 15]: Loss 0.01965
Epoch [ 15]: Loss 0.02092
Epoch [ 15]: Loss 0.02061
Epoch [ 15]: Loss 0.02220
Epoch [ 15]: Loss 0.02020
Epoch [ 15]: Loss 0.02103
Epoch [ 15]: Loss 0.02016
Validation: Loss 0.01666 Accuracy 1.00000
Validation: Loss 0.02013 Accuracy 1.00000
Epoch [ 16]: Loss 0.02080
Epoch [ 16]: Loss 0.01909
Epoch [ 16]: Loss 0.01847
Epoch [ 16]: Loss 0.01717
Epoch [ 16]: Loss 0.01919
Epoch [ 16]: Loss 0.01841
Epoch [ 16]: Loss 0.01769
Validation: Loss 0.01511 Accuracy 1.00000
Validation: Loss 0.01832 Accuracy 1.00000
Epoch [ 17]: Loss 0.01778
Epoch [ 17]: Loss 0.01754
Epoch [ 17]: Loss 0.01666
Epoch [ 17]: Loss 0.01549
Epoch [ 17]: Loss 0.01909
Epoch [ 17]: Loss 0.01541
Epoch [ 17]: Loss 0.02053
Validation: Loss 0.01378 Accuracy 1.00000
Validation: Loss 0.01676 Accuracy 1.00000
Epoch [ 18]: Loss 0.01688
Epoch [ 18]: Loss 0.01558
Epoch [ 18]: Loss 0.01599
Epoch [ 18]: Loss 0.01535
Epoch [ 18]: Loss 0.01605
Epoch [ 18]: Loss 0.01483
Epoch [ 18]: Loss 0.01362
Validation: Loss 0.01260 Accuracy 1.00000
Validation: Loss 0.01536 Accuracy 1.00000
Epoch [ 19]: Loss 0.01547
Epoch [ 19]: Loss 0.01457
Epoch [ 19]: Loss 0.01468
Epoch [ 19]: Loss 0.01463
Epoch [ 19]: Loss 0.01363
Epoch [ 19]: Loss 0.01386
Epoch [ 19]: Loss 0.01249
Validation: Loss 0.01154 Accuracy 1.00000
Validation: Loss 0.01410 Accuracy 1.00000
Epoch [ 20]: Loss 0.01552
Epoch [ 20]: Loss 0.01312
Epoch [ 20]: Loss 0.01267
Epoch [ 20]: Loss 0.01277
Epoch [ 20]: Loss 0.01211
Epoch [ 20]: Loss 0.01347
Epoch [ 20]: Loss 0.01139
Validation: Loss 0.01054 Accuracy 1.00000
Validation: Loss 0.01288 Accuracy 1.00000
Epoch [ 21]: Loss 0.01309
Epoch [ 21]: Loss 0.01268
Epoch [ 21]: Loss 0.01225
Epoch [ 21]: Loss 0.01212
Epoch [ 21]: Loss 0.01098
Epoch [ 21]: Loss 0.01103
Epoch [ 21]: Loss 0.01161
Validation: Loss 0.00949 Accuracy 1.00000
Validation: Loss 0.01157 Accuracy 1.00000
Epoch [ 22]: Loss 0.01222
Epoch [ 22]: Loss 0.01049
Epoch [ 22]: Loss 0.00989
Epoch [ 22]: Loss 0.01097
Epoch [ 22]: Loss 0.01044
Epoch [ 22]: Loss 0.00987
Epoch [ 22]: Loss 0.01182
Validation: Loss 0.00840 Accuracy 1.00000
Validation: Loss 0.01018 Accuracy 1.00000
Epoch [ 23]: Loss 0.00913
Epoch [ 23]: Loss 0.01062
Epoch [ 23]: Loss 0.00989
Epoch [ 23]: Loss 0.00920
Epoch [ 23]: Loss 0.00921
Epoch [ 23]: Loss 0.00865
Epoch [ 23]: Loss 0.00879
Validation: Loss 0.00751 Accuracy 1.00000
Validation: Loss 0.00905 Accuracy 1.00000
Epoch [ 24]: Loss 0.00903
Epoch [ 24]: Loss 0.00796
Epoch [ 24]: Loss 0.00847
Epoch [ 24]: Loss 0.00835
Epoch [ 24]: Loss 0.00855
Epoch [ 24]: Loss 0.00836
Epoch [ 24]: Loss 0.00855
Validation: Loss 0.00687 Accuracy 1.00000
Validation: Loss 0.00826 Accuracy 1.00000
Epoch [ 25]: Loss 0.00831
Epoch [ 25]: Loss 0.00850
Epoch [ 25]: Loss 0.00756
Epoch [ 25]: Loss 0.00755
Epoch [ 25]: Loss 0.00775
Epoch [ 25]: Loss 0.00712
Epoch [ 25]: Loss 0.00739
Validation: Loss 0.00639 Accuracy 1.00000
Validation: Loss 0.00767 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [ 1]: Loss 0.63161
Epoch [ 1]: Loss 0.59157
Epoch [ 1]: Loss 0.56177
Epoch [ 1]: Loss 0.53878
Epoch [ 1]: Loss 0.51506
Epoch [ 1]: Loss 0.50421
Epoch [ 1]: Loss 0.48370
Validation: Loss 0.46435 Accuracy 1.00000
Validation: Loss 0.47612 Accuracy 1.00000
Epoch [ 2]: Loss 0.47300
Epoch [ 2]: Loss 0.44561
Epoch [ 2]: Loss 0.44431
Epoch [ 2]: Loss 0.41993
Epoch [ 2]: Loss 0.41492
Epoch [ 2]: Loss 0.39915
Epoch [ 2]: Loss 0.38991
Validation: Loss 0.36601 Accuracy 1.00000
Validation: Loss 0.37949 Accuracy 1.00000
Epoch [ 3]: Loss 0.37226
Epoch [ 3]: Loss 0.35362
Epoch [ 3]: Loss 0.35615
Epoch [ 3]: Loss 0.33198
Epoch [ 3]: Loss 0.30412
Epoch [ 3]: Loss 0.31466
Epoch [ 3]: Loss 0.30811
Validation: Loss 0.28010 Accuracy 1.00000
Validation: Loss 0.29456 Accuracy 1.00000
Epoch [ 4]: Loss 0.28383
Epoch [ 4]: Loss 0.27321
Epoch [ 4]: Loss 0.26839
Epoch [ 4]: Loss 0.24670
Epoch [ 4]: Loss 0.24752
Epoch [ 4]: Loss 0.22688
Epoch [ 4]: Loss 0.25679
Validation: Loss 0.20971 Accuracy 1.00000
Validation: Loss 0.22409 Accuracy 1.00000
Epoch [ 5]: Loss 0.21850
Epoch [ 5]: Loss 0.20251
Epoch [ 5]: Loss 0.19916
Epoch [ 5]: Loss 0.19424
Epoch [ 5]: Loss 0.17411
Epoch [ 5]: Loss 0.16995
Epoch [ 5]: Loss 0.18597
Validation: Loss 0.15423 Accuracy 1.00000
Validation: Loss 0.16730 Accuracy 1.00000
Epoch [ 6]: Loss 0.15835
Epoch [ 6]: Loss 0.14846
Epoch [ 6]: Loss 0.14356
Epoch [ 6]: Loss 0.14558
Epoch [ 6]: Loss 0.12427
Epoch [ 6]: Loss 0.13242
Epoch [ 6]: Loss 0.13393
Validation: Loss 0.11213 Accuracy 1.00000
Validation: Loss 0.12277 Accuracy 1.00000
Epoch [ 7]: Loss 0.11278
Epoch [ 7]: Loss 0.11189
Epoch [ 7]: Loss 0.10748
Epoch [ 7]: Loss 0.10338
Epoch [ 7]: Loss 0.09165
Epoch [ 7]: Loss 0.09369
Epoch [ 7]: Loss 0.08419
Validation: Loss 0.07983 Accuracy 1.00000
Validation: Loss 0.08754 Accuracy 1.00000
Epoch [ 8]: Loss 0.08700
Epoch [ 8]: Loss 0.07787
Epoch [ 8]: Loss 0.07338
Epoch [ 8]: Loss 0.07191
Epoch [ 8]: Loss 0.06659
Epoch [ 8]: Loss 0.05968
Epoch [ 8]: Loss 0.06656
Validation: Loss 0.05568 Accuracy 1.00000
Validation: Loss 0.06079 Accuracy 1.00000
Epoch [ 9]: Loss 0.05713
Epoch [ 9]: Loss 0.05704
Epoch [ 9]: Loss 0.04976
Epoch [ 9]: Loss 0.04687
Epoch [ 9]: Loss 0.04846
Epoch [ 9]: Loss 0.04931
Epoch [ 9]: Loss 0.04599
Validation: Loss 0.04171 Accuracy 1.00000
Validation: Loss 0.04545 Accuracy 1.00000
Epoch [ 10]: Loss 0.04481
Epoch [ 10]: Loss 0.04029
Epoch [ 10]: Loss 0.03901
Epoch [ 10]: Loss 0.03979
Epoch [ 10]: Loss 0.03917
Epoch [ 10]: Loss 0.03718
Epoch [ 10]: Loss 0.03122
Validation: Loss 0.03387 Accuracy 1.00000
Validation: Loss 0.03699 Accuracy 1.00000
Epoch [ 11]: Loss 0.03528
Epoch [ 11]: Loss 0.03387
Epoch [ 11]: Loss 0.03192
Epoch [ 11]: Loss 0.03362
Epoch [ 11]: Loss 0.03126
Epoch [ 11]: Loss 0.03112
Epoch [ 11]: Loss 0.03346
Validation: Loss 0.02880 Accuracy 1.00000
Validation: Loss 0.03154 Accuracy 1.00000
Epoch [ 12]: Loss 0.03106
Epoch [ 12]: Loss 0.02795
Epoch [ 12]: Loss 0.03002
Epoch [ 12]: Loss 0.02898
Epoch [ 12]: Loss 0.02863
Epoch [ 12]: Loss 0.02439
Epoch [ 12]: Loss 0.02284
Validation: Loss 0.02506 Accuracy 1.00000
Validation: Loss 0.02751 Accuracy 1.00000
Epoch [ 13]: Loss 0.02571
Epoch [ 13]: Loss 0.02592
Epoch [ 13]: Loss 0.02504
Epoch [ 13]: Loss 0.02589
Epoch [ 13]: Loss 0.02314
Epoch [ 13]: Loss 0.02314
Epoch [ 13]: Loss 0.02421
Validation: Loss 0.02219 Accuracy 1.00000
Validation: Loss 0.02441 Accuracy 1.00000
Epoch [ 14]: Loss 0.02222
Epoch [ 14]: Loss 0.02462
Epoch [ 14]: Loss 0.02128
Epoch [ 14]: Loss 0.02325
Epoch [ 14]: Loss 0.02062
Epoch [ 14]: Loss 0.02083
Epoch [ 14]: Loss 0.02015
Validation: Loss 0.01986 Accuracy 1.00000
Validation: Loss 0.02190 Accuracy 1.00000
Epoch [ 15]: Loss 0.02059
Epoch [ 15]: Loss 0.01920
Epoch [ 15]: Loss 0.01965
Epoch [ 15]: Loss 0.01989
Epoch [ 15]: Loss 0.01837
Epoch [ 15]: Loss 0.02177
Epoch [ 15]: Loss 0.01750
Validation: Loss 0.01794 Accuracy 1.00000
Validation: Loss 0.01983 Accuracy 1.00000
Epoch [ 16]: Loss 0.01861
Epoch [ 16]: Loss 0.01836
Epoch [ 16]: Loss 0.01845
Epoch [ 16]: Loss 0.01721
Epoch [ 16]: Loss 0.01631
Epoch [ 16]: Loss 0.01805
Epoch [ 16]: Loss 0.02105
Validation: Loss 0.01631 Accuracy 1.00000
Validation: Loss 0.01806 Accuracy 1.00000
Epoch [ 17]: Loss 0.01821
Epoch [ 17]: Loss 0.01545
Epoch [ 17]: Loss 0.01495
Epoch [ 17]: Loss 0.01644
Epoch [ 17]: Loss 0.01715
Epoch [ 17]: Loss 0.01604
Epoch [ 17]: Loss 0.01629
Validation: Loss 0.01490 Accuracy 1.00000
Validation: Loss 0.01654 Accuracy 1.00000
Epoch [ 18]: Loss 0.01624
Epoch [ 18]: Loss 0.01426
Epoch [ 18]: Loss 0.01628
Epoch [ 18]: Loss 0.01508
Epoch [ 18]: Loss 0.01575
Epoch [ 18]: Loss 0.01302
Epoch [ 18]: Loss 0.01273
Validation: Loss 0.01370 Accuracy 1.00000
Validation: Loss 0.01522 Accuracy 1.00000
Epoch [ 19]: Loss 0.01591
Epoch [ 19]: Loss 0.01340
Epoch [ 19]: Loss 0.01338
Epoch [ 19]: Loss 0.01309
Epoch [ 19]: Loss 0.01254
Epoch [ 19]: Loss 0.01386
Epoch [ 19]: Loss 0.01697
Validation: Loss 0.01266 Accuracy 1.00000
Validation: Loss 0.01408 Accuracy 1.00000
Epoch [ 20]: Loss 0.01291
Epoch [ 20]: Loss 0.01290
Epoch [ 20]: Loss 0.01288
Epoch [ 20]: Loss 0.01337
Epoch [ 20]: Loss 0.01210
Epoch [ 20]: Loss 0.01234
Epoch [ 20]: Loss 0.01379
Validation: Loss 0.01172 Accuracy 1.00000
Validation: Loss 0.01304 Accuracy 1.00000
Epoch [ 21]: Loss 0.01256
Epoch [ 21]: Loss 0.01172
Epoch [ 21]: Loss 0.01277
Epoch [ 21]: Loss 0.01133
Epoch [ 21]: Loss 0.01086
Epoch [ 21]: Loss 0.01189
Epoch [ 21]: Loss 0.01185
Validation: Loss 0.01085 Accuracy 1.00000
Validation: Loss 0.01207 Accuracy 1.00000
Epoch [ 22]: Loss 0.01152
Epoch [ 22]: Loss 0.01116
Epoch [ 22]: Loss 0.01092
Epoch [ 22]: Loss 0.01111
Epoch [ 22]: Loss 0.01054
Epoch [ 22]: Loss 0.01052
Epoch [ 22]: Loss 0.01050
Validation: Loss 0.00996 Accuracy 1.00000
Validation: Loss 0.01105 Accuracy 1.00000
Epoch [ 23]: Loss 0.00983
Epoch [ 23]: Loss 0.01054
Epoch [ 23]: Loss 0.00956
Epoch [ 23]: Loss 0.01042
Epoch [ 23]: Loss 0.00950
Epoch [ 23]: Loss 0.01035
Epoch [ 23]: Loss 0.00835
Validation: Loss 0.00897 Accuracy 1.00000
Validation: Loss 0.00992 Accuracy 1.00000
Epoch [ 24]: Loss 0.00933
Epoch [ 24]: Loss 0.00958
Epoch [ 24]: Loss 0.00883
Epoch [ 24]: Loss 0.00861
Epoch [ 24]: Loss 0.00862
Epoch [ 24]: Loss 0.00892
Epoch [ 24]: Loss 0.00719
Validation: Loss 0.00797 Accuracy 1.00000
Validation: Loss 0.00879 Accuracy 1.00000
Epoch [ 25]: Loss 0.00805
Epoch [ 25]: Loss 0.00858
Epoch [ 25]: Loss 0.00722
Epoch [ 25]: Loss 0.00761
Epoch [ 25]: Loss 0.00835
Epoch [ 25]: Loss 0.00802
Epoch [ 25]: Loss 0.00731
Validation: Loss 0.00722 Accuracy 1.00000
Validation: Loss 0.00794 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0
Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.391 GiB / 4.750 GiB available)
This page was generated using Literate.jl.