Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = gpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() .|> dev
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
for epoch in 1:25
# Train the model
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
Epoch [ 1]: Loss 0.62225
Epoch [ 1]: Loss 0.58452
Epoch [ 1]: Loss 0.57647
Epoch [ 1]: Loss 0.53713
Epoch [ 1]: Loss 0.51691
Epoch [ 1]: Loss 0.49888
Epoch [ 1]: Loss 0.48847
Validation: Loss 0.47424 Accuracy 1.00000
Validation: Loss 0.46706 Accuracy 1.00000
Epoch [ 2]: Loss 0.47281
Epoch [ 2]: Loss 0.45886
Epoch [ 2]: Loss 0.44245
Epoch [ 2]: Loss 0.42243
Epoch [ 2]: Loss 0.40542
Epoch [ 2]: Loss 0.39367
Epoch [ 2]: Loss 0.38477
Validation: Loss 0.37763 Accuracy 1.00000
Validation: Loss 0.36948 Accuracy 1.00000
Epoch [ 3]: Loss 0.36578
Epoch [ 3]: Loss 0.36006
Epoch [ 3]: Loss 0.35264
Epoch [ 3]: Loss 0.33993
Epoch [ 3]: Loss 0.30781
Epoch [ 3]: Loss 0.30599
Epoch [ 3]: Loss 0.30994
Validation: Loss 0.29307 Accuracy 1.00000
Validation: Loss 0.28429 Accuracy 1.00000
Epoch [ 4]: Loss 0.29298
Epoch [ 4]: Loss 0.27158
Epoch [ 4]: Loss 0.26401
Epoch [ 4]: Loss 0.25578
Epoch [ 4]: Loss 0.23678
Epoch [ 4]: Loss 0.23406
Epoch [ 4]: Loss 0.23145
Validation: Loss 0.22292 Accuracy 1.00000
Validation: Loss 0.21429 Accuracy 1.00000
Epoch [ 5]: Loss 0.22214
Epoch [ 5]: Loss 0.20231
Epoch [ 5]: Loss 0.18991
Epoch [ 5]: Loss 0.18320
Epoch [ 5]: Loss 0.19426
Epoch [ 5]: Loss 0.17266
Epoch [ 5]: Loss 0.17777
Validation: Loss 0.16650 Accuracy 1.00000
Validation: Loss 0.15865 Accuracy 1.00000
Epoch [ 6]: Loss 0.15678
Epoch [ 6]: Loss 0.15592
Epoch [ 6]: Loss 0.14733
Epoch [ 6]: Loss 0.13975
Epoch [ 6]: Loss 0.13619
Epoch [ 6]: Loss 0.12549
Epoch [ 6]: Loss 0.12288
Validation: Loss 0.12240 Accuracy 1.00000
Validation: Loss 0.11601 Accuracy 1.00000
Epoch [ 7]: Loss 0.11055
Epoch [ 7]: Loss 0.10557
Epoch [ 7]: Loss 0.10822
Epoch [ 7]: Loss 0.10875
Epoch [ 7]: Loss 0.09867
Epoch [ 7]: Loss 0.09373
Epoch [ 7]: Loss 0.08833
Validation: Loss 0.08776 Accuracy 1.00000
Validation: Loss 0.08307 Accuracy 1.00000
Epoch [ 8]: Loss 0.08656
Epoch [ 8]: Loss 0.07314
Epoch [ 8]: Loss 0.07540
Epoch [ 8]: Loss 0.06530
Epoch [ 8]: Loss 0.07273
Epoch [ 8]: Loss 0.06839
Epoch [ 8]: Loss 0.06574
Validation: Loss 0.06095 Accuracy 1.00000
Validation: Loss 0.05782 Accuracy 1.00000
Epoch [ 9]: Loss 0.06051
Epoch [ 9]: Loss 0.05530
Epoch [ 9]: Loss 0.05537
Epoch [ 9]: Loss 0.04864
Epoch [ 9]: Loss 0.04700
Epoch [ 9]: Loss 0.04662
Epoch [ 9]: Loss 0.04060
Validation: Loss 0.04495 Accuracy 1.00000
Validation: Loss 0.04271 Accuracy 1.00000
Epoch [ 10]: Loss 0.04399
Epoch [ 10]: Loss 0.04247
Epoch [ 10]: Loss 0.03719
Epoch [ 10]: Loss 0.03917
Epoch [ 10]: Loss 0.03678
Epoch [ 10]: Loss 0.03971
Epoch [ 10]: Loss 0.03252
Validation: Loss 0.03637 Accuracy 1.00000
Validation: Loss 0.03451 Accuracy 1.00000
Epoch [ 11]: Loss 0.03297
Epoch [ 11]: Loss 0.03494
Epoch [ 11]: Loss 0.03168
Epoch [ 11]: Loss 0.03138
Epoch [ 11]: Loss 0.03222
Epoch [ 11]: Loss 0.03131
Epoch [ 11]: Loss 0.03733
Validation: Loss 0.03089 Accuracy 1.00000
Validation: Loss 0.02926 Accuracy 1.00000
Epoch [ 12]: Loss 0.03165
Epoch [ 12]: Loss 0.03025
Epoch [ 12]: Loss 0.02871
Epoch [ 12]: Loss 0.02714
Epoch [ 12]: Loss 0.02479
Epoch [ 12]: Loss 0.02566
Epoch [ 12]: Loss 0.02719
Validation: Loss 0.02686 Accuracy 1.00000
Validation: Loss 0.02540 Accuracy 1.00000
Epoch [ 13]: Loss 0.02444
Epoch [ 13]: Loss 0.02725
Epoch [ 13]: Loss 0.02359
Epoch [ 13]: Loss 0.02236
Epoch [ 13]: Loss 0.02687
Epoch [ 13]: Loss 0.02215
Epoch [ 13]: Loss 0.02476
Validation: Loss 0.02377 Accuracy 1.00000
Validation: Loss 0.02245 Accuracy 1.00000
Epoch [ 14]: Loss 0.02161
Epoch [ 14]: Loss 0.02260
Epoch [ 14]: Loss 0.02225
Epoch [ 14]: Loss 0.02342
Epoch [ 14]: Loss 0.02007
Epoch [ 14]: Loss 0.02045
Epoch [ 14]: Loss 0.02144
Validation: Loss 0.02128 Accuracy 1.00000
Validation: Loss 0.02007 Accuracy 1.00000
Epoch [ 15]: Loss 0.02002
Epoch [ 15]: Loss 0.01937
Epoch [ 15]: Loss 0.01978
Epoch [ 15]: Loss 0.02071
Epoch [ 15]: Loss 0.01853
Epoch [ 15]: Loss 0.01894
Epoch [ 15]: Loss 0.01798
Validation: Loss 0.01922 Accuracy 1.00000
Validation: Loss 0.01811 Accuracy 1.00000
Epoch [ 16]: Loss 0.01924
Epoch [ 16]: Loss 0.01846
Epoch [ 16]: Loss 0.01725
Epoch [ 16]: Loss 0.01571
Epoch [ 16]: Loss 0.01788
Epoch [ 16]: Loss 0.01764
Epoch [ 16]: Loss 0.01633
Validation: Loss 0.01749 Accuracy 1.00000
Validation: Loss 0.01645 Accuracy 1.00000
Epoch [ 17]: Loss 0.01749
Epoch [ 17]: Loss 0.01558
Epoch [ 17]: Loss 0.01493
Epoch [ 17]: Loss 0.01669
Epoch [ 17]: Loss 0.01572
Epoch [ 17]: Loss 0.01584
Epoch [ 17]: Loss 0.01665
Validation: Loss 0.01601 Accuracy 1.00000
Validation: Loss 0.01504 Accuracy 1.00000
Epoch [ 18]: Loss 0.01531
Epoch [ 18]: Loss 0.01526
Epoch [ 18]: Loss 0.01408
Epoch [ 18]: Loss 0.01418
Epoch [ 18]: Loss 0.01513
Epoch [ 18]: Loss 0.01441
Epoch [ 18]: Loss 0.01478
Validation: Loss 0.01474 Accuracy 1.00000
Validation: Loss 0.01384 Accuracy 1.00000
Epoch [ 19]: Loss 0.01368
Epoch [ 19]: Loss 0.01337
Epoch [ 19]: Loss 0.01390
Epoch [ 19]: Loss 0.01386
Epoch [ 19]: Loss 0.01306
Epoch [ 19]: Loss 0.01317
Epoch [ 19]: Loss 0.01522
Validation: Loss 0.01364 Accuracy 1.00000
Validation: Loss 0.01279 Accuracy 1.00000
Epoch [ 20]: Loss 0.01352
Epoch [ 20]: Loss 0.01226
Epoch [ 20]: Loss 0.01360
Epoch [ 20]: Loss 0.01160
Epoch [ 20]: Loss 0.01259
Epoch [ 20]: Loss 0.01221
Epoch [ 20]: Loss 0.01126
Validation: Loss 0.01266 Accuracy 1.00000
Validation: Loss 0.01187 Accuracy 1.00000
Epoch [ 21]: Loss 0.01172
Epoch [ 21]: Loss 0.01253
Epoch [ 21]: Loss 0.01160
Epoch [ 21]: Loss 0.01160
Epoch [ 21]: Loss 0.01227
Epoch [ 21]: Loss 0.01068
Epoch [ 21]: Loss 0.01047
Validation: Loss 0.01178 Accuracy 1.00000
Validation: Loss 0.01104 Accuracy 1.00000
Epoch [ 22]: Loss 0.00959
Epoch [ 22]: Loss 0.01146
Epoch [ 22]: Loss 0.01187
Epoch [ 22]: Loss 0.01025
Epoch [ 22]: Loss 0.01066
Epoch [ 22]: Loss 0.01139
Epoch [ 22]: Loss 0.01058
Validation: Loss 0.01093 Accuracy 1.00000
Validation: Loss 0.01026 Accuracy 1.00000
Epoch [ 23]: Loss 0.01145
Epoch [ 23]: Loss 0.01032
Epoch [ 23]: Loss 0.01010
Epoch [ 23]: Loss 0.00949
Epoch [ 23]: Loss 0.01024
Epoch [ 23]: Loss 0.00916
Epoch [ 23]: Loss 0.00858
Validation: Loss 0.01003 Accuracy 1.00000
Validation: Loss 0.00942 Accuracy 1.00000
Epoch [ 24]: Loss 0.00967
Epoch [ 24]: Loss 0.00945
Epoch [ 24]: Loss 0.01021
Epoch [ 24]: Loss 0.00785
Epoch [ 24]: Loss 0.00870
Epoch [ 24]: Loss 0.00925
Epoch [ 24]: Loss 0.00876
Validation: Loss 0.00901 Accuracy 1.00000
Validation: Loss 0.00848 Accuracy 1.00000
Epoch [ 25]: Loss 0.00902
Epoch [ 25]: Loss 0.00929
Epoch [ 25]: Loss 0.00779
Epoch [ 25]: Loss 0.00781
Epoch [ 25]: Loss 0.00791
Epoch [ 25]: Loss 0.00782
Epoch [ 25]: Loss 0.00667
Validation: Loss 0.00801 Accuracy 1.00000
Validation: Loss 0.00756 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [ 1]: Loss 0.62686
Epoch [ 1]: Loss 0.59189
Epoch [ 1]: Loss 0.55786
Epoch [ 1]: Loss 0.53801
Epoch [ 1]: Loss 0.51328
Epoch [ 1]: Loss 0.49906
Epoch [ 1]: Loss 0.49657
Validation: Loss 0.47909 Accuracy 1.00000
Validation: Loss 0.46856 Accuracy 1.00000
Epoch [ 2]: Loss 0.46447
Epoch [ 2]: Loss 0.45461
Epoch [ 2]: Loss 0.44614
Epoch [ 2]: Loss 0.42496
Epoch [ 2]: Loss 0.40764
Epoch [ 2]: Loss 0.39277
Epoch [ 2]: Loss 0.37983
Validation: Loss 0.38367 Accuracy 1.00000
Validation: Loss 0.37133 Accuracy 1.00000
Epoch [ 3]: Loss 0.37505
Epoch [ 3]: Loss 0.35015
Epoch [ 3]: Loss 0.34280
Epoch [ 3]: Loss 0.33740
Epoch [ 3]: Loss 0.31918
Epoch [ 3]: Loss 0.30545
Epoch [ 3]: Loss 0.29285
Validation: Loss 0.29977 Accuracy 1.00000
Validation: Loss 0.28606 Accuracy 1.00000
Epoch [ 4]: Loss 0.29291
Epoch [ 4]: Loss 0.27805
Epoch [ 4]: Loss 0.25928
Epoch [ 4]: Loss 0.25599
Epoch [ 4]: Loss 0.23213
Epoch [ 4]: Loss 0.23241
Epoch [ 4]: Loss 0.21639
Validation: Loss 0.22957 Accuracy 1.00000
Validation: Loss 0.21572 Accuracy 1.00000
Epoch [ 5]: Loss 0.21074
Epoch [ 5]: Loss 0.20613
Epoch [ 5]: Loss 0.19065
Epoch [ 5]: Loss 0.18171
Epoch [ 5]: Loss 0.18299
Epoch [ 5]: Loss 0.18283
Epoch [ 5]: Loss 0.17709
Validation: Loss 0.17279 Accuracy 1.00000
Validation: Loss 0.15992 Accuracy 1.00000
Epoch [ 6]: Loss 0.15913
Epoch [ 6]: Loss 0.13994
Epoch [ 6]: Loss 0.14437
Epoch [ 6]: Loss 0.14647
Epoch [ 6]: Loss 0.13224
Epoch [ 6]: Loss 0.12985
Epoch [ 6]: Loss 0.12269
Validation: Loss 0.12753 Accuracy 1.00000
Validation: Loss 0.11695 Accuracy 1.00000
Epoch [ 7]: Loss 0.11127
Epoch [ 7]: Loss 0.10331
Epoch [ 7]: Loss 0.10289
Epoch [ 7]: Loss 0.10098
Epoch [ 7]: Loss 0.09919
Epoch [ 7]: Loss 0.10000
Epoch [ 7]: Loss 0.08501
Validation: Loss 0.09140 Accuracy 1.00000
Validation: Loss 0.08359 Accuracy 1.00000
Epoch [ 8]: Loss 0.08467
Epoch [ 8]: Loss 0.07826
Epoch [ 8]: Loss 0.07501
Epoch [ 8]: Loss 0.06869
Epoch [ 8]: Loss 0.06688
Epoch [ 8]: Loss 0.06293
Epoch [ 8]: Loss 0.06530
Validation: Loss 0.06322 Accuracy 1.00000
Validation: Loss 0.05807 Accuracy 1.00000
Epoch [ 9]: Loss 0.05629
Epoch [ 9]: Loss 0.05545
Epoch [ 9]: Loss 0.05066
Epoch [ 9]: Loss 0.05012
Epoch [ 9]: Loss 0.04763
Epoch [ 9]: Loss 0.04766
Epoch [ 9]: Loss 0.04326
Validation: Loss 0.04666 Accuracy 1.00000
Validation: Loss 0.04308 Accuracy 1.00000
Epoch [ 10]: Loss 0.04389
Epoch [ 10]: Loss 0.04210
Epoch [ 10]: Loss 0.03645
Epoch [ 10]: Loss 0.03910
Epoch [ 10]: Loss 0.03769
Epoch [ 10]: Loss 0.03806
Epoch [ 10]: Loss 0.03123
Validation: Loss 0.03779 Accuracy 1.00000
Validation: Loss 0.03489 Accuracy 1.00000
Epoch [ 11]: Loss 0.03265
Epoch [ 11]: Loss 0.03494
Epoch [ 11]: Loss 0.03272
Epoch [ 11]: Loss 0.03270
Epoch [ 11]: Loss 0.03040
Epoch [ 11]: Loss 0.03023
Epoch [ 11]: Loss 0.03437
Validation: Loss 0.03216 Accuracy 1.00000
Validation: Loss 0.02964 Accuracy 1.00000
Epoch [ 12]: Loss 0.02917
Epoch [ 12]: Loss 0.02831
Epoch [ 12]: Loss 0.02678
Epoch [ 12]: Loss 0.02790
Epoch [ 12]: Loss 0.02922
Epoch [ 12]: Loss 0.02570
Epoch [ 12]: Loss 0.02542
Validation: Loss 0.02803 Accuracy 1.00000
Validation: Loss 0.02579 Accuracy 1.00000
Epoch [ 13]: Loss 0.02648
Epoch [ 13]: Loss 0.02515
Epoch [ 13]: Loss 0.02122
Epoch [ 13]: Loss 0.02382
Epoch [ 13]: Loss 0.02528
Epoch [ 13]: Loss 0.02322
Epoch [ 13]: Loss 0.02647
Validation: Loss 0.02483 Accuracy 1.00000
Validation: Loss 0.02281 Accuracy 1.00000
Epoch [ 14]: Loss 0.02164
Epoch [ 14]: Loss 0.02345
Epoch [ 14]: Loss 0.02228
Epoch [ 14]: Loss 0.02193
Epoch [ 14]: Loss 0.01929
Epoch [ 14]: Loss 0.02163
Epoch [ 14]: Loss 0.01850
Validation: Loss 0.02224 Accuracy 1.00000
Validation: Loss 0.02039 Accuracy 1.00000
Epoch [ 15]: Loss 0.01931
Epoch [ 15]: Loss 0.02210
Epoch [ 15]: Loss 0.02034
Epoch [ 15]: Loss 0.02005
Epoch [ 15]: Loss 0.01800
Epoch [ 15]: Loss 0.01757
Epoch [ 15]: Loss 0.01457
Validation: Loss 0.02010 Accuracy 1.00000
Validation: Loss 0.01840 Accuracy 1.00000
Epoch [ 16]: Loss 0.01741
Epoch [ 16]: Loss 0.01850
Epoch [ 16]: Loss 0.01726
Epoch [ 16]: Loss 0.01816
Epoch [ 16]: Loss 0.01891
Epoch [ 16]: Loss 0.01527
Epoch [ 16]: Loss 0.01571
Validation: Loss 0.01833 Accuracy 1.00000
Validation: Loss 0.01674 Accuracy 1.00000
Epoch [ 17]: Loss 0.01488
Epoch [ 17]: Loss 0.01748
Epoch [ 17]: Loss 0.01788
Epoch [ 17]: Loss 0.01582
Epoch [ 17]: Loss 0.01562
Epoch [ 17]: Loss 0.01453
Epoch [ 17]: Loss 0.01434
Validation: Loss 0.01682 Accuracy 1.00000
Validation: Loss 0.01533 Accuracy 1.00000
Epoch [ 18]: Loss 0.01549
Epoch [ 18]: Loss 0.01545
Epoch [ 18]: Loss 0.01429
Epoch [ 18]: Loss 0.01438
Epoch [ 18]: Loss 0.01463
Epoch [ 18]: Loss 0.01420
Epoch [ 18]: Loss 0.01265
Validation: Loss 0.01552 Accuracy 1.00000
Validation: Loss 0.01413 Accuracy 1.00000
Epoch [ 19]: Loss 0.01415
Epoch [ 19]: Loss 0.01343
Epoch [ 19]: Loss 0.01410
Epoch [ 19]: Loss 0.01382
Epoch [ 19]: Loss 0.01262
Epoch [ 19]: Loss 0.01340
Epoch [ 19]: Loss 0.01219
Validation: Loss 0.01439 Accuracy 1.00000
Validation: Loss 0.01309 Accuracy 1.00000
Epoch [ 20]: Loss 0.01310
Epoch [ 20]: Loss 0.01309
Epoch [ 20]: Loss 0.01232
Epoch [ 20]: Loss 0.01270
Epoch [ 20]: Loss 0.01170
Epoch [ 20]: Loss 0.01280
Epoch [ 20]: Loss 0.01085
Validation: Loss 0.01338 Accuracy 1.00000
Validation: Loss 0.01217 Accuracy 1.00000
Epoch [ 21]: Loss 0.01230
Epoch [ 21]: Loss 0.01182
Epoch [ 21]: Loss 0.01157
Epoch [ 21]: Loss 0.01169
Epoch [ 21]: Loss 0.01145
Epoch [ 21]: Loss 0.01122
Epoch [ 21]: Loss 0.01146
Validation: Loss 0.01244 Accuracy 1.00000
Validation: Loss 0.01131 Accuracy 1.00000
Epoch [ 22]: Loss 0.01166
Epoch [ 22]: Loss 0.01033
Epoch [ 22]: Loss 0.01107
Epoch [ 22]: Loss 0.01032
Epoch [ 22]: Loss 0.01042
Epoch [ 22]: Loss 0.01103
Epoch [ 22]: Loss 0.01143
Validation: Loss 0.01151 Accuracy 1.00000
Validation: Loss 0.01047 Accuracy 1.00000
Epoch [ 23]: Loss 0.01118
Epoch [ 23]: Loss 0.01006
Epoch [ 23]: Loss 0.01062
Epoch [ 23]: Loss 0.00959
Epoch [ 23]: Loss 0.00881
Epoch [ 23]: Loss 0.00999
Epoch [ 23]: Loss 0.00856
Validation: Loss 0.01046 Accuracy 1.00000
Validation: Loss 0.00953 Accuracy 1.00000
Epoch [ 24]: Loss 0.00899
Epoch [ 24]: Loss 0.00875
Epoch [ 24]: Loss 0.00960
Epoch [ 24]: Loss 0.00964
Epoch [ 24]: Loss 0.00847
Epoch [ 24]: Loss 0.00865
Epoch [ 24]: Loss 0.00893
Validation: Loss 0.00931 Accuracy 1.00000
Validation: Loss 0.00851 Accuracy 1.00000
Epoch [ 25]: Loss 0.00861
Epoch [ 25]: Loss 0.00737
Epoch [ 25]: Loss 0.00861
Epoch [ 25]: Loss 0.00735
Epoch [ 25]: Loss 0.00812
Epoch [ 25]: Loss 0.00804
Epoch [ 25]: Loss 0.00828
Validation: Loss 0.00829 Accuracy 1.00000
Validation: Loss 0.00760 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0
Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.422 GiB / 4.750 GiB available)
This page was generated using Literate.jl.