Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, StatisticsDataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
endget_dataloaders (generic function with 1 method)Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
endWe won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
endMain.var"##225".SpiralClassifierWe can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
endUsing the @compact API
We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
endSpiralClassifierCompact (generic function with 1 method)Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)accuracy (generic function with 1 method)Training the Model
function main(model_type)
dev = gpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() .|> dev
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
for epoch in 1:25
# Train the model
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)Epoch [ 1]: Loss 0.62718
Epoch [ 1]: Loss 0.60001
Epoch [ 1]: Loss 0.56268
Epoch [ 1]: Loss 0.53181
Epoch [ 1]: Loss 0.51388
Epoch [ 1]: Loss 0.51194
Epoch [ 1]: Loss 0.47500
Validation: Loss 0.46404 Accuracy 1.00000
Validation: Loss 0.47038 Accuracy 1.00000
Epoch [ 2]: Loss 0.47518
Epoch [ 2]: Loss 0.45419
Epoch [ 2]: Loss 0.43895
Epoch [ 2]: Loss 0.41477
Epoch [ 2]: Loss 0.42107
Epoch [ 2]: Loss 0.40054
Epoch [ 2]: Loss 0.37514
Validation: Loss 0.36615 Accuracy 1.00000
Validation: Loss 0.37387 Accuracy 1.00000
Epoch [ 3]: Loss 0.38160
Epoch [ 3]: Loss 0.36141
Epoch [ 3]: Loss 0.33932
Epoch [ 3]: Loss 0.31804
Epoch [ 3]: Loss 0.33407
Epoch [ 3]: Loss 0.31477
Epoch [ 3]: Loss 0.27711
Validation: Loss 0.28073 Accuracy 1.00000
Validation: Loss 0.28935 Accuracy 1.00000
Epoch [ 4]: Loss 0.28997
Epoch [ 4]: Loss 0.28240
Epoch [ 4]: Loss 0.25862
Epoch [ 4]: Loss 0.24393
Epoch [ 4]: Loss 0.25496
Epoch [ 4]: Loss 0.24056
Epoch [ 4]: Loss 0.21530
Validation: Loss 0.21101 Accuracy 1.00000
Validation: Loss 0.21973 Accuracy 1.00000
Epoch [ 5]: Loss 0.22257
Epoch [ 5]: Loss 0.21121
Epoch [ 5]: Loss 0.20601
Epoch [ 5]: Loss 0.18332
Epoch [ 5]: Loss 0.18898
Epoch [ 5]: Loss 0.17421
Epoch [ 5]: Loss 0.14663
Validation: Loss 0.15600 Accuracy 1.00000
Validation: Loss 0.16392 Accuracy 1.00000
Epoch [ 6]: Loss 0.16531
Epoch [ 6]: Loss 0.15557
Epoch [ 6]: Loss 0.14348
Epoch [ 6]: Loss 0.14633
Epoch [ 6]: Loss 0.13218
Epoch [ 6]: Loss 0.13313
Epoch [ 6]: Loss 0.12015
Validation: Loss 0.11405 Accuracy 1.00000
Validation: Loss 0.12051 Accuracy 1.00000
Epoch [ 7]: Loss 0.11521
Epoch [ 7]: Loss 0.10906
Epoch [ 7]: Loss 0.10325
Epoch [ 7]: Loss 0.10190
Epoch [ 7]: Loss 0.10488
Epoch [ 7]: Loss 0.09975
Epoch [ 7]: Loss 0.09609
Validation: Loss 0.08147 Accuracy 1.00000
Validation: Loss 0.08622 Accuracy 1.00000
Epoch [ 8]: Loss 0.08552
Epoch [ 8]: Loss 0.07674
Epoch [ 8]: Loss 0.07377
Epoch [ 8]: Loss 0.07847
Epoch [ 8]: Loss 0.07308
Epoch [ 8]: Loss 0.06279
Epoch [ 8]: Loss 0.05957
Validation: Loss 0.05666 Accuracy 1.00000
Validation: Loss 0.05977 Accuracy 1.00000
Epoch [ 9]: Loss 0.05632
Epoch [ 9]: Loss 0.05646
Epoch [ 9]: Loss 0.05808
Epoch [ 9]: Loss 0.04899
Epoch [ 9]: Loss 0.04856
Epoch [ 9]: Loss 0.04833
Epoch [ 9]: Loss 0.04532
Validation: Loss 0.04231 Accuracy 1.00000
Validation: Loss 0.04448 Accuracy 1.00000
Epoch [ 10]: Loss 0.04456
Epoch [ 10]: Loss 0.04584
Epoch [ 10]: Loss 0.04199
Epoch [ 10]: Loss 0.04006
Epoch [ 10]: Loss 0.03755
Epoch [ 10]: Loss 0.03498
Epoch [ 10]: Loss 0.03402
Validation: Loss 0.03427 Accuracy 1.00000
Validation: Loss 0.03602 Accuracy 1.00000
Epoch [ 11]: Loss 0.03817
Epoch [ 11]: Loss 0.03232
Epoch [ 11]: Loss 0.03371
Epoch [ 11]: Loss 0.03322
Epoch [ 11]: Loss 0.03166
Epoch [ 11]: Loss 0.03205
Epoch [ 11]: Loss 0.03184
Validation: Loss 0.02910 Accuracy 1.00000
Validation: Loss 0.03062 Accuracy 1.00000
Epoch [ 12]: Loss 0.02991
Epoch [ 12]: Loss 0.02674
Epoch [ 12]: Loss 0.02942
Epoch [ 12]: Loss 0.02688
Epoch [ 12]: Loss 0.02973
Epoch [ 12]: Loss 0.02892
Epoch [ 12]: Loss 0.03049
Validation: Loss 0.02532 Accuracy 1.00000
Validation: Loss 0.02667 Accuracy 1.00000
Epoch [ 13]: Loss 0.02762
Epoch [ 13]: Loss 0.02623
Epoch [ 13]: Loss 0.02642
Epoch [ 13]: Loss 0.02441
Epoch [ 13]: Loss 0.02249
Epoch [ 13]: Loss 0.02463
Epoch [ 13]: Loss 0.02190
Validation: Loss 0.02237 Accuracy 1.00000
Validation: Loss 0.02357 Accuracy 1.00000
Epoch [ 14]: Loss 0.02340
Epoch [ 14]: Loss 0.02223
Epoch [ 14]: Loss 0.02384
Epoch [ 14]: Loss 0.02210
Epoch [ 14]: Loss 0.02214
Epoch [ 14]: Loss 0.02104
Epoch [ 14]: Loss 0.01959
Validation: Loss 0.02001 Accuracy 1.00000
Validation: Loss 0.02111 Accuracy 1.00000
Epoch [ 15]: Loss 0.02303
Epoch [ 15]: Loss 0.02102
Epoch [ 15]: Loss 0.01961
Epoch [ 15]: Loss 0.01768
Epoch [ 15]: Loss 0.01868
Epoch [ 15]: Loss 0.01983
Epoch [ 15]: Loss 0.02250
Validation: Loss 0.01805 Accuracy 1.00000
Validation: Loss 0.01906 Accuracy 1.00000
Epoch [ 16]: Loss 0.01995
Epoch [ 16]: Loss 0.01832
Epoch [ 16]: Loss 0.01789
Epoch [ 16]: Loss 0.01749
Epoch [ 16]: Loss 0.01922
Epoch [ 16]: Loss 0.01702
Epoch [ 16]: Loss 0.01469
Validation: Loss 0.01639 Accuracy 1.00000
Validation: Loss 0.01732 Accuracy 1.00000
Epoch [ 17]: Loss 0.01703
Epoch [ 17]: Loss 0.01746
Epoch [ 17]: Loss 0.01629
Epoch [ 17]: Loss 0.01591
Epoch [ 17]: Loss 0.01679
Epoch [ 17]: Loss 0.01582
Epoch [ 17]: Loss 0.01635
Validation: Loss 0.01497 Accuracy 1.00000
Validation: Loss 0.01584 Accuracy 1.00000
Epoch [ 18]: Loss 0.01484
Epoch [ 18]: Loss 0.01554
Epoch [ 18]: Loss 0.01549
Epoch [ 18]: Loss 0.01544
Epoch [ 18]: Loss 0.01598
Epoch [ 18]: Loss 0.01394
Epoch [ 18]: Loss 0.01392
Validation: Loss 0.01374 Accuracy 1.00000
Validation: Loss 0.01456 Accuracy 1.00000
Epoch [ 19]: Loss 0.01416
Epoch [ 19]: Loss 0.01386
Epoch [ 19]: Loss 0.01645
Epoch [ 19]: Loss 0.01301
Epoch [ 19]: Loss 0.01459
Epoch [ 19]: Loss 0.01244
Epoch [ 19]: Loss 0.01071
Validation: Loss 0.01267 Accuracy 1.00000
Validation: Loss 0.01343 Accuracy 1.00000
Epoch [ 20]: Loss 0.01410
Epoch [ 20]: Loss 0.01253
Epoch [ 20]: Loss 0.01290
Epoch [ 20]: Loss 0.01350
Epoch [ 20]: Loss 0.01337
Epoch [ 20]: Loss 0.01160
Epoch [ 20]: Loss 0.00991
Validation: Loss 0.01170 Accuracy 1.00000
Validation: Loss 0.01241 Accuracy 1.00000
Epoch [ 21]: Loss 0.01232
Epoch [ 21]: Loss 0.01145
Epoch [ 21]: Loss 0.01166
Epoch [ 21]: Loss 0.01327
Epoch [ 21]: Loss 0.01196
Epoch [ 21]: Loss 0.01068
Epoch [ 21]: Loss 0.01135
Validation: Loss 0.01074 Accuracy 1.00000
Validation: Loss 0.01139 Accuracy 1.00000
Epoch [ 22]: Loss 0.01161
Epoch [ 22]: Loss 0.01047
Epoch [ 22]: Loss 0.01185
Epoch [ 22]: Loss 0.01052
Epoch [ 22]: Loss 0.01003
Epoch [ 22]: Loss 0.01051
Epoch [ 22]: Loss 0.01061
Validation: Loss 0.00967 Accuracy 1.00000
Validation: Loss 0.01025 Accuracy 1.00000
Epoch [ 23]: Loss 0.01008
Epoch [ 23]: Loss 0.00990
Epoch [ 23]: Loss 0.00952
Epoch [ 23]: Loss 0.00935
Epoch [ 23]: Loss 0.00939
Epoch [ 23]: Loss 0.01016
Epoch [ 23]: Loss 0.00770
Validation: Loss 0.00857 Accuracy 1.00000
Validation: Loss 0.00907 Accuracy 1.00000
Epoch [ 24]: Loss 0.00948
Epoch [ 24]: Loss 0.00843
Epoch [ 24]: Loss 0.00907
Epoch [ 24]: Loss 0.00849
Epoch [ 24]: Loss 0.00797
Epoch [ 24]: Loss 0.00815
Epoch [ 24]: Loss 0.00811
Validation: Loss 0.00771 Accuracy 1.00000
Validation: Loss 0.00814 Accuracy 1.00000
Epoch [ 25]: Loss 0.00842
Epoch [ 25]: Loss 0.00804
Epoch [ 25]: Loss 0.00669
Epoch [ 25]: Loss 0.00810
Epoch [ 25]: Loss 0.00774
Epoch [ 25]: Loss 0.00772
Epoch [ 25]: Loss 0.00744
Validation: Loss 0.00709 Accuracy 1.00000
Validation: Loss 0.00747 Accuracy 1.00000We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)Epoch [ 1]: Loss 0.62108
Epoch [ 1]: Loss 0.58121
Epoch [ 1]: Loss 0.56262
Epoch [ 1]: Loss 0.54752
Epoch [ 1]: Loss 0.52864
Epoch [ 1]: Loss 0.49979
Epoch [ 1]: Loss 0.46702
Validation: Loss 0.47471 Accuracy 1.00000
Validation: Loss 0.46144 Accuracy 1.00000
Epoch [ 2]: Loss 0.46684
Epoch [ 2]: Loss 0.44808
Epoch [ 2]: Loss 0.44013
Epoch [ 2]: Loss 0.42881
Epoch [ 2]: Loss 0.41686
Epoch [ 2]: Loss 0.40113
Epoch [ 2]: Loss 0.39721
Validation: Loss 0.37957 Accuracy 1.00000
Validation: Loss 0.36353 Accuracy 1.00000
Epoch [ 3]: Loss 0.37677
Epoch [ 3]: Loss 0.36023
Epoch [ 3]: Loss 0.34351
Epoch [ 3]: Loss 0.33917
Epoch [ 3]: Loss 0.32752
Epoch [ 3]: Loss 0.30449
Epoch [ 3]: Loss 0.28897
Validation: Loss 0.29574 Accuracy 1.00000
Validation: Loss 0.27846 Accuracy 1.00000
Epoch [ 4]: Loss 0.27458
Epoch [ 4]: Loss 0.27470
Epoch [ 4]: Loss 0.26752
Epoch [ 4]: Loss 0.26272
Epoch [ 4]: Loss 0.24863
Epoch [ 4]: Loss 0.24116
Epoch [ 4]: Loss 0.24974
Validation: Loss 0.22620 Accuracy 1.00000
Validation: Loss 0.20871 Accuracy 1.00000
Epoch [ 5]: Loss 0.21805
Epoch [ 5]: Loss 0.21271
Epoch [ 5]: Loss 0.19800
Epoch [ 5]: Loss 0.20039
Epoch [ 5]: Loss 0.18249
Epoch [ 5]: Loss 0.17278
Epoch [ 5]: Loss 0.18034
Validation: Loss 0.16975 Accuracy 1.00000
Validation: Loss 0.15377 Accuracy 1.00000
Epoch [ 6]: Loss 0.16734
Epoch [ 6]: Loss 0.15922
Epoch [ 6]: Loss 0.15422
Epoch [ 6]: Loss 0.14687
Epoch [ 6]: Loss 0.12528
Epoch [ 6]: Loss 0.13232
Epoch [ 6]: Loss 0.10874
Validation: Loss 0.12526 Accuracy 1.00000
Validation: Loss 0.11210 Accuracy 1.00000
Epoch [ 7]: Loss 0.12720
Epoch [ 7]: Loss 0.11364
Epoch [ 7]: Loss 0.10651
Epoch [ 7]: Loss 0.09527
Epoch [ 7]: Loss 0.09629
Epoch [ 7]: Loss 0.09677
Epoch [ 7]: Loss 0.11688
Validation: Loss 0.09022 Accuracy 1.00000
Validation: Loss 0.08035 Accuracy 1.00000
Epoch [ 8]: Loss 0.07898
Epoch [ 8]: Loss 0.08638
Epoch [ 8]: Loss 0.08011
Epoch [ 8]: Loss 0.07609
Epoch [ 8]: Loss 0.06828
Epoch [ 8]: Loss 0.06253
Epoch [ 8]: Loss 0.07626
Validation: Loss 0.06259 Accuracy 1.00000
Validation: Loss 0.05604 Accuracy 1.00000
Epoch [ 9]: Loss 0.05683
Epoch [ 9]: Loss 0.05699
Epoch [ 9]: Loss 0.05471
Epoch [ 9]: Loss 0.05135
Epoch [ 9]: Loss 0.04590
Epoch [ 9]: Loss 0.05185
Epoch [ 9]: Loss 0.05025
Validation: Loss 0.04591 Accuracy 1.00000
Validation: Loss 0.04145 Accuracy 1.00000
Epoch [ 10]: Loss 0.04658
Epoch [ 10]: Loss 0.04182
Epoch [ 10]: Loss 0.04112
Epoch [ 10]: Loss 0.03920
Epoch [ 10]: Loss 0.03553
Epoch [ 10]: Loss 0.03901
Epoch [ 10]: Loss 0.03768
Validation: Loss 0.03693 Accuracy 1.00000
Validation: Loss 0.03337 Accuracy 1.00000
Epoch [ 11]: Loss 0.03345
Epoch [ 11]: Loss 0.03682
Epoch [ 11]: Loss 0.03255
Epoch [ 11]: Loss 0.03261
Epoch [ 11]: Loss 0.03287
Epoch [ 11]: Loss 0.03009
Epoch [ 11]: Loss 0.03583
Validation: Loss 0.03127 Accuracy 1.00000
Validation: Loss 0.02819 Accuracy 1.00000
Epoch [ 12]: Loss 0.02902
Epoch [ 12]: Loss 0.02865
Epoch [ 12]: Loss 0.02972
Epoch [ 12]: Loss 0.02890
Epoch [ 12]: Loss 0.02737
Epoch [ 12]: Loss 0.02684
Epoch [ 12]: Loss 0.02747
Validation: Loss 0.02718 Accuracy 1.00000
Validation: Loss 0.02445 Accuracy 1.00000
Epoch [ 13]: Loss 0.02694
Epoch [ 13]: Loss 0.02756
Epoch [ 13]: Loss 0.02620
Epoch [ 13]: Loss 0.02382
Epoch [ 13]: Loss 0.02297
Epoch [ 13]: Loss 0.02257
Epoch [ 13]: Loss 0.02096
Validation: Loss 0.02403 Accuracy 1.00000
Validation: Loss 0.02157 Accuracy 1.00000
Epoch [ 14]: Loss 0.02281
Epoch [ 14]: Loss 0.02363
Epoch [ 14]: Loss 0.02247
Epoch [ 14]: Loss 0.02308
Epoch [ 14]: Loss 0.01987
Epoch [ 14]: Loss 0.02068
Epoch [ 14]: Loss 0.02092
Validation: Loss 0.02153 Accuracy 1.00000
Validation: Loss 0.01927 Accuracy 1.00000
Epoch [ 15]: Loss 0.02226
Epoch [ 15]: Loss 0.01992
Epoch [ 15]: Loss 0.02071
Epoch [ 15]: Loss 0.02096
Epoch [ 15]: Loss 0.01800
Epoch [ 15]: Loss 0.01844
Epoch [ 15]: Loss 0.01404
Validation: Loss 0.01946 Accuracy 1.00000
Validation: Loss 0.01737 Accuracy 1.00000
Epoch [ 16]: Loss 0.01877
Epoch [ 16]: Loss 0.01920
Epoch [ 16]: Loss 0.01775
Epoch [ 16]: Loss 0.01801
Epoch [ 16]: Loss 0.01712
Epoch [ 16]: Loss 0.01690
Epoch [ 16]: Loss 0.01712
Validation: Loss 0.01773 Accuracy 1.00000
Validation: Loss 0.01578 Accuracy 1.00000
Epoch [ 17]: Loss 0.01852
Epoch [ 17]: Loss 0.01655
Epoch [ 17]: Loss 0.01633
Epoch [ 17]: Loss 0.01551
Epoch [ 17]: Loss 0.01563
Epoch [ 17]: Loss 0.01572
Epoch [ 17]: Loss 0.01569
Validation: Loss 0.01624 Accuracy 1.00000
Validation: Loss 0.01442 Accuracy 1.00000
Epoch [ 18]: Loss 0.01549
Epoch [ 18]: Loss 0.01472
Epoch [ 18]: Loss 0.01498
Epoch [ 18]: Loss 0.01510
Epoch [ 18]: Loss 0.01462
Epoch [ 18]: Loss 0.01488
Epoch [ 18]: Loss 0.01556
Validation: Loss 0.01497 Accuracy 1.00000
Validation: Loss 0.01326 Accuracy 1.00000
Epoch [ 19]: Loss 0.01426
Epoch [ 19]: Loss 0.01520
Epoch [ 19]: Loss 0.01354
Epoch [ 19]: Loss 0.01244
Epoch [ 19]: Loss 0.01310
Epoch [ 19]: Loss 0.01428
Epoch [ 19]: Loss 0.01439
Validation: Loss 0.01385 Accuracy 1.00000
Validation: Loss 0.01226 Accuracy 1.00000
Epoch [ 20]: Loss 0.01332
Epoch [ 20]: Loss 0.01318
Epoch [ 20]: Loss 0.01304
Epoch [ 20]: Loss 0.01152
Epoch [ 20]: Loss 0.01327
Epoch [ 20]: Loss 0.01263
Epoch [ 20]: Loss 0.01237
Validation: Loss 0.01286 Accuracy 1.00000
Validation: Loss 0.01137 Accuracy 1.00000
Epoch [ 21]: Loss 0.01225
Epoch [ 21]: Loss 0.01212
Epoch [ 21]: Loss 0.01150
Epoch [ 21]: Loss 0.01151
Epoch [ 21]: Loss 0.01209
Epoch [ 21]: Loss 0.01174
Epoch [ 21]: Loss 0.01240
Validation: Loss 0.01195 Accuracy 1.00000
Validation: Loss 0.01056 Accuracy 1.00000
Epoch [ 22]: Loss 0.01239
Epoch [ 22]: Loss 0.01092
Epoch [ 22]: Loss 0.01136
Epoch [ 22]: Loss 0.01035
Epoch [ 22]: Loss 0.01143
Epoch [ 22]: Loss 0.01002
Epoch [ 22]: Loss 0.01020
Validation: Loss 0.01104 Accuracy 1.00000
Validation: Loss 0.00976 Accuracy 1.00000
Epoch [ 23]: Loss 0.01031
Epoch [ 23]: Loss 0.01091
Epoch [ 23]: Loss 0.01018
Epoch [ 23]: Loss 0.00992
Epoch [ 23]: Loss 0.01038
Epoch [ 23]: Loss 0.00967
Epoch [ 23]: Loss 0.00858
Validation: Loss 0.01006 Accuracy 1.00000
Validation: Loss 0.00891 Accuracy 1.00000
Epoch [ 24]: Loss 0.00901
Epoch [ 24]: Loss 0.01045
Epoch [ 24]: Loss 0.00920
Epoch [ 24]: Loss 0.00922
Epoch [ 24]: Loss 0.00902
Epoch [ 24]: Loss 0.00844
Epoch [ 24]: Loss 0.00838
Validation: Loss 0.00896 Accuracy 1.00000
Validation: Loss 0.00796 Accuracy 1.00000
Epoch [ 25]: Loss 0.00896
Epoch [ 25]: Loss 0.00849
Epoch [ 25]: Loss 0.00825
Epoch [ 25]: Loss 0.00782
Epoch [ 25]: Loss 0.00822
Epoch [ 25]: Loss 0.00758
Epoch [ 25]: Loss 0.00717
Validation: Loss 0.00798 Accuracy 1.00000
Validation: Loss 0.00713 Accuracy 1.00000Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trainedLet's try loading the model
@load "trained_model.jld2" ps_trained st_trained2-element Vector{Symbol}:
:ps_trained
:st_trainedAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6
CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+555.42.6
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0
Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.422 GiB / 4.750 GiB available)This page was generated using Literate.jl.