Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, StatisticsDataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
endget_dataloaders (generic function with 1 method)Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
endWe won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
endMain.var"##225".SpiralClassifierWe can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
endUsing the @compact API
We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
endSpiralClassifierCompact (generic function with 1 method)Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)accuracy (generic function with 1 method)Training the Model
function main(model_type)
dev = gpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() .|> dev
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
for epoch in 1:25
# Train the model
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)Epoch [ 1]: Loss 0.63025
Epoch [ 1]: Loss 0.59252
Epoch [ 1]: Loss 0.57088
Epoch [ 1]: Loss 0.52932
Epoch [ 1]: Loss 0.51400
Epoch [ 1]: Loss 0.50206
Epoch [ 1]: Loss 0.48769
Validation: Loss 0.47018 Accuracy 1.00000
Validation: Loss 0.47064 Accuracy 1.00000
Epoch [ 2]: Loss 0.46553
Epoch [ 2]: Loss 0.45533
Epoch [ 2]: Loss 0.43983
Epoch [ 2]: Loss 0.42447
Epoch [ 2]: Loss 0.41724
Epoch [ 2]: Loss 0.39576
Epoch [ 2]: Loss 0.36962
Validation: Loss 0.37282 Accuracy 1.00000
Validation: Loss 0.37357 Accuracy 1.00000
Epoch [ 3]: Loss 0.37650
Epoch [ 3]: Loss 0.37466
Epoch [ 3]: Loss 0.34152
Epoch [ 3]: Loss 0.34363
Epoch [ 3]: Loss 0.30602
Epoch [ 3]: Loss 0.29191
Epoch [ 3]: Loss 0.29158
Validation: Loss 0.28819 Accuracy 1.00000
Validation: Loss 0.28900 Accuracy 1.00000
Epoch [ 4]: Loss 0.27525
Epoch [ 4]: Loss 0.27721
Epoch [ 4]: Loss 0.26404
Epoch [ 4]: Loss 0.25525
Epoch [ 4]: Loss 0.24492
Epoch [ 4]: Loss 0.23775
Epoch [ 4]: Loss 0.24253
Validation: Loss 0.21888 Accuracy 1.00000
Validation: Loss 0.21954 Accuracy 1.00000
Epoch [ 5]: Loss 0.20728
Epoch [ 5]: Loss 0.20676
Epoch [ 5]: Loss 0.20946
Epoch [ 5]: Loss 0.20387
Epoch [ 5]: Loss 0.18062
Epoch [ 5]: Loss 0.17109
Epoch [ 5]: Loss 0.14328
Validation: Loss 0.16350 Accuracy 1.00000
Validation: Loss 0.16388 Accuracy 1.00000
Epoch [ 6]: Loss 0.16040
Epoch [ 6]: Loss 0.16014
Epoch [ 6]: Loss 0.14114
Epoch [ 6]: Loss 0.13844
Epoch [ 6]: Loss 0.13959
Epoch [ 6]: Loss 0.12868
Epoch [ 6]: Loss 0.12854
Validation: Loss 0.12064 Accuracy 1.00000
Validation: Loss 0.12080 Accuracy 1.00000
Epoch [ 7]: Loss 0.11828
Epoch [ 7]: Loss 0.11069
Epoch [ 7]: Loss 0.11143
Epoch [ 7]: Loss 0.10270
Epoch [ 7]: Loss 0.09966
Epoch [ 7]: Loss 0.09235
Epoch [ 7]: Loss 0.08941
Validation: Loss 0.08659 Accuracy 1.00000
Validation: Loss 0.08664 Accuracy 1.00000
Epoch [ 8]: Loss 0.08308
Epoch [ 8]: Loss 0.08126
Epoch [ 8]: Loss 0.07705
Epoch [ 8]: Loss 0.07120
Epoch [ 8]: Loss 0.06869
Epoch [ 8]: Loss 0.06998
Epoch [ 8]: Loss 0.05573
Validation: Loss 0.06021 Accuracy 1.00000
Validation: Loss 0.06021 Accuracy 1.00000
Epoch [ 9]: Loss 0.05737
Epoch [ 9]: Loss 0.05734
Epoch [ 9]: Loss 0.05609
Epoch [ 9]: Loss 0.05445
Epoch [ 9]: Loss 0.04675
Epoch [ 9]: Loss 0.04521
Epoch [ 9]: Loss 0.04244
Validation: Loss 0.04441 Accuracy 1.00000
Validation: Loss 0.04440 Accuracy 1.00000
Epoch [ 10]: Loss 0.04378
Epoch [ 10]: Loss 0.03780
Epoch [ 10]: Loss 0.04223
Epoch [ 10]: Loss 0.04091
Epoch [ 10]: Loss 0.03903
Epoch [ 10]: Loss 0.03821
Epoch [ 10]: Loss 0.03259
Validation: Loss 0.03584 Accuracy 1.00000
Validation: Loss 0.03582 Accuracy 1.00000
Epoch [ 11]: Loss 0.03564
Epoch [ 11]: Loss 0.03466
Epoch [ 11]: Loss 0.03382
Epoch [ 11]: Loss 0.03248
Epoch [ 11]: Loss 0.03093
Epoch [ 11]: Loss 0.03074
Epoch [ 11]: Loss 0.03161
Validation: Loss 0.03040 Accuracy 1.00000
Validation: Loss 0.03038 Accuracy 1.00000
Epoch [ 12]: Loss 0.02863
Epoch [ 12]: Loss 0.03059
Epoch [ 12]: Loss 0.02746
Epoch [ 12]: Loss 0.02981
Epoch [ 12]: Loss 0.02637
Epoch [ 12]: Loss 0.02821
Epoch [ 12]: Loss 0.02189
Validation: Loss 0.02644 Accuracy 1.00000
Validation: Loss 0.02642 Accuracy 1.00000
Epoch [ 13]: Loss 0.02332
Epoch [ 13]: Loss 0.02555
Epoch [ 13]: Loss 0.02465
Epoch [ 13]: Loss 0.02553
Epoch [ 13]: Loss 0.02551
Epoch [ 13]: Loss 0.02360
Epoch [ 13]: Loss 0.02509
Validation: Loss 0.02341 Accuracy 1.00000
Validation: Loss 0.02338 Accuracy 1.00000
Epoch [ 14]: Loss 0.02241
Epoch [ 14]: Loss 0.02304
Epoch [ 14]: Loss 0.02248
Epoch [ 14]: Loss 0.02308
Epoch [ 14]: Loss 0.02177
Epoch [ 14]: Loss 0.01971
Epoch [ 14]: Loss 0.01938
Validation: Loss 0.02094 Accuracy 1.00000
Validation: Loss 0.02092 Accuracy 1.00000
Epoch [ 15]: Loss 0.02039
Epoch [ 15]: Loss 0.02046
Epoch [ 15]: Loss 0.02162
Epoch [ 15]: Loss 0.01877
Epoch [ 15]: Loss 0.02046
Epoch [ 15]: Loss 0.01700
Epoch [ 15]: Loss 0.01815
Validation: Loss 0.01890 Accuracy 1.00000
Validation: Loss 0.01888 Accuracy 1.00000
Epoch [ 16]: Loss 0.01942
Epoch [ 16]: Loss 0.01780
Epoch [ 16]: Loss 0.01756
Epoch [ 16]: Loss 0.01782
Epoch [ 16]: Loss 0.01808
Epoch [ 16]: Loss 0.01651
Epoch [ 16]: Loss 0.01712
Validation: Loss 0.01718 Accuracy 1.00000
Validation: Loss 0.01717 Accuracy 1.00000
Epoch [ 17]: Loss 0.01608
Epoch [ 17]: Loss 0.01730
Epoch [ 17]: Loss 0.01582
Epoch [ 17]: Loss 0.01500
Epoch [ 17]: Loss 0.01773
Epoch [ 17]: Loss 0.01523
Epoch [ 17]: Loss 0.01745
Validation: Loss 0.01573 Accuracy 1.00000
Validation: Loss 0.01571 Accuracy 1.00000
Epoch [ 18]: Loss 0.01560
Epoch [ 18]: Loss 0.01450
Epoch [ 18]: Loss 0.01537
Epoch [ 18]: Loss 0.01561
Epoch [ 18]: Loss 0.01471
Epoch [ 18]: Loss 0.01412
Epoch [ 18]: Loss 0.01268
Validation: Loss 0.01447 Accuracy 1.00000
Validation: Loss 0.01445 Accuracy 1.00000
Epoch [ 19]: Loss 0.01393
Epoch [ 19]: Loss 0.01405
Epoch [ 19]: Loss 0.01335
Epoch [ 19]: Loss 0.01373
Epoch [ 19]: Loss 0.01399
Epoch [ 19]: Loss 0.01357
Epoch [ 19]: Loss 0.01277
Validation: Loss 0.01340 Accuracy 1.00000
Validation: Loss 0.01338 Accuracy 1.00000
Epoch [ 20]: Loss 0.01249
Epoch [ 20]: Loss 0.01340
Epoch [ 20]: Loss 0.01339
Epoch [ 20]: Loss 0.01211
Epoch [ 20]: Loss 0.01323
Epoch [ 20]: Loss 0.01197
Epoch [ 20]: Loss 0.01188
Validation: Loss 0.01245 Accuracy 1.00000
Validation: Loss 0.01243 Accuracy 1.00000
Epoch [ 21]: Loss 0.01186
Epoch [ 21]: Loss 0.01197
Epoch [ 21]: Loss 0.01257
Epoch [ 21]: Loss 0.01055
Epoch [ 21]: Loss 0.01220
Epoch [ 21]: Loss 0.01243
Epoch [ 21]: Loss 0.00968
Validation: Loss 0.01160 Accuracy 1.00000
Validation: Loss 0.01159 Accuracy 1.00000
Epoch [ 22]: Loss 0.01115
Epoch [ 22]: Loss 0.01105
Epoch [ 22]: Loss 0.01128
Epoch [ 22]: Loss 0.01072
Epoch [ 22]: Loss 0.01182
Epoch [ 22]: Loss 0.01040
Epoch [ 22]: Loss 0.01029
Validation: Loss 0.01081 Accuracy 1.00000
Validation: Loss 0.01080 Accuracy 1.00000
Epoch [ 23]: Loss 0.01030
Epoch [ 23]: Loss 0.01114
Epoch [ 23]: Loss 0.01071
Epoch [ 23]: Loss 0.01069
Epoch [ 23]: Loss 0.00957
Epoch [ 23]: Loss 0.00966
Epoch [ 23]: Loss 0.00875
Validation: Loss 0.01002 Accuracy 1.00000
Validation: Loss 0.01001 Accuracy 1.00000
Epoch [ 24]: Loss 0.01013
Epoch [ 24]: Loss 0.00904
Epoch [ 24]: Loss 0.00941
Epoch [ 24]: Loss 0.00929
Epoch [ 24]: Loss 0.00982
Epoch [ 24]: Loss 0.00862
Epoch [ 24]: Loss 0.01167
Validation: Loss 0.00914 Accuracy 1.00000
Validation: Loss 0.00913 Accuracy 1.00000
Epoch [ 25]: Loss 0.00862
Epoch [ 25]: Loss 0.00834
Epoch [ 25]: Loss 0.00848
Epoch [ 25]: Loss 0.00881
Epoch [ 25]: Loss 0.00826
Epoch [ 25]: Loss 0.00885
Epoch [ 25]: Loss 0.00872
Validation: Loss 0.00815 Accuracy 1.00000
Validation: Loss 0.00814 Accuracy 1.00000We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)Epoch [ 1]: Loss 0.62511
Epoch [ 1]: Loss 0.58580
Epoch [ 1]: Loss 0.57456
Epoch [ 1]: Loss 0.53250
Epoch [ 1]: Loss 0.52901
Epoch [ 1]: Loss 0.50098
Epoch [ 1]: Loss 0.48846
Validation: Loss 0.47042 Accuracy 1.00000
Validation: Loss 0.45265 Accuracy 1.00000
Epoch [ 2]: Loss 0.46477
Epoch [ 2]: Loss 0.45467
Epoch [ 2]: Loss 0.44332
Epoch [ 2]: Loss 0.43183
Epoch [ 2]: Loss 0.41632
Epoch [ 2]: Loss 0.39695
Epoch [ 2]: Loss 0.39054
Validation: Loss 0.37357 Accuracy 1.00000
Validation: Loss 0.35305 Accuracy 1.00000
Epoch [ 3]: Loss 0.36736
Epoch [ 3]: Loss 0.37108
Epoch [ 3]: Loss 0.34727
Epoch [ 3]: Loss 0.32543
Epoch [ 3]: Loss 0.32855
Epoch [ 3]: Loss 0.30059
Epoch [ 3]: Loss 0.33959
Validation: Loss 0.28889 Accuracy 1.00000
Validation: Loss 0.26646 Accuracy 1.00000
Epoch [ 4]: Loss 0.27500
Epoch [ 4]: Loss 0.28522
Epoch [ 4]: Loss 0.27535
Epoch [ 4]: Loss 0.24723
Epoch [ 4]: Loss 0.25147
Epoch [ 4]: Loss 0.24482
Epoch [ 4]: Loss 0.21234
Validation: Loss 0.21940 Accuracy 1.00000
Validation: Loss 0.19670 Accuracy 1.00000
Epoch [ 5]: Loss 0.21717
Epoch [ 5]: Loss 0.21005
Epoch [ 5]: Loss 0.20097
Epoch [ 5]: Loss 0.18944
Epoch [ 5]: Loss 0.19837
Epoch [ 5]: Loss 0.17691
Epoch [ 5]: Loss 0.15471
Validation: Loss 0.16378 Accuracy 1.00000
Validation: Loss 0.14238 Accuracy 1.00000
Epoch [ 6]: Loss 0.14907
Epoch [ 6]: Loss 0.17096
Epoch [ 6]: Loss 0.14491
Epoch [ 6]: Loss 0.15415
Epoch [ 6]: Loss 0.12732
Epoch [ 6]: Loss 0.13610
Epoch [ 6]: Loss 0.13404
Validation: Loss 0.12086 Accuracy 1.00000
Validation: Loss 0.10280 Accuracy 1.00000
Epoch [ 7]: Loss 0.11218
Epoch [ 7]: Loss 0.12446
Epoch [ 7]: Loss 0.10434
Epoch [ 7]: Loss 0.10600
Epoch [ 7]: Loss 0.10753
Epoch [ 7]: Loss 0.09145
Epoch [ 7]: Loss 0.10018
Validation: Loss 0.08695 Accuracy 1.00000
Validation: Loss 0.07342 Accuracy 1.00000
Epoch [ 8]: Loss 0.08309
Epoch [ 8]: Loss 0.08118
Epoch [ 8]: Loss 0.08338
Epoch [ 8]: Loss 0.07706
Epoch [ 8]: Loss 0.07309
Epoch [ 8]: Loss 0.06513
Epoch [ 8]: Loss 0.05652
Validation: Loss 0.06052 Accuracy 1.00000
Validation: Loss 0.05150 Accuracy 1.00000
Epoch [ 9]: Loss 0.06397
Epoch [ 9]: Loss 0.05313
Epoch [ 9]: Loss 0.05611
Epoch [ 9]: Loss 0.05228
Epoch [ 9]: Loss 0.04761
Epoch [ 9]: Loss 0.05207
Epoch [ 9]: Loss 0.04120
Validation: Loss 0.04462 Accuracy 1.00000
Validation: Loss 0.03834 Accuracy 1.00000
Epoch [ 10]: Loss 0.04387
Epoch [ 10]: Loss 0.04480
Epoch [ 10]: Loss 0.04296
Epoch [ 10]: Loss 0.04029
Epoch [ 10]: Loss 0.03845
Epoch [ 10]: Loss 0.03702
Epoch [ 10]: Loss 0.03665
Validation: Loss 0.03597 Accuracy 1.00000
Validation: Loss 0.03090 Accuracy 1.00000
Epoch [ 11]: Loss 0.03239
Epoch [ 11]: Loss 0.03502
Epoch [ 11]: Loss 0.03643
Epoch [ 11]: Loss 0.03282
Epoch [ 11]: Loss 0.03230
Epoch [ 11]: Loss 0.03393
Epoch [ 11]: Loss 0.03191
Validation: Loss 0.03050 Accuracy 1.00000
Validation: Loss 0.02608 Accuracy 1.00000
Epoch [ 12]: Loss 0.03072
Epoch [ 12]: Loss 0.02933
Epoch [ 12]: Loss 0.02973
Epoch [ 12]: Loss 0.02618
Epoch [ 12]: Loss 0.03022
Epoch [ 12]: Loss 0.02776
Epoch [ 12]: Loss 0.02768
Validation: Loss 0.02652 Accuracy 1.00000
Validation: Loss 0.02260 Accuracy 1.00000
Epoch [ 13]: Loss 0.02927
Epoch [ 13]: Loss 0.02735
Epoch [ 13]: Loss 0.02444
Epoch [ 13]: Loss 0.02418
Epoch [ 13]: Loss 0.02486
Epoch [ 13]: Loss 0.02272
Epoch [ 13]: Loss 0.02274
Validation: Loss 0.02343 Accuracy 1.00000
Validation: Loss 0.01990 Accuracy 1.00000
Epoch [ 14]: Loss 0.02363
Epoch [ 14]: Loss 0.02282
Epoch [ 14]: Loss 0.02408
Epoch [ 14]: Loss 0.02229
Epoch [ 14]: Loss 0.02117
Epoch [ 14]: Loss 0.02103
Epoch [ 14]: Loss 0.02261
Validation: Loss 0.02097 Accuracy 1.00000
Validation: Loss 0.01775 Accuracy 1.00000
Epoch [ 15]: Loss 0.02039
Epoch [ 15]: Loss 0.02010
Epoch [ 15]: Loss 0.01996
Epoch [ 15]: Loss 0.02132
Epoch [ 15]: Loss 0.02028
Epoch [ 15]: Loss 0.01912
Epoch [ 15]: Loss 0.02045
Validation: Loss 0.01894 Accuracy 1.00000
Validation: Loss 0.01597 Accuracy 1.00000
Epoch [ 16]: Loss 0.01857
Epoch [ 16]: Loss 0.01935
Epoch [ 16]: Loss 0.01823
Epoch [ 16]: Loss 0.01933
Epoch [ 16]: Loss 0.01691
Epoch [ 16]: Loss 0.01766
Epoch [ 16]: Loss 0.01731
Validation: Loss 0.01721 Accuracy 1.00000
Validation: Loss 0.01446 Accuracy 1.00000
Epoch [ 17]: Loss 0.01735
Epoch [ 17]: Loss 0.01699
Epoch [ 17]: Loss 0.01565
Epoch [ 17]: Loss 0.01639
Epoch [ 17]: Loss 0.01649
Epoch [ 17]: Loss 0.01720
Epoch [ 17]: Loss 0.01640
Validation: Loss 0.01574 Accuracy 1.00000
Validation: Loss 0.01318 Accuracy 1.00000
Epoch [ 18]: Loss 0.01693
Epoch [ 18]: Loss 0.01508
Epoch [ 18]: Loss 0.01462
Epoch [ 18]: Loss 0.01488
Epoch [ 18]: Loss 0.01659
Epoch [ 18]: Loss 0.01439
Epoch [ 18]: Loss 0.01220
Validation: Loss 0.01448 Accuracy 1.00000
Validation: Loss 0.01208 Accuracy 1.00000
Epoch [ 19]: Loss 0.01602
Epoch [ 19]: Loss 0.01445
Epoch [ 19]: Loss 0.01387
Epoch [ 19]: Loss 0.01343
Epoch [ 19]: Loss 0.01328
Epoch [ 19]: Loss 0.01381
Epoch [ 19]: Loss 0.01286
Validation: Loss 0.01339 Accuracy 1.00000
Validation: Loss 0.01114 Accuracy 1.00000
Epoch [ 20]: Loss 0.01352
Epoch [ 20]: Loss 0.01324
Epoch [ 20]: Loss 0.01399
Epoch [ 20]: Loss 0.01270
Epoch [ 20]: Loss 0.01193
Epoch [ 20]: Loss 0.01294
Epoch [ 20]: Loss 0.01291
Validation: Loss 0.01242 Accuracy 1.00000
Validation: Loss 0.01032 Accuracy 1.00000
Epoch [ 21]: Loss 0.01290
Epoch [ 21]: Loss 0.01180
Epoch [ 21]: Loss 0.01179
Epoch [ 21]: Loss 0.01227
Epoch [ 21]: Loss 0.01156
Epoch [ 21]: Loss 0.01248
Epoch [ 21]: Loss 0.01153
Validation: Loss 0.01153 Accuracy 1.00000
Validation: Loss 0.00958 Accuracy 1.00000
Epoch [ 22]: Loss 0.01096
Epoch [ 22]: Loss 0.01038
Epoch [ 22]: Loss 0.01275
Epoch [ 22]: Loss 0.01049
Epoch [ 22]: Loss 0.01147
Epoch [ 22]: Loss 0.01167
Epoch [ 22]: Loss 0.00984
Validation: Loss 0.01065 Accuracy 1.00000
Validation: Loss 0.00885 Accuracy 1.00000
Epoch [ 23]: Loss 0.01022
Epoch [ 23]: Loss 0.01032
Epoch [ 23]: Loss 0.01019
Epoch [ 23]: Loss 0.01107
Epoch [ 23]: Loss 0.01110
Epoch [ 23]: Loss 0.00922
Epoch [ 23]: Loss 0.00974
Validation: Loss 0.00968 Accuracy 1.00000
Validation: Loss 0.00807 Accuracy 1.00000
Epoch [ 24]: Loss 0.00985
Epoch [ 24]: Loss 0.00972
Epoch [ 24]: Loss 0.00937
Epoch [ 24]: Loss 0.00896
Epoch [ 24]: Loss 0.00885
Epoch [ 24]: Loss 0.00926
Epoch [ 24]: Loss 0.00871
Validation: Loss 0.00861 Accuracy 1.00000
Validation: Loss 0.00722 Accuracy 1.00000
Epoch [ 25]: Loss 0.00871
Epoch [ 25]: Loss 0.00854
Epoch [ 25]: Loss 0.00778
Epoch [ 25]: Loss 0.00872
Epoch [ 25]: Loss 0.00788
Epoch [ 25]: Loss 0.00799
Epoch [ 25]: Loss 0.00801
Validation: Loss 0.00768 Accuracy 1.00000
Validation: Loss 0.00649 Accuracy 1.00000Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trainedLet's try loading the model
@load "trained_model.jld2" ps_trained st_trained2-element Vector{Symbol}:
:ps_trained
:st_trainedAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.4+0
Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.422 GiB / 4.750 GiB available)This page was generated using Literate.jl.