Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, StatisticsDataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
endget_dataloaders (generic function with 1 method)Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
endWe won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
endMain.var"##225".SpiralClassifierWe can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
endUsing the @compact API
We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
endSpiralClassifierCompact (generic function with 1 method)Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)accuracy (generic function with 1 method)Training the Model
function main(model_type)
dev = gpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() .|> dev
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
for epoch in 1:25
# Train the model
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)Epoch [ 1]: Loss 0.63363
Epoch [ 1]: Loss 0.58107
Epoch [ 1]: Loss 0.56740
Epoch [ 1]: Loss 0.54022
Epoch [ 1]: Loss 0.51598
Epoch [ 1]: Loss 0.50952
Epoch [ 1]: Loss 0.47641
Validation: Loss 0.46745 Accuracy 1.00000
Validation: Loss 0.47185 Accuracy 1.00000
Epoch [ 2]: Loss 0.46529
Epoch [ 2]: Loss 0.46127
Epoch [ 2]: Loss 0.45251
Epoch [ 2]: Loss 0.42456
Epoch [ 2]: Loss 0.40951
Epoch [ 2]: Loss 0.39444
Epoch [ 2]: Loss 0.37673
Validation: Loss 0.37011 Accuracy 1.00000
Validation: Loss 0.37545 Accuracy 1.00000
Epoch [ 3]: Loss 0.36862
Epoch [ 3]: Loss 0.36597
Epoch [ 3]: Loss 0.34827
Epoch [ 3]: Loss 0.34519
Epoch [ 3]: Loss 0.31592
Epoch [ 3]: Loss 0.30139
Epoch [ 3]: Loss 0.31022
Validation: Loss 0.28488 Accuracy 1.00000
Validation: Loss 0.29073 Accuracy 1.00000
Epoch [ 4]: Loss 0.28752
Epoch [ 4]: Loss 0.26677
Epoch [ 4]: Loss 0.27206
Epoch [ 4]: Loss 0.25701
Epoch [ 4]: Loss 0.25002
Epoch [ 4]: Loss 0.23381
Epoch [ 4]: Loss 0.23789
Validation: Loss 0.21455 Accuracy 1.00000
Validation: Loss 0.22029 Accuracy 1.00000
Epoch [ 5]: Loss 0.21389
Epoch [ 5]: Loss 0.20979
Epoch [ 5]: Loss 0.19148
Epoch [ 5]: Loss 0.19602
Epoch [ 5]: Loss 0.19112
Epoch [ 5]: Loss 0.17698
Epoch [ 5]: Loss 0.16302
Validation: Loss 0.15852 Accuracy 1.00000
Validation: Loss 0.16358 Accuracy 1.00000
Epoch [ 6]: Loss 0.16230
Epoch [ 6]: Loss 0.14705
Epoch [ 6]: Loss 0.14988
Epoch [ 6]: Loss 0.13514
Epoch [ 6]: Loss 0.14176
Epoch [ 6]: Loss 0.13110
Epoch [ 6]: Loss 0.12789
Validation: Loss 0.11559 Accuracy 1.00000
Validation: Loss 0.11961 Accuracy 1.00000
Epoch [ 7]: Loss 0.11763
Epoch [ 7]: Loss 0.10926
Epoch [ 7]: Loss 0.10829
Epoch [ 7]: Loss 0.10551
Epoch [ 7]: Loss 0.09851
Epoch [ 7]: Loss 0.08757
Epoch [ 7]: Loss 0.10091
Validation: Loss 0.08221 Accuracy 1.00000
Validation: Loss 0.08509 Accuracy 1.00000
Epoch [ 8]: Loss 0.09275
Epoch [ 8]: Loss 0.08028
Epoch [ 8]: Loss 0.07877
Epoch [ 8]: Loss 0.07199
Epoch [ 8]: Loss 0.06543
Epoch [ 8]: Loss 0.05869
Epoch [ 8]: Loss 0.05268
Validation: Loss 0.05708 Accuracy 1.00000
Validation: Loss 0.05897 Accuracy 1.00000
Epoch [ 9]: Loss 0.05457
Epoch [ 9]: Loss 0.05363
Epoch [ 9]: Loss 0.04984
Epoch [ 9]: Loss 0.05663
Epoch [ 9]: Loss 0.04950
Epoch [ 9]: Loss 0.04771
Epoch [ 9]: Loss 0.04671
Validation: Loss 0.04292 Accuracy 1.00000
Validation: Loss 0.04429 Accuracy 1.00000
Epoch [ 10]: Loss 0.04077
Epoch [ 10]: Loss 0.04255
Epoch [ 10]: Loss 0.04058
Epoch [ 10]: Loss 0.03803
Epoch [ 10]: Loss 0.03853
Epoch [ 10]: Loss 0.04149
Epoch [ 10]: Loss 0.03777
Validation: Loss 0.03493 Accuracy 1.00000
Validation: Loss 0.03606 Accuracy 1.00000
Epoch [ 11]: Loss 0.03434
Epoch [ 11]: Loss 0.03377
Epoch [ 11]: Loss 0.03296
Epoch [ 11]: Loss 0.03385
Epoch [ 11]: Loss 0.03444
Epoch [ 11]: Loss 0.03208
Epoch [ 11]: Loss 0.02881
Validation: Loss 0.02969 Accuracy 1.00000
Validation: Loss 0.03068 Accuracy 1.00000
Epoch [ 12]: Loss 0.03047
Epoch [ 12]: Loss 0.03008
Epoch [ 12]: Loss 0.02824
Epoch [ 12]: Loss 0.02843
Epoch [ 12]: Loss 0.02883
Epoch [ 12]: Loss 0.02710
Epoch [ 12]: Loss 0.02491
Validation: Loss 0.02586 Accuracy 1.00000
Validation: Loss 0.02674 Accuracy 1.00000
Epoch [ 13]: Loss 0.02587
Epoch [ 13]: Loss 0.02448
Epoch [ 13]: Loss 0.02584
Epoch [ 13]: Loss 0.02408
Epoch [ 13]: Loss 0.02560
Epoch [ 13]: Loss 0.02540
Epoch [ 13]: Loss 0.02372
Validation: Loss 0.02290 Accuracy 1.00000
Validation: Loss 0.02369 Accuracy 1.00000
Epoch [ 14]: Loss 0.02355
Epoch [ 14]: Loss 0.02294
Epoch [ 14]: Loss 0.02439
Epoch [ 14]: Loss 0.02041
Epoch [ 14]: Loss 0.02226
Epoch [ 14]: Loss 0.02133
Epoch [ 14]: Loss 0.02041
Validation: Loss 0.02050 Accuracy 1.00000
Validation: Loss 0.02123 Accuracy 1.00000
Epoch [ 15]: Loss 0.02183
Epoch [ 15]: Loss 0.02021
Epoch [ 15]: Loss 0.02045
Epoch [ 15]: Loss 0.01869
Epoch [ 15]: Loss 0.01932
Epoch [ 15]: Loss 0.01994
Epoch [ 15]: Loss 0.02126
Validation: Loss 0.01851 Accuracy 1.00000
Validation: Loss 0.01918 Accuracy 1.00000
Epoch [ 16]: Loss 0.01871
Epoch [ 16]: Loss 0.01825
Epoch [ 16]: Loss 0.01783
Epoch [ 16]: Loss 0.01827
Epoch [ 16]: Loss 0.01952
Epoch [ 16]: Loss 0.01644
Epoch [ 16]: Loss 0.01935
Validation: Loss 0.01682 Accuracy 1.00000
Validation: Loss 0.01744 Accuracy 1.00000
Epoch [ 17]: Loss 0.01635
Epoch [ 17]: Loss 0.01644
Epoch [ 17]: Loss 0.01703
Epoch [ 17]: Loss 0.01601
Epoch [ 17]: Loss 0.01643
Epoch [ 17]: Loss 0.01731
Epoch [ 17]: Loss 0.01647
Validation: Loss 0.01538 Accuracy 1.00000
Validation: Loss 0.01595 Accuracy 1.00000
Epoch [ 18]: Loss 0.01629
Epoch [ 18]: Loss 0.01398
Epoch [ 18]: Loss 0.01678
Epoch [ 18]: Loss 0.01599
Epoch [ 18]: Loss 0.01377
Epoch [ 18]: Loss 0.01432
Epoch [ 18]: Loss 0.01569
Validation: Loss 0.01413 Accuracy 1.00000
Validation: Loss 0.01466 Accuracy 1.00000
Epoch [ 19]: Loss 0.01679
Epoch [ 19]: Loss 0.01354
Epoch [ 19]: Loss 0.01367
Epoch [ 19]: Loss 0.01282
Epoch [ 19]: Loss 0.01308
Epoch [ 19]: Loss 0.01442
Epoch [ 19]: Loss 0.01291
Validation: Loss 0.01304 Accuracy 1.00000
Validation: Loss 0.01354 Accuracy 1.00000
Epoch [ 20]: Loss 0.01317
Epoch [ 20]: Loss 0.01342
Epoch [ 20]: Loss 0.01212
Epoch [ 20]: Loss 0.01389
Epoch [ 20]: Loss 0.01248
Epoch [ 20]: Loss 0.01274
Epoch [ 20]: Loss 0.01246
Validation: Loss 0.01208 Accuracy 1.00000
Validation: Loss 0.01254 Accuracy 1.00000
Epoch [ 21]: Loss 0.01304
Epoch [ 21]: Loss 0.01256
Epoch [ 21]: Loss 0.01182
Epoch [ 21]: Loss 0.01122
Epoch [ 21]: Loss 0.01133
Epoch [ 21]: Loss 0.01236
Epoch [ 21]: Loss 0.01061
Validation: Loss 0.01118 Accuracy 1.00000
Validation: Loss 0.01160 Accuracy 1.00000
Epoch [ 22]: Loss 0.01230
Epoch [ 22]: Loss 0.01060
Epoch [ 22]: Loss 0.01113
Epoch [ 22]: Loss 0.00969
Epoch [ 22]: Loss 0.01174
Epoch [ 22]: Loss 0.01126
Epoch [ 22]: Loss 0.00975
Validation: Loss 0.01024 Accuracy 1.00000
Validation: Loss 0.01062 Accuracy 1.00000
Epoch [ 23]: Loss 0.01086
Epoch [ 23]: Loss 0.01038
Epoch [ 23]: Loss 0.00994
Epoch [ 23]: Loss 0.01040
Epoch [ 23]: Loss 0.00955
Epoch [ 23]: Loss 0.00970
Epoch [ 23]: Loss 0.00844
Validation: Loss 0.00919 Accuracy 1.00000
Validation: Loss 0.00952 Accuracy 1.00000
Epoch [ 24]: Loss 0.00972
Epoch [ 24]: Loss 0.00930
Epoch [ 24]: Loss 0.00914
Epoch [ 24]: Loss 0.00864
Epoch [ 24]: Loss 0.00850
Epoch [ 24]: Loss 0.00908
Epoch [ 24]: Loss 0.00689
Validation: Loss 0.00817 Accuracy 1.00000
Validation: Loss 0.00845 Accuracy 1.00000
Epoch [ 25]: Loss 0.00788
Epoch [ 25]: Loss 0.00835
Epoch [ 25]: Loss 0.00837
Epoch [ 25]: Loss 0.00780
Epoch [ 25]: Loss 0.00798
Epoch [ 25]: Loss 0.00822
Epoch [ 25]: Loss 0.00598
Validation: Loss 0.00740 Accuracy 1.00000
Validation: Loss 0.00766 Accuracy 1.00000We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)Epoch [ 1]: Loss 0.60294
Epoch [ 1]: Loss 0.58756
Epoch [ 1]: Loss 0.57470
Epoch [ 1]: Loss 0.54895
Epoch [ 1]: Loss 0.52199
Epoch [ 1]: Loss 0.49149
Epoch [ 1]: Loss 0.48506
Validation: Loss 0.48276 Accuracy 1.00000
Validation: Loss 0.47895 Accuracy 1.00000
Epoch [ 2]: Loss 0.47011
Epoch [ 2]: Loss 0.45964
Epoch [ 2]: Loss 0.43511
Epoch [ 2]: Loss 0.42398
Epoch [ 2]: Loss 0.41160
Epoch [ 2]: Loss 0.38680
Epoch [ 2]: Loss 0.40415
Validation: Loss 0.38807 Accuracy 1.00000
Validation: Loss 0.38384 Accuracy 1.00000
Epoch [ 3]: Loss 0.36732
Epoch [ 3]: Loss 0.34996
Epoch [ 3]: Loss 0.34527
Epoch [ 3]: Loss 0.33681
Epoch [ 3]: Loss 0.31621
Epoch [ 3]: Loss 0.32185
Epoch [ 3]: Loss 0.30775
Validation: Loss 0.30521 Accuracy 1.00000
Validation: Loss 0.30071 Accuracy 1.00000
Epoch [ 4]: Loss 0.29681
Epoch [ 4]: Loss 0.27822
Epoch [ 4]: Loss 0.26056
Epoch [ 4]: Loss 0.25291
Epoch [ 4]: Loss 0.23603
Epoch [ 4]: Loss 0.24225
Epoch [ 4]: Loss 0.22994
Validation: Loss 0.23551 Accuracy 1.00000
Validation: Loss 0.23110 Accuracy 1.00000
Epoch [ 5]: Loss 0.21795
Epoch [ 5]: Loss 0.21422
Epoch [ 5]: Loss 0.20792
Epoch [ 5]: Loss 0.19499
Epoch [ 5]: Loss 0.17770
Epoch [ 5]: Loss 0.17183
Epoch [ 5]: Loss 0.16068
Validation: Loss 0.17835 Accuracy 1.00000
Validation: Loss 0.17435 Accuracy 1.00000
Epoch [ 6]: Loss 0.16007
Epoch [ 6]: Loss 0.15957
Epoch [ 6]: Loss 0.14442
Epoch [ 6]: Loss 0.14187
Epoch [ 6]: Loss 0.13934
Epoch [ 6]: Loss 0.12529
Epoch [ 6]: Loss 0.14995
Validation: Loss 0.13267 Accuracy 1.00000
Validation: Loss 0.12941 Accuracy 1.00000
Epoch [ 7]: Loss 0.11240
Epoch [ 7]: Loss 0.10309
Epoch [ 7]: Loss 0.10744
Epoch [ 7]: Loss 0.11145
Epoch [ 7]: Loss 0.10243
Epoch [ 7]: Loss 0.10379
Epoch [ 7]: Loss 0.08507
Validation: Loss 0.09567 Accuracy 1.00000
Validation: Loss 0.09324 Accuracy 1.00000
Epoch [ 8]: Loss 0.08379
Epoch [ 8]: Loss 0.08044
Epoch [ 8]: Loss 0.08166
Epoch [ 8]: Loss 0.07579
Epoch [ 8]: Loss 0.06885
Epoch [ 8]: Loss 0.06613
Epoch [ 8]: Loss 0.05985
Validation: Loss 0.06640 Accuracy 1.00000
Validation: Loss 0.06474 Accuracy 1.00000
Epoch [ 9]: Loss 0.06068
Epoch [ 9]: Loss 0.05652
Epoch [ 9]: Loss 0.05485
Epoch [ 9]: Loss 0.05233
Epoch [ 9]: Loss 0.04971
Epoch [ 9]: Loss 0.04684
Epoch [ 9]: Loss 0.04336
Validation: Loss 0.04867 Accuracy 1.00000
Validation: Loss 0.04754 Accuracy 1.00000
Epoch [ 10]: Loss 0.04424
Epoch [ 10]: Loss 0.04479
Epoch [ 10]: Loss 0.03874
Epoch [ 10]: Loss 0.03961
Epoch [ 10]: Loss 0.04071
Epoch [ 10]: Loss 0.03695
Epoch [ 10]: Loss 0.03219
Validation: Loss 0.03922 Accuracy 1.00000
Validation: Loss 0.03833 Accuracy 1.00000
Epoch [ 11]: Loss 0.03558
Epoch [ 11]: Loss 0.03376
Epoch [ 11]: Loss 0.03523
Epoch [ 11]: Loss 0.03155
Epoch [ 11]: Loss 0.03110
Epoch [ 11]: Loss 0.03231
Epoch [ 11]: Loss 0.03356
Validation: Loss 0.03330 Accuracy 1.00000
Validation: Loss 0.03253 Accuracy 1.00000
Epoch [ 12]: Loss 0.03067
Epoch [ 12]: Loss 0.02975
Epoch [ 12]: Loss 0.03041
Epoch [ 12]: Loss 0.02751
Epoch [ 12]: Loss 0.02625
Epoch [ 12]: Loss 0.02595
Epoch [ 12]: Loss 0.03005
Validation: Loss 0.02897 Accuracy 1.00000
Validation: Loss 0.02829 Accuracy 1.00000
Epoch [ 13]: Loss 0.02238
Epoch [ 13]: Loss 0.02684
Epoch [ 13]: Loss 0.02452
Epoch [ 13]: Loss 0.02701
Epoch [ 13]: Loss 0.02408
Epoch [ 13]: Loss 0.02390
Epoch [ 13]: Loss 0.02650
Validation: Loss 0.02564 Accuracy 1.00000
Validation: Loss 0.02502 Accuracy 1.00000
Epoch [ 14]: Loss 0.02371
Epoch [ 14]: Loss 0.02208
Epoch [ 14]: Loss 0.02316
Epoch [ 14]: Loss 0.02074
Epoch [ 14]: Loss 0.02283
Epoch [ 14]: Loss 0.02029
Epoch [ 14]: Loss 0.02055
Validation: Loss 0.02294 Accuracy 1.00000
Validation: Loss 0.02238 Accuracy 1.00000
Epoch [ 15]: Loss 0.02121
Epoch [ 15]: Loss 0.02013
Epoch [ 15]: Loss 0.02057
Epoch [ 15]: Loss 0.02009
Epoch [ 15]: Loss 0.01830
Epoch [ 15]: Loss 0.01877
Epoch [ 15]: Loss 0.01847
Validation: Loss 0.02073 Accuracy 1.00000
Validation: Loss 0.02021 Accuracy 1.00000
Epoch [ 16]: Loss 0.01681
Epoch [ 16]: Loss 0.01782
Epoch [ 16]: Loss 0.01825
Epoch [ 16]: Loss 0.01853
Epoch [ 16]: Loss 0.01726
Epoch [ 16]: Loss 0.01804
Epoch [ 16]: Loss 0.02012
Validation: Loss 0.01888 Accuracy 1.00000
Validation: Loss 0.01840 Accuracy 1.00000
Epoch [ 17]: Loss 0.01707
Epoch [ 17]: Loss 0.01651
Epoch [ 17]: Loss 0.01774
Epoch [ 17]: Loss 0.01463
Epoch [ 17]: Loss 0.01460
Epoch [ 17]: Loss 0.01724
Epoch [ 17]: Loss 0.01593
Validation: Loss 0.01728 Accuracy 1.00000
Validation: Loss 0.01683 Accuracy 1.00000
Epoch [ 18]: Loss 0.01529
Epoch [ 18]: Loss 0.01508
Epoch [ 18]: Loss 0.01551
Epoch [ 18]: Loss 0.01548
Epoch [ 18]: Loss 0.01382
Epoch [ 18]: Loss 0.01480
Epoch [ 18]: Loss 0.01289
Validation: Loss 0.01590 Accuracy 1.00000
Validation: Loss 0.01549 Accuracy 1.00000
Epoch [ 19]: Loss 0.01389
Epoch [ 19]: Loss 0.01402
Epoch [ 19]: Loss 0.01486
Epoch [ 19]: Loss 0.01357
Epoch [ 19]: Loss 0.01424
Epoch [ 19]: Loss 0.01162
Epoch [ 19]: Loss 0.01461
Validation: Loss 0.01471 Accuracy 1.00000
Validation: Loss 0.01432 Accuracy 1.00000
Epoch [ 20]: Loss 0.01377
Epoch [ 20]: Loss 0.01363
Epoch [ 20]: Loss 0.01268
Epoch [ 20]: Loss 0.01151
Epoch [ 20]: Loss 0.01216
Epoch [ 20]: Loss 0.01291
Epoch [ 20]: Loss 0.01101
Validation: Loss 0.01363 Accuracy 1.00000
Validation: Loss 0.01327 Accuracy 1.00000
Epoch [ 21]: Loss 0.01165
Epoch [ 21]: Loss 0.01349
Epoch [ 21]: Loss 0.01187
Epoch [ 21]: Loss 0.01127
Epoch [ 21]: Loss 0.01174
Epoch [ 21]: Loss 0.01055
Epoch [ 21]: Loss 0.01207
Validation: Loss 0.01263 Accuracy 1.00000
Validation: Loss 0.01229 Accuracy 1.00000
Epoch [ 22]: Loss 0.01091
Epoch [ 22]: Loss 0.01088
Epoch [ 22]: Loss 0.01154
Epoch [ 22]: Loss 0.00984
Epoch [ 22]: Loss 0.01073
Epoch [ 22]: Loss 0.01092
Epoch [ 22]: Loss 0.01233
Validation: Loss 0.01158 Accuracy 1.00000
Validation: Loss 0.01127 Accuracy 1.00000
Epoch [ 23]: Loss 0.01030
Epoch [ 23]: Loss 0.00935
Epoch [ 23]: Loss 0.00922
Epoch [ 23]: Loss 0.01168
Epoch [ 23]: Loss 0.00993
Epoch [ 23]: Loss 0.00929
Epoch [ 23]: Loss 0.00844
Validation: Loss 0.01037 Accuracy 1.00000
Validation: Loss 0.01009 Accuracy 1.00000
Epoch [ 24]: Loss 0.00898
Epoch [ 24]: Loss 0.00892
Epoch [ 24]: Loss 0.00931
Epoch [ 24]: Loss 0.00795
Epoch [ 24]: Loss 0.00940
Epoch [ 24]: Loss 0.00882
Epoch [ 24]: Loss 0.00722
Validation: Loss 0.00916 Accuracy 1.00000
Validation: Loss 0.00893 Accuracy 1.00000
Epoch [ 25]: Loss 0.00806
Epoch [ 25]: Loss 0.00746
Epoch [ 25]: Loss 0.00796
Epoch [ 25]: Loss 0.00845
Epoch [ 25]: Loss 0.00788
Epoch [ 25]: Loss 0.00741
Epoch [ 25]: Loss 0.00793
Validation: Loss 0.00824 Accuracy 1.00000
Validation: Loss 0.00804 Accuracy 1.00000Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model
@save "trained_model.jld2" ps_trained st_trainedLet's try loading the model
@load "trained_model.jld2" ps_trained st_trained2-element Vector{Symbol}:
:ps_trained
:st_trainedAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6
CUDA libraries:
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6
Julia packages:
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0
Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.453 GiB / 4.750 GiB available)This page was generated using Literate.jl.