Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, AMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
Statistics
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractExplicitContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <:
Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
# Get the dataloaders
(train_loader, val_loader) = get_dataloaders()
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
dev = gpu_device()
train_state = Training.TrainState(rng, model, Adam(0.01f0); transform_variables=dev)
for epoch in 1:25
# Train the model
for (x, y) in train_loader
x = x |> dev
y = y |> dev
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
x = x |> dev
y = y |> dev
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
Epoch [ 1]: Loss 0.73673
Epoch [ 1]: Loss 0.70158
Epoch [ 1]: Loss 0.63264
Epoch [ 1]: Loss 0.60179
Epoch [ 1]: Loss 0.56753
Epoch [ 1]: Loss 0.54071
Epoch [ 1]: Loss 0.50607
Validation: Loss 0.47787 Accuracy 1.00000
Validation: Loss 0.47686 Accuracy 1.00000
Epoch [ 2]: Loss 0.47749
Epoch [ 2]: Loss 0.45250
Epoch [ 2]: Loss 0.43124
Epoch [ 2]: Loss 0.40331
Epoch [ 2]: Loss 0.38042
Epoch [ 2]: Loss 0.37468
Epoch [ 2]: Loss 0.35117
Validation: Loss 0.33918 Accuracy 1.00000
Validation: Loss 0.33153 Accuracy 1.00000
Epoch [ 3]: Loss 0.32913
Epoch [ 3]: Loss 0.32219
Epoch [ 3]: Loss 0.30392
Epoch [ 3]: Loss 0.27819
Epoch [ 3]: Loss 0.27343
Epoch [ 3]: Loss 0.25045
Epoch [ 3]: Loss 0.23073
Validation: Loss 0.23854 Accuracy 1.00000
Validation: Loss 0.23018 Accuracy 1.00000
Epoch [ 4]: Loss 0.22694
Epoch [ 4]: Loss 0.22916
Epoch [ 4]: Loss 0.21602
Epoch [ 4]: Loss 0.19636
Epoch [ 4]: Loss 0.19386
Epoch [ 4]: Loss 0.18130
Epoch [ 4]: Loss 0.17992
Validation: Loss 0.17364 Accuracy 1.00000
Validation: Loss 0.16782 Accuracy 1.00000
Epoch [ 5]: Loss 0.16791
Epoch [ 5]: Loss 0.16248
Epoch [ 5]: Loss 0.15579
Epoch [ 5]: Loss 0.14473
Epoch [ 5]: Loss 0.14685
Epoch [ 5]: Loss 0.13079
Epoch [ 5]: Loss 0.13320
Validation: Loss 0.12662 Accuracy 1.00000
Validation: Loss 0.12291 Accuracy 1.00000
Epoch [ 6]: Loss 0.12614
Epoch [ 6]: Loss 0.12016
Epoch [ 6]: Loss 0.10961
Epoch [ 6]: Loss 0.10798
Epoch [ 6]: Loss 0.10401
Epoch [ 6]: Loss 0.10008
Epoch [ 6]: Loss 0.08811
Validation: Loss 0.09265 Accuracy 1.00000
Validation: Loss 0.09001 Accuracy 1.00000
Epoch [ 7]: Loss 0.09078
Epoch [ 7]: Loss 0.08520
Epoch [ 7]: Loss 0.08291
Epoch [ 7]: Loss 0.08204
Epoch [ 7]: Loss 0.07455
Epoch [ 7]: Loss 0.07242
Epoch [ 7]: Loss 0.06951
Validation: Loss 0.06840 Accuracy 1.00000
Validation: Loss 0.06612 Accuracy 1.00000
Epoch [ 8]: Loss 0.06580
Epoch [ 8]: Loss 0.06198
Epoch [ 8]: Loss 0.06142
Epoch [ 8]: Loss 0.05998
Epoch [ 8]: Loss 0.05557
Epoch [ 8]: Loss 0.05431
Epoch [ 8]: Loss 0.05521
Validation: Loss 0.05141 Accuracy 1.00000
Validation: Loss 0.04937 Accuracy 1.00000
Epoch [ 9]: Loss 0.05207
Epoch [ 9]: Loss 0.04888
Epoch [ 9]: Loss 0.04669
Epoch [ 9]: Loss 0.04152
Epoch [ 9]: Loss 0.04283
Epoch [ 9]: Loss 0.04094
Epoch [ 9]: Loss 0.03625
Validation: Loss 0.04003 Accuracy 1.00000
Validation: Loss 0.03828 Accuracy 1.00000
Epoch [ 10]: Loss 0.03828
Epoch [ 10]: Loss 0.03910
Epoch [ 10]: Loss 0.03497
Epoch [ 10]: Loss 0.03545
Epoch [ 10]: Loss 0.03477
Epoch [ 10]: Loss 0.03153
Epoch [ 10]: Loss 0.03228
Validation: Loss 0.03246 Accuracy 1.00000
Validation: Loss 0.03097 Accuracy 1.00000
Epoch [ 11]: Loss 0.03021
Epoch [ 11]: Loss 0.02952
Epoch [ 11]: Loss 0.02825
Epoch [ 11]: Loss 0.02870
Epoch [ 11]: Loss 0.02959
Epoch [ 11]: Loss 0.02824
Epoch [ 11]: Loss 0.02882
Validation: Loss 0.02713 Accuracy 1.00000
Validation: Loss 0.02587 Accuracy 1.00000
Epoch [ 12]: Loss 0.02659
Epoch [ 12]: Loss 0.02676
Epoch [ 12]: Loss 0.02497
Epoch [ 12]: Loss 0.02342
Epoch [ 12]: Loss 0.02317
Epoch [ 12]: Loss 0.02217
Epoch [ 12]: Loss 0.02518
Validation: Loss 0.02319 Accuracy 1.00000
Validation: Loss 0.02212 Accuracy 1.00000
Epoch [ 13]: Loss 0.02281
Epoch [ 13]: Loss 0.02095
Epoch [ 13]: Loss 0.02106
Epoch [ 13]: Loss 0.02160
Epoch [ 13]: Loss 0.02065
Epoch [ 13]: Loss 0.02028
Epoch [ 13]: Loss 0.01900
Validation: Loss 0.02023 Accuracy 1.00000
Validation: Loss 0.01931 Accuracy 1.00000
Epoch [ 14]: Loss 0.01852
Epoch [ 14]: Loss 0.01959
Epoch [ 14]: Loss 0.01824
Epoch [ 14]: Loss 0.01854
Epoch [ 14]: Loss 0.01904
Epoch [ 14]: Loss 0.01759
Epoch [ 14]: Loss 0.01812
Validation: Loss 0.01796 Accuracy 1.00000
Validation: Loss 0.01715 Accuracy 1.00000
Epoch [ 15]: Loss 0.01727
Epoch [ 15]: Loss 0.01769
Epoch [ 15]: Loss 0.01605
Epoch [ 15]: Loss 0.01674
Epoch [ 15]: Loss 0.01585
Epoch [ 15]: Loss 0.01592
Epoch [ 15]: Loss 0.01627
Validation: Loss 0.01615 Accuracy 1.00000
Validation: Loss 0.01542 Accuracy 1.00000
Epoch [ 16]: Loss 0.01573
Epoch [ 16]: Loss 0.01489
Epoch [ 16]: Loss 0.01565
Epoch [ 16]: Loss 0.01580
Epoch [ 16]: Loss 0.01494
Epoch [ 16]: Loss 0.01321
Epoch [ 16]: Loss 0.01322
Validation: Loss 0.01466 Accuracy 1.00000
Validation: Loss 0.01399 Accuracy 1.00000
Epoch [ 17]: Loss 0.01442
Epoch [ 17]: Loss 0.01405
Epoch [ 17]: Loss 0.01361
Epoch [ 17]: Loss 0.01267
Epoch [ 17]: Loss 0.01327
Epoch [ 17]: Loss 0.01355
Epoch [ 17]: Loss 0.01417
Validation: Loss 0.01342 Accuracy 1.00000
Validation: Loss 0.01280 Accuracy 1.00000
Epoch [ 18]: Loss 0.01312
Epoch [ 18]: Loss 0.01291
Epoch [ 18]: Loss 0.01174
Epoch [ 18]: Loss 0.01241
Epoch [ 18]: Loss 0.01284
Epoch [ 18]: Loss 0.01235
Epoch [ 18]: Loss 0.01083
Validation: Loss 0.01236 Accuracy 1.00000
Validation: Loss 0.01179 Accuracy 1.00000
Epoch [ 19]: Loss 0.01162
Epoch [ 19]: Loss 0.01194
Epoch [ 19]: Loss 0.01151
Epoch [ 19]: Loss 0.01092
Epoch [ 19]: Loss 0.01185
Epoch [ 19]: Loss 0.01089
Epoch [ 19]: Loss 0.01325
Validation: Loss 0.01145 Accuracy 1.00000
Validation: Loss 0.01091 Accuracy 1.00000
Epoch [ 20]: Loss 0.01127
Epoch [ 20]: Loss 0.01051
Epoch [ 20]: Loss 0.01103
Epoch [ 20]: Loss 0.01049
Epoch [ 20]: Loss 0.01059
Epoch [ 20]: Loss 0.01032
Epoch [ 20]: Loss 0.01039
Validation: Loss 0.01064 Accuracy 1.00000
Validation: Loss 0.01014 Accuracy 1.00000
Epoch [ 21]: Loss 0.01038
Epoch [ 21]: Loss 0.00982
Epoch [ 21]: Loss 0.00939
Epoch [ 21]: Loss 0.01058
Epoch [ 21]: Loss 0.00969
Epoch [ 21]: Loss 0.00991
Epoch [ 21]: Loss 0.00961
Validation: Loss 0.00994 Accuracy 1.00000
Validation: Loss 0.00946 Accuracy 1.00000
Epoch [ 22]: Loss 0.00944
Epoch [ 22]: Loss 0.00937
Epoch [ 22]: Loss 0.00959
Epoch [ 22]: Loss 0.00944
Epoch [ 22]: Loss 0.00904
Epoch [ 22]: Loss 0.00864
Epoch [ 22]: Loss 0.01035
Validation: Loss 0.00931 Accuracy 1.00000
Validation: Loss 0.00886 Accuracy 1.00000
Epoch [ 23]: Loss 0.00858
Epoch [ 23]: Loss 0.00836
Epoch [ 23]: Loss 0.00958
Epoch [ 23]: Loss 0.00844
Epoch [ 23]: Loss 0.00887
Epoch [ 23]: Loss 0.00815
Epoch [ 23]: Loss 0.00992
Validation: Loss 0.00874 Accuracy 1.00000
Validation: Loss 0.00832 Accuracy 1.00000
Epoch [ 24]: Loss 0.00839
Epoch [ 24]: Loss 0.00805
Epoch [ 24]: Loss 0.00835
Epoch [ 24]: Loss 0.00809
Epoch [ 24]: Loss 0.00814
Epoch [ 24]: Loss 0.00807
Epoch [ 24]: Loss 0.00842
Validation: Loss 0.00823 Accuracy 1.00000
Validation: Loss 0.00783 Accuracy 1.00000
Epoch [ 25]: Loss 0.00825
Epoch [ 25]: Loss 0.00820
Epoch [ 25]: Loss 0.00780
Epoch [ 25]: Loss 0.00770
Epoch [ 25]: Loss 0.00724
Epoch [ 25]: Loss 0.00707
Epoch [ 25]: Loss 0.00799
Validation: Loss 0.00777 Accuracy 1.00000
Validation: Loss 0.00739 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [ 1]: Loss 0.73144
Epoch [ 1]: Loss 0.67839
Epoch [ 1]: Loss 0.64396
Epoch [ 1]: Loss 0.59659
Epoch [ 1]: Loss 0.57715
Epoch [ 1]: Loss 0.53814
Epoch [ 1]: Loss 0.50134
Validation: Loss 0.47486 Accuracy 1.00000
Validation: Loss 0.47575 Accuracy 1.00000
Epoch [ 2]: Loss 0.47732
Epoch [ 2]: Loss 0.44956
Epoch [ 2]: Loss 0.42565
Epoch [ 2]: Loss 0.40488
Epoch [ 2]: Loss 0.38910
Epoch [ 2]: Loss 0.36951
Epoch [ 2]: Loss 0.35933
Validation: Loss 0.33479 Accuracy 1.00000
Validation: Loss 0.32591 Accuracy 1.00000
Epoch [ 3]: Loss 0.32664
Epoch [ 3]: Loss 0.31709
Epoch [ 3]: Loss 0.29801
Epoch [ 3]: Loss 0.28722
Epoch [ 3]: Loss 0.26316
Epoch [ 3]: Loss 0.25891
Epoch [ 3]: Loss 0.23415
Validation: Loss 0.23345 Accuracy 1.00000
Validation: Loss 0.22358 Accuracy 1.00000
Epoch [ 4]: Loss 0.23313
Epoch [ 4]: Loss 0.22011
Epoch [ 4]: Loss 0.20610
Epoch [ 4]: Loss 0.20197
Epoch [ 4]: Loss 0.18954
Epoch [ 4]: Loss 0.18644
Epoch [ 4]: Loss 0.16965
Validation: Loss 0.16929 Accuracy 1.00000
Validation: Loss 0.16289 Accuracy 1.00000
Epoch [ 5]: Loss 0.17248
Epoch [ 5]: Loss 0.15691
Epoch [ 5]: Loss 0.14939
Epoch [ 5]: Loss 0.15211
Epoch [ 5]: Loss 0.13778
Epoch [ 5]: Loss 0.13366
Epoch [ 5]: Loss 0.12813
Validation: Loss 0.12354 Accuracy 1.00000
Validation: Loss 0.11952 Accuracy 1.00000
Epoch [ 6]: Loss 0.12129
Epoch [ 6]: Loss 0.11648
Epoch [ 6]: Loss 0.11025
Epoch [ 6]: Loss 0.11038
Epoch [ 6]: Loss 0.10240
Epoch [ 6]: Loss 0.09844
Epoch [ 6]: Loss 0.09438
Validation: Loss 0.09038 Accuracy 1.00000
Validation: Loss 0.08732 Accuracy 1.00000
Epoch [ 7]: Loss 0.08690
Epoch [ 7]: Loss 0.08627
Epoch [ 7]: Loss 0.08206
Epoch [ 7]: Loss 0.08071
Epoch [ 7]: Loss 0.07612
Epoch [ 7]: Loss 0.07176
Epoch [ 7]: Loss 0.06464
Validation: Loss 0.06657 Accuracy 1.00000
Validation: Loss 0.06384 Accuracy 1.00000
Epoch [ 8]: Loss 0.06580
Epoch [ 8]: Loss 0.06142
Epoch [ 8]: Loss 0.06099
Epoch [ 8]: Loss 0.05952
Epoch [ 8]: Loss 0.05652
Epoch [ 8]: Loss 0.05247
Epoch [ 8]: Loss 0.04920
Validation: Loss 0.04998 Accuracy 1.00000
Validation: Loss 0.04735 Accuracy 1.00000
Epoch [ 9]: Loss 0.04984
Epoch [ 9]: Loss 0.04869
Epoch [ 9]: Loss 0.04513
Epoch [ 9]: Loss 0.04462
Epoch [ 9]: Loss 0.04283
Epoch [ 9]: Loss 0.03819
Epoch [ 9]: Loss 0.04063
Validation: Loss 0.03902 Accuracy 1.00000
Validation: Loss 0.03663 Accuracy 1.00000
Epoch [ 10]: Loss 0.03702
Epoch [ 10]: Loss 0.03864
Epoch [ 10]: Loss 0.03552
Epoch [ 10]: Loss 0.03625
Epoch [ 10]: Loss 0.03313
Epoch [ 10]: Loss 0.03199
Epoch [ 10]: Loss 0.03259
Validation: Loss 0.03162 Accuracy 1.00000
Validation: Loss 0.02954 Accuracy 1.00000
Epoch [ 11]: Loss 0.03234
Epoch [ 11]: Loss 0.02995
Epoch [ 11]: Loss 0.03071
Epoch [ 11]: Loss 0.02743
Epoch [ 11]: Loss 0.02662
Epoch [ 11]: Loss 0.02684
Epoch [ 11]: Loss 0.02702
Validation: Loss 0.02637 Accuracy 1.00000
Validation: Loss 0.02460 Accuracy 1.00000
Epoch [ 12]: Loss 0.02498
Epoch [ 12]: Loss 0.02625
Epoch [ 12]: Loss 0.02586
Epoch [ 12]: Loss 0.02266
Epoch [ 12]: Loss 0.02356
Epoch [ 12]: Loss 0.02290
Epoch [ 12]: Loss 0.02310
Validation: Loss 0.02257 Accuracy 1.00000
Validation: Loss 0.02104 Accuracy 1.00000
Epoch [ 13]: Loss 0.02235
Epoch [ 13]: Loss 0.02041
Epoch [ 13]: Loss 0.02204
Epoch [ 13]: Loss 0.02106
Epoch [ 13]: Loss 0.02094
Epoch [ 13]: Loss 0.01886
Epoch [ 13]: Loss 0.02162
Validation: Loss 0.01971 Accuracy 1.00000
Validation: Loss 0.01839 Accuracy 1.00000
Epoch [ 14]: Loss 0.01974
Epoch [ 14]: Loss 0.01898
Epoch [ 14]: Loss 0.01828
Epoch [ 14]: Loss 0.01841
Epoch [ 14]: Loss 0.01823
Epoch [ 14]: Loss 0.01780
Epoch [ 14]: Loss 0.01521
Validation: Loss 0.01749 Accuracy 1.00000
Validation: Loss 0.01632 Accuracy 1.00000
Epoch [ 15]: Loss 0.01649
Epoch [ 15]: Loss 0.01631
Epoch [ 15]: Loss 0.01711
Epoch [ 15]: Loss 0.01699
Epoch [ 15]: Loss 0.01602
Epoch [ 15]: Loss 0.01552
Epoch [ 15]: Loss 0.01727
Validation: Loss 0.01573 Accuracy 1.00000
Validation: Loss 0.01468 Accuracy 1.00000
Epoch [ 16]: Loss 0.01557
Epoch [ 16]: Loss 0.01564
Epoch [ 16]: Loss 0.01498
Epoch [ 16]: Loss 0.01444
Epoch [ 16]: Loss 0.01493
Epoch [ 16]: Loss 0.01363
Epoch [ 16]: Loss 0.01454
Validation: Loss 0.01428 Accuracy 1.00000
Validation: Loss 0.01333 Accuracy 1.00000
Epoch [ 17]: Loss 0.01397
Epoch [ 17]: Loss 0.01433
Epoch [ 17]: Loss 0.01279
Epoch [ 17]: Loss 0.01386
Epoch [ 17]: Loss 0.01287
Epoch [ 17]: Loss 0.01318
Epoch [ 17]: Loss 0.01397
Validation: Loss 0.01307 Accuracy 1.00000
Validation: Loss 0.01219 Accuracy 1.00000
Epoch [ 18]: Loss 0.01268
Epoch [ 18]: Loss 0.01247
Epoch [ 18]: Loss 0.01323
Epoch [ 18]: Loss 0.01210
Epoch [ 18]: Loss 0.01172
Epoch [ 18]: Loss 0.01225
Epoch [ 18]: Loss 0.01216
Validation: Loss 0.01203 Accuracy 1.00000
Validation: Loss 0.01122 Accuracy 1.00000
Epoch [ 19]: Loss 0.01192
Epoch [ 19]: Loss 0.01234
Epoch [ 19]: Loss 0.01078
Epoch [ 19]: Loss 0.01153
Epoch [ 19]: Loss 0.01106
Epoch [ 19]: Loss 0.01147
Epoch [ 19]: Loss 0.00963
Validation: Loss 0.01114 Accuracy 1.00000
Validation: Loss 0.01038 Accuracy 1.00000
Epoch [ 20]: Loss 0.01134
Epoch [ 20]: Loss 0.01135
Epoch [ 20]: Loss 0.01027
Epoch [ 20]: Loss 0.01033
Epoch [ 20]: Loss 0.00984
Epoch [ 20]: Loss 0.01070
Epoch [ 20]: Loss 0.00998
Validation: Loss 0.01036 Accuracy 1.00000
Validation: Loss 0.00965 Accuracy 1.00000
Epoch [ 21]: Loss 0.01019
Epoch [ 21]: Loss 0.01034
Epoch [ 21]: Loss 0.01011
Epoch [ 21]: Loss 0.00926
Epoch [ 21]: Loss 0.00977
Epoch [ 21]: Loss 0.00973
Epoch [ 21]: Loss 0.00943
Validation: Loss 0.00968 Accuracy 1.00000
Validation: Loss 0.00901 Accuracy 1.00000
Epoch [ 22]: Loss 0.00969
Epoch [ 22]: Loss 0.00918
Epoch [ 22]: Loss 0.00927
Epoch [ 22]: Loss 0.00923
Epoch [ 22]: Loss 0.00942
Epoch [ 22]: Loss 0.00896
Epoch [ 22]: Loss 0.00799
Validation: Loss 0.00907 Accuracy 1.00000
Validation: Loss 0.00844 Accuracy 1.00000
Epoch [ 23]: Loss 0.00959
Epoch [ 23]: Loss 0.00914
Epoch [ 23]: Loss 0.00806
Epoch [ 23]: Loss 0.00846
Epoch [ 23]: Loss 0.00839
Epoch [ 23]: Loss 0.00888
Epoch [ 23]: Loss 0.00664
Validation: Loss 0.00852 Accuracy 1.00000
Validation: Loss 0.00792 Accuracy 1.00000
Epoch [ 24]: Loss 0.00857
Epoch [ 24]: Loss 0.00856
Epoch [ 24]: Loss 0.00768
Epoch [ 24]: Loss 0.00808
Epoch [ 24]: Loss 0.00834
Epoch [ 24]: Loss 0.00793
Epoch [ 24]: Loss 0.00725
Validation: Loss 0.00804 Accuracy 1.00000
Validation: Loss 0.00747 Accuracy 1.00000
Epoch [ 25]: Loss 0.00778
Epoch [ 25]: Loss 0.00757
Epoch [ 25]: Loss 0.00814
Epoch [ 25]: Loss 0.00787
Epoch [ 25]: Loss 0.00735
Epoch [ 25]: Loss 0.00745
Epoch [ 25]: Loss 0.00770
Validation: Loss 0.00760 Accuracy 1.00000
Validation: Loss 0.00705 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxDeviceUtils)
if @isdefined(CUDA) && LuxDeviceUtils.functional(LuxCUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && LuxDeviceUtils.functional(LuxAMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 4 default, 0 interactive, 2 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 4
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6
CUDA libraries:
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6
Julia packages:
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0
Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.484 GiB / 4.750 GiB available)
This page was generated using Literate.jl.