Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, StatisticsDataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
endget_dataloaders (generic function with 1 method)Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
endWe won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
endMain.var"##225".SpiralClassifierWe can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
endUsing the @compact API
We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
endSpiralClassifierCompact (generic function with 1 method)Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)accuracy (generic function with 1 method)Training the Model
function main(model_type)
dev = gpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() .|> dev
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
for epoch in 1:25
# Train the model
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)Epoch [ 1]: Loss 0.62418
Epoch [ 1]: Loss 0.58579
Epoch [ 1]: Loss 0.56593
Epoch [ 1]: Loss 0.54239
Epoch [ 1]: Loss 0.51930
Epoch [ 1]: Loss 0.50690
Epoch [ 1]: Loss 0.49584
Validation: Loss 0.46165 Accuracy 1.00000
Validation: Loss 0.46952 Accuracy 1.00000
Epoch [ 2]: Loss 0.47788
Epoch [ 2]: Loss 0.45420
Epoch [ 2]: Loss 0.43943
Epoch [ 2]: Loss 0.42512
Epoch [ 2]: Loss 0.40691
Epoch [ 2]: Loss 0.40282
Epoch [ 2]: Loss 0.38509
Validation: Loss 0.36373 Accuracy 1.00000
Validation: Loss 0.37269 Accuracy 1.00000
Epoch [ 3]: Loss 0.37104
Epoch [ 3]: Loss 0.38009
Epoch [ 3]: Loss 0.34337
Epoch [ 3]: Loss 0.32877
Epoch [ 3]: Loss 0.31390
Epoch [ 3]: Loss 0.31147
Epoch [ 3]: Loss 0.29863
Validation: Loss 0.27834 Accuracy 1.00000
Validation: Loss 0.28801 Accuracy 1.00000
Epoch [ 4]: Loss 0.29062
Epoch [ 4]: Loss 0.27988
Epoch [ 4]: Loss 0.26162
Epoch [ 4]: Loss 0.25111
Epoch [ 4]: Loss 0.25346
Epoch [ 4]: Loss 0.23974
Epoch [ 4]: Loss 0.21767
Validation: Loss 0.20848 Accuracy 1.00000
Validation: Loss 0.21814 Accuracy 1.00000
Epoch [ 5]: Loss 0.20867
Epoch [ 5]: Loss 0.21213
Epoch [ 5]: Loss 0.19058
Epoch [ 5]: Loss 0.19469
Epoch [ 5]: Loss 0.19829
Epoch [ 5]: Loss 0.18660
Epoch [ 5]: Loss 0.15299
Validation: Loss 0.15361 Accuracy 1.00000
Validation: Loss 0.16244 Accuracy 1.00000
Epoch [ 6]: Loss 0.16677
Epoch [ 6]: Loss 0.16597
Epoch [ 6]: Loss 0.14248
Epoch [ 6]: Loss 0.14270
Epoch [ 6]: Loss 0.13129
Epoch [ 6]: Loss 0.13175
Epoch [ 6]: Loss 0.13294
Validation: Loss 0.11233 Accuracy 1.00000
Validation: Loss 0.11952 Accuracy 1.00000
Epoch [ 7]: Loss 0.11907
Epoch [ 7]: Loss 0.10307
Epoch [ 7]: Loss 0.11472
Epoch [ 7]: Loss 0.10291
Epoch [ 7]: Loss 0.10372
Epoch [ 7]: Loss 0.10140
Epoch [ 7]: Loss 0.09050
Validation: Loss 0.08060 Accuracy 1.00000
Validation: Loss 0.08589 Accuracy 1.00000
Epoch [ 8]: Loss 0.08299
Epoch [ 8]: Loss 0.08522
Epoch [ 8]: Loss 0.08371
Epoch [ 8]: Loss 0.07357
Epoch [ 8]: Loss 0.06734
Epoch [ 8]: Loss 0.06675
Epoch [ 8]: Loss 0.06284
Validation: Loss 0.05620 Accuracy 1.00000
Validation: Loss 0.05972 Accuracy 1.00000
Epoch [ 9]: Loss 0.06122
Epoch [ 9]: Loss 0.05772
Epoch [ 9]: Loss 0.05625
Epoch [ 9]: Loss 0.05107
Epoch [ 9]: Loss 0.04973
Epoch [ 9]: Loss 0.04731
Epoch [ 9]: Loss 0.04114
Validation: Loss 0.04140 Accuracy 1.00000
Validation: Loss 0.04387 Accuracy 1.00000
Epoch [ 10]: Loss 0.04597
Epoch [ 10]: Loss 0.04468
Epoch [ 10]: Loss 0.04027
Epoch [ 10]: Loss 0.03766
Epoch [ 10]: Loss 0.03869
Epoch [ 10]: Loss 0.03671
Epoch [ 10]: Loss 0.03978
Validation: Loss 0.03330 Accuracy 1.00000
Validation: Loss 0.03531 Accuracy 1.00000
Epoch [ 11]: Loss 0.03641
Epoch [ 11]: Loss 0.03406
Epoch [ 11]: Loss 0.03407
Epoch [ 11]: Loss 0.03293
Epoch [ 11]: Loss 0.03303
Epoch [ 11]: Loss 0.03120
Epoch [ 11]: Loss 0.02747
Validation: Loss 0.02813 Accuracy 1.00000
Validation: Loss 0.02988 Accuracy 1.00000
Epoch [ 12]: Loss 0.03035
Epoch [ 12]: Loss 0.02948
Epoch [ 12]: Loss 0.02886
Epoch [ 12]: Loss 0.02897
Epoch [ 12]: Loss 0.02867
Epoch [ 12]: Loss 0.02558
Epoch [ 12]: Loss 0.02645
Validation: Loss 0.02441 Accuracy 1.00000
Validation: Loss 0.02597 Accuracy 1.00000
Epoch [ 13]: Loss 0.02640
Epoch [ 13]: Loss 0.02505
Epoch [ 13]: Loss 0.02586
Epoch [ 13]: Loss 0.02413
Epoch [ 13]: Loss 0.02457
Epoch [ 13]: Loss 0.02392
Epoch [ 13]: Loss 0.02495
Validation: Loss 0.02154 Accuracy 1.00000
Validation: Loss 0.02295 Accuracy 1.00000
Epoch [ 14]: Loss 0.02393
Epoch [ 14]: Loss 0.02394
Epoch [ 14]: Loss 0.02140
Epoch [ 14]: Loss 0.02144
Epoch [ 14]: Loss 0.02072
Epoch [ 14]: Loss 0.02172
Epoch [ 14]: Loss 0.02228
Validation: Loss 0.01923 Accuracy 1.00000
Validation: Loss 0.02051 Accuracy 1.00000
Epoch [ 15]: Loss 0.02086
Epoch [ 15]: Loss 0.02128
Epoch [ 15]: Loss 0.01953
Epoch [ 15]: Loss 0.02066
Epoch [ 15]: Loss 0.01842
Epoch [ 15]: Loss 0.01920
Epoch [ 15]: Loss 0.01809
Validation: Loss 0.01732 Accuracy 1.00000
Validation: Loss 0.01850 Accuracy 1.00000
Epoch [ 16]: Loss 0.01951
Epoch [ 16]: Loss 0.01857
Epoch [ 16]: Loss 0.01764
Epoch [ 16]: Loss 0.01806
Epoch [ 16]: Loss 0.01846
Epoch [ 16]: Loss 0.01676
Epoch [ 16]: Loss 0.01449
Validation: Loss 0.01571 Accuracy 1.00000
Validation: Loss 0.01681 Accuracy 1.00000
Epoch [ 17]: Loss 0.01790
Epoch [ 17]: Loss 0.01721
Epoch [ 17]: Loss 0.01653
Epoch [ 17]: Loss 0.01615
Epoch [ 17]: Loss 0.01641
Epoch [ 17]: Loss 0.01455
Epoch [ 17]: Loss 0.01555
Validation: Loss 0.01435 Accuracy 1.00000
Validation: Loss 0.01538 Accuracy 1.00000
Epoch [ 18]: Loss 0.01493
Epoch [ 18]: Loss 0.01504
Epoch [ 18]: Loss 0.01486
Epoch [ 18]: Loss 0.01653
Epoch [ 18]: Loss 0.01422
Epoch [ 18]: Loss 0.01420
Epoch [ 18]: Loss 0.01710
Validation: Loss 0.01319 Accuracy 1.00000
Validation: Loss 0.01415 Accuracy 1.00000
Epoch [ 19]: Loss 0.01415
Epoch [ 19]: Loss 0.01304
Epoch [ 19]: Loss 0.01302
Epoch [ 19]: Loss 0.01433
Epoch [ 19]: Loss 0.01432
Epoch [ 19]: Loss 0.01416
Epoch [ 19]: Loss 0.01478
Validation: Loss 0.01219 Accuracy 1.00000
Validation: Loss 0.01308 Accuracy 1.00000
Epoch [ 20]: Loss 0.01346
Epoch [ 20]: Loss 0.01344
Epoch [ 20]: Loss 0.01306
Epoch [ 20]: Loss 0.01266
Epoch [ 20]: Loss 0.01252
Epoch [ 20]: Loss 0.01251
Epoch [ 20]: Loss 0.01081
Validation: Loss 0.01130 Accuracy 1.00000
Validation: Loss 0.01213 Accuracy 1.00000
Epoch [ 21]: Loss 0.01272
Epoch [ 21]: Loss 0.01260
Epoch [ 21]: Loss 0.01141
Epoch [ 21]: Loss 0.01128
Epoch [ 21]: Loss 0.01175
Epoch [ 21]: Loss 0.01185
Epoch [ 21]: Loss 0.01185
Validation: Loss 0.01049 Accuracy 1.00000
Validation: Loss 0.01126 Accuracy 1.00000
Epoch [ 22]: Loss 0.01238
Epoch [ 22]: Loss 0.01004
Epoch [ 22]: Loss 0.01071
Epoch [ 22]: Loss 0.01079
Epoch [ 22]: Loss 0.01111
Epoch [ 22]: Loss 0.01182
Epoch [ 22]: Loss 0.00929
Validation: Loss 0.00970 Accuracy 1.00000
Validation: Loss 0.01041 Accuracy 1.00000
Epoch [ 23]: Loss 0.01063
Epoch [ 23]: Loss 0.00979
Epoch [ 23]: Loss 0.01155
Epoch [ 23]: Loss 0.01052
Epoch [ 23]: Loss 0.01019
Epoch [ 23]: Loss 0.00873
Epoch [ 23]: Loss 0.00930
Validation: Loss 0.00885 Accuracy 1.00000
Validation: Loss 0.00948 Accuracy 1.00000
Epoch [ 24]: Loss 0.00987
Epoch [ 24]: Loss 0.00944
Epoch [ 24]: Loss 0.00989
Epoch [ 24]: Loss 0.00913
Epoch [ 24]: Loss 0.00871
Epoch [ 24]: Loss 0.00866
Epoch [ 24]: Loss 0.00732
Validation: Loss 0.00790 Accuracy 1.00000
Validation: Loss 0.00844 Accuracy 1.00000
Epoch [ 25]: Loss 0.00924
Epoch [ 25]: Loss 0.00861
Epoch [ 25]: Loss 0.00853
Epoch [ 25]: Loss 0.00779
Epoch [ 25]: Loss 0.00775
Epoch [ 25]: Loss 0.00744
Epoch [ 25]: Loss 0.00710
Validation: Loss 0.00707 Accuracy 1.00000
Validation: Loss 0.00754 Accuracy 1.00000We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)Epoch [ 1]: Loss 0.61648
Epoch [ 1]: Loss 0.60569
Epoch [ 1]: Loss 0.55924
Epoch [ 1]: Loss 0.53742
Epoch [ 1]: Loss 0.51968
Epoch [ 1]: Loss 0.50163
Epoch [ 1]: Loss 0.48797
Validation: Loss 0.46968 Accuracy 1.00000
Validation: Loss 0.46657 Accuracy 1.00000
Epoch [ 2]: Loss 0.47505
Epoch [ 2]: Loss 0.45209
Epoch [ 2]: Loss 0.43445
Epoch [ 2]: Loss 0.43346
Epoch [ 2]: Loss 0.40767
Epoch [ 2]: Loss 0.39600
Epoch [ 2]: Loss 0.38576
Validation: Loss 0.37245 Accuracy 1.00000
Validation: Loss 0.36874 Accuracy 1.00000
Epoch [ 3]: Loss 0.36002
Epoch [ 3]: Loss 0.36527
Epoch [ 3]: Loss 0.33697
Epoch [ 3]: Loss 0.34172
Epoch [ 3]: Loss 0.32651
Epoch [ 3]: Loss 0.30933
Epoch [ 3]: Loss 0.29770
Validation: Loss 0.28741 Accuracy 1.00000
Validation: Loss 0.28346 Accuracy 1.00000
Epoch [ 4]: Loss 0.27948
Epoch [ 4]: Loss 0.27427
Epoch [ 4]: Loss 0.25578
Epoch [ 4]: Loss 0.26128
Epoch [ 4]: Loss 0.26132
Epoch [ 4]: Loss 0.22517
Epoch [ 4]: Loss 0.24097
Validation: Loss 0.21734 Accuracy 1.00000
Validation: Loss 0.21362 Accuracy 1.00000
Epoch [ 5]: Loss 0.21696
Epoch [ 5]: Loss 0.20913
Epoch [ 5]: Loss 0.19235
Epoch [ 5]: Loss 0.19339
Epoch [ 5]: Loss 0.18199
Epoch [ 5]: Loss 0.18324
Epoch [ 5]: Loss 0.14668
Validation: Loss 0.16142 Accuracy 1.00000
Validation: Loss 0.15829 Accuracy 1.00000
Epoch [ 6]: Loss 0.16548
Epoch [ 6]: Loss 0.16349
Epoch [ 6]: Loss 0.14657
Epoch [ 6]: Loss 0.14099
Epoch [ 6]: Loss 0.12364
Epoch [ 6]: Loss 0.13009
Epoch [ 6]: Loss 0.11028
Validation: Loss 0.11837 Accuracy 1.00000
Validation: Loss 0.11600 Accuracy 1.00000
Epoch [ 7]: Loss 0.10952
Epoch [ 7]: Loss 0.11348
Epoch [ 7]: Loss 0.11021
Epoch [ 7]: Loss 0.10278
Epoch [ 7]: Loss 0.09154
Epoch [ 7]: Loss 0.10022
Epoch [ 7]: Loss 0.09825
Validation: Loss 0.08504 Accuracy 1.00000
Validation: Loss 0.08340 Accuracy 1.00000
Epoch [ 8]: Loss 0.08598
Epoch [ 8]: Loss 0.08123
Epoch [ 8]: Loss 0.07999
Epoch [ 8]: Loss 0.06669
Epoch [ 8]: Loss 0.07046
Epoch [ 8]: Loss 0.06645
Epoch [ 8]: Loss 0.05322
Validation: Loss 0.05916 Accuracy 1.00000
Validation: Loss 0.05808 Accuracy 1.00000
Epoch [ 9]: Loss 0.06442
Epoch [ 9]: Loss 0.05365
Epoch [ 9]: Loss 0.05147
Epoch [ 9]: Loss 0.05029
Epoch [ 9]: Loss 0.04758
Epoch [ 9]: Loss 0.04685
Epoch [ 9]: Loss 0.04956
Validation: Loss 0.04375 Accuracy 1.00000
Validation: Loss 0.04287 Accuracy 1.00000
Epoch [ 10]: Loss 0.04564
Epoch [ 10]: Loss 0.04405
Epoch [ 10]: Loss 0.03796
Epoch [ 10]: Loss 0.03963
Epoch [ 10]: Loss 0.03850
Epoch [ 10]: Loss 0.03521
Epoch [ 10]: Loss 0.03610
Validation: Loss 0.03532 Accuracy 1.00000
Validation: Loss 0.03456 Accuracy 1.00000
Epoch [ 11]: Loss 0.03507
Epoch [ 11]: Loss 0.03105
Epoch [ 11]: Loss 0.03418
Epoch [ 11]: Loss 0.03227
Epoch [ 11]: Loss 0.03260
Epoch [ 11]: Loss 0.03136
Epoch [ 11]: Loss 0.03553
Validation: Loss 0.02996 Accuracy 1.00000
Validation: Loss 0.02930 Accuracy 1.00000
Epoch [ 12]: Loss 0.03098
Epoch [ 12]: Loss 0.03034
Epoch [ 12]: Loss 0.02851
Epoch [ 12]: Loss 0.02817
Epoch [ 12]: Loss 0.02666
Epoch [ 12]: Loss 0.02503
Epoch [ 12]: Loss 0.02634
Validation: Loss 0.02603 Accuracy 1.00000
Validation: Loss 0.02544 Accuracy 1.00000
Epoch [ 13]: Loss 0.02473
Epoch [ 13]: Loss 0.02426
Epoch [ 13]: Loss 0.02448
Epoch [ 13]: Loss 0.02658
Epoch [ 13]: Loss 0.02259
Epoch [ 13]: Loss 0.02529
Epoch [ 13]: Loss 0.02406
Validation: Loss 0.02302 Accuracy 1.00000
Validation: Loss 0.02249 Accuracy 1.00000
Epoch [ 14]: Loss 0.02143
Epoch [ 14]: Loss 0.02288
Epoch [ 14]: Loss 0.02464
Epoch [ 14]: Loss 0.02272
Epoch [ 14]: Loss 0.01895
Epoch [ 14]: Loss 0.02051
Epoch [ 14]: Loss 0.02293
Validation: Loss 0.02060 Accuracy 1.00000
Validation: Loss 0.02011 Accuracy 1.00000
Epoch [ 15]: Loss 0.02189
Epoch [ 15]: Loss 0.01995
Epoch [ 15]: Loss 0.01945
Epoch [ 15]: Loss 0.02023
Epoch [ 15]: Loss 0.01856
Epoch [ 15]: Loss 0.01759
Epoch [ 15]: Loss 0.02073
Validation: Loss 0.01858 Accuracy 1.00000
Validation: Loss 0.01813 Accuracy 1.00000
Epoch [ 16]: Loss 0.01992
Epoch [ 16]: Loss 0.01900
Epoch [ 16]: Loss 0.01808
Epoch [ 16]: Loss 0.01701
Epoch [ 16]: Loss 0.01582
Epoch [ 16]: Loss 0.01750
Epoch [ 16]: Loss 0.01509
Validation: Loss 0.01688 Accuracy 1.00000
Validation: Loss 0.01646 Accuracy 1.00000
Epoch [ 17]: Loss 0.01581
Epoch [ 17]: Loss 0.01652
Epoch [ 17]: Loss 0.01631
Epoch [ 17]: Loss 0.01684
Epoch [ 17]: Loss 0.01589
Epoch [ 17]: Loss 0.01615
Epoch [ 17]: Loss 0.01434
Validation: Loss 0.01544 Accuracy 1.00000
Validation: Loss 0.01506 Accuracy 1.00000
Epoch [ 18]: Loss 0.01574
Epoch [ 18]: Loss 0.01368
Epoch [ 18]: Loss 0.01452
Epoch [ 18]: Loss 0.01589
Epoch [ 18]: Loss 0.01487
Epoch [ 18]: Loss 0.01538
Epoch [ 18]: Loss 0.01043
Validation: Loss 0.01421 Accuracy 1.00000
Validation: Loss 0.01385 Accuracy 1.00000
Epoch [ 19]: Loss 0.01328
Epoch [ 19]: Loss 0.01446
Epoch [ 19]: Loss 0.01391
Epoch [ 19]: Loss 0.01376
Epoch [ 19]: Loss 0.01254
Epoch [ 19]: Loss 0.01384
Epoch [ 19]: Loss 0.01473
Validation: Loss 0.01314 Accuracy 1.00000
Validation: Loss 0.01281 Accuracy 1.00000
Epoch [ 20]: Loss 0.01226
Epoch [ 20]: Loss 0.01289
Epoch [ 20]: Loss 0.01313
Epoch [ 20]: Loss 0.01236
Epoch [ 20]: Loss 0.01370
Epoch [ 20]: Loss 0.01173
Epoch [ 20]: Loss 0.01218
Validation: Loss 0.01217 Accuracy 1.00000
Validation: Loss 0.01187 Accuracy 1.00000
Epoch [ 21]: Loss 0.01195
Epoch [ 21]: Loss 0.01182
Epoch [ 21]: Loss 0.01193
Epoch [ 21]: Loss 0.01133
Epoch [ 21]: Loss 0.01177
Epoch [ 21]: Loss 0.01174
Epoch [ 21]: Loss 0.01086
Validation: Loss 0.01125 Accuracy 1.00000
Validation: Loss 0.01097 Accuracy 1.00000
Epoch [ 22]: Loss 0.01191
Epoch [ 22]: Loss 0.01156
Epoch [ 22]: Loss 0.01099
Epoch [ 22]: Loss 0.01076
Epoch [ 22]: Loss 0.01030
Epoch [ 22]: Loss 0.00967
Epoch [ 22]: Loss 0.00951
Validation: Loss 0.01027 Accuracy 1.00000
Validation: Loss 0.01003 Accuracy 1.00000
Epoch [ 23]: Loss 0.00990
Epoch [ 23]: Loss 0.01014
Epoch [ 23]: Loss 0.00961
Epoch [ 23]: Loss 0.00965
Epoch [ 23]: Loss 0.00968
Epoch [ 23]: Loss 0.00996
Epoch [ 23]: Loss 0.00890
Validation: Loss 0.00918 Accuracy 1.00000
Validation: Loss 0.00899 Accuracy 1.00000
Epoch [ 24]: Loss 0.00876
Epoch [ 24]: Loss 0.00921
Epoch [ 24]: Loss 0.00889
Epoch [ 24]: Loss 0.00794
Epoch [ 24]: Loss 0.00889
Epoch [ 24]: Loss 0.00845
Epoch [ 24]: Loss 0.00914
Validation: Loss 0.00816 Accuracy 1.00000
Validation: Loss 0.00799 Accuracy 1.00000
Epoch [ 25]: Loss 0.00803
Epoch [ 25]: Loss 0.00762
Epoch [ 25]: Loss 0.00793
Epoch [ 25]: Loss 0.00754
Epoch [ 25]: Loss 0.00764
Epoch [ 25]: Loss 0.00767
Epoch [ 25]: Loss 0.00899
Validation: Loss 0.00740 Accuracy 1.00000
Validation: Loss 0.00725 Accuracy 1.00000Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trainedLet's try loading the model
@load "trained_model.jld2" ps_trained st_trained2-element Vector{Symbol}:
:ps_trained
:st_trainedAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0
Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.453 GiB / 4.750 GiB available)This page was generated using Literate.jl.