Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = gpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() .|> dev
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
for epoch in 1:25
# Train the model
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
Epoch [ 1]: Loss 0.62124
Epoch [ 1]: Loss 0.59112
Epoch [ 1]: Loss 0.56059
Epoch [ 1]: Loss 0.53948
Epoch [ 1]: Loss 0.52964
Epoch [ 1]: Loss 0.50251
Epoch [ 1]: Loss 0.47554
Validation: Loss 0.46974 Accuracy 1.00000
Validation: Loss 0.45945 Accuracy 1.00000
Epoch [ 2]: Loss 0.46349
Epoch [ 2]: Loss 0.45522
Epoch [ 2]: Loss 0.44148
Epoch [ 2]: Loss 0.42113
Epoch [ 2]: Loss 0.41486
Epoch [ 2]: Loss 0.40422
Epoch [ 2]: Loss 0.39463
Validation: Loss 0.37277 Accuracy 1.00000
Validation: Loss 0.36055 Accuracy 1.00000
Epoch [ 3]: Loss 0.36452
Epoch [ 3]: Loss 0.36135
Epoch [ 3]: Loss 0.35644
Epoch [ 3]: Loss 0.32859
Epoch [ 3]: Loss 0.32668
Epoch [ 3]: Loss 0.30146
Epoch [ 3]: Loss 0.31281
Validation: Loss 0.28783 Accuracy 1.00000
Validation: Loss 0.27462 Accuracy 1.00000
Epoch [ 4]: Loss 0.27613
Epoch [ 4]: Loss 0.27654
Epoch [ 4]: Loss 0.27543
Epoch [ 4]: Loss 0.26748
Epoch [ 4]: Loss 0.23477
Epoch [ 4]: Loss 0.23649
Epoch [ 4]: Loss 0.21447
Validation: Loss 0.21765 Accuracy 1.00000
Validation: Loss 0.20443 Accuracy 1.00000
Epoch [ 5]: Loss 0.23198
Epoch [ 5]: Loss 0.19341
Epoch [ 5]: Loss 0.20363
Epoch [ 5]: Loss 0.19023
Epoch [ 5]: Loss 0.17448
Epoch [ 5]: Loss 0.17787
Epoch [ 5]: Loss 0.17737
Validation: Loss 0.16169 Accuracy 1.00000
Validation: Loss 0.14953 Accuracy 1.00000
Epoch [ 6]: Loss 0.15398
Epoch [ 6]: Loss 0.14600
Epoch [ 6]: Loss 0.14303
Epoch [ 6]: Loss 0.14745
Epoch [ 6]: Loss 0.13438
Epoch [ 6]: Loss 0.13613
Epoch [ 6]: Loss 0.14339
Validation: Loss 0.11852 Accuracy 1.00000
Validation: Loss 0.10852 Accuracy 1.00000
Epoch [ 7]: Loss 0.12618
Epoch [ 7]: Loss 0.11306
Epoch [ 7]: Loss 0.11519
Epoch [ 7]: Loss 0.10155
Epoch [ 7]: Loss 0.09263
Epoch [ 7]: Loss 0.08684
Epoch [ 7]: Loss 0.07856
Validation: Loss 0.08455 Accuracy 1.00000
Validation: Loss 0.07723 Accuracy 1.00000
Epoch [ 8]: Loss 0.08516
Epoch [ 8]: Loss 0.07670
Epoch [ 8]: Loss 0.07334
Epoch [ 8]: Loss 0.07463
Epoch [ 8]: Loss 0.06892
Epoch [ 8]: Loss 0.06460
Epoch [ 8]: Loss 0.06731
Validation: Loss 0.05894 Accuracy 1.00000
Validation: Loss 0.05402 Accuracy 1.00000
Epoch [ 9]: Loss 0.05723
Epoch [ 9]: Loss 0.05550
Epoch [ 9]: Loss 0.05291
Epoch [ 9]: Loss 0.05032
Epoch [ 9]: Loss 0.05052
Epoch [ 9]: Loss 0.04744
Epoch [ 9]: Loss 0.04511
Validation: Loss 0.04383 Accuracy 1.00000
Validation: Loss 0.04037 Accuracy 1.00000
Epoch [ 10]: Loss 0.03820
Epoch [ 10]: Loss 0.04462
Epoch [ 10]: Loss 0.04101
Epoch [ 10]: Loss 0.04476
Epoch [ 10]: Loss 0.03514
Epoch [ 10]: Loss 0.03749
Epoch [ 10]: Loss 0.03838
Validation: Loss 0.03552 Accuracy 1.00000
Validation: Loss 0.03269 Accuracy 1.00000
Epoch [ 11]: Loss 0.03432
Epoch [ 11]: Loss 0.03486
Epoch [ 11]: Loss 0.03209
Epoch [ 11]: Loss 0.03268
Epoch [ 11]: Loss 0.03431
Epoch [ 11]: Loss 0.03226
Epoch [ 11]: Loss 0.02758
Validation: Loss 0.03016 Accuracy 1.00000
Validation: Loss 0.02769 Accuracy 1.00000
Epoch [ 12]: Loss 0.03089
Epoch [ 12]: Loss 0.02854
Epoch [ 12]: Loss 0.03082
Epoch [ 12]: Loss 0.02693
Epoch [ 12]: Loss 0.02546
Epoch [ 12]: Loss 0.02794
Epoch [ 12]: Loss 0.03010
Validation: Loss 0.02626 Accuracy 1.00000
Validation: Loss 0.02406 Accuracy 1.00000
Epoch [ 13]: Loss 0.02713
Epoch [ 13]: Loss 0.02577
Epoch [ 13]: Loss 0.02164
Epoch [ 13]: Loss 0.02613
Epoch [ 13]: Loss 0.02443
Epoch [ 13]: Loss 0.02483
Epoch [ 13]: Loss 0.02423
Validation: Loss 0.02323 Accuracy 1.00000
Validation: Loss 0.02125 Accuracy 1.00000
Epoch [ 14]: Loss 0.02202
Epoch [ 14]: Loss 0.02224
Epoch [ 14]: Loss 0.02442
Epoch [ 14]: Loss 0.02039
Epoch [ 14]: Loss 0.02272
Epoch [ 14]: Loss 0.02162
Epoch [ 14]: Loss 0.02090
Validation: Loss 0.02080 Accuracy 1.00000
Validation: Loss 0.01898 Accuracy 1.00000
Epoch [ 15]: Loss 0.02078
Epoch [ 15]: Loss 0.01922
Epoch [ 15]: Loss 0.02123
Epoch [ 15]: Loss 0.01900
Epoch [ 15]: Loss 0.01976
Epoch [ 15]: Loss 0.01949
Epoch [ 15]: Loss 0.02024
Validation: Loss 0.01878 Accuracy 1.00000
Validation: Loss 0.01711 Accuracy 1.00000
Epoch [ 16]: Loss 0.01894
Epoch [ 16]: Loss 0.01767
Epoch [ 16]: Loss 0.01859
Epoch [ 16]: Loss 0.01766
Epoch [ 16]: Loss 0.01857
Epoch [ 16]: Loss 0.01625
Epoch [ 16]: Loss 0.02031
Validation: Loss 0.01707 Accuracy 1.00000
Validation: Loss 0.01551 Accuracy 1.00000
Epoch [ 17]: Loss 0.01768
Epoch [ 17]: Loss 0.01683
Epoch [ 17]: Loss 0.01661
Epoch [ 17]: Loss 0.01596
Epoch [ 17]: Loss 0.01648
Epoch [ 17]: Loss 0.01597
Epoch [ 17]: Loss 0.01269
Validation: Loss 0.01559 Accuracy 1.00000
Validation: Loss 0.01415 Accuracy 1.00000
Epoch [ 18]: Loss 0.01602
Epoch [ 18]: Loss 0.01554
Epoch [ 18]: Loss 0.01392
Epoch [ 18]: Loss 0.01573
Epoch [ 18]: Loss 0.01499
Epoch [ 18]: Loss 0.01425
Epoch [ 18]: Loss 0.01435
Validation: Loss 0.01434 Accuracy 1.00000
Validation: Loss 0.01298 Accuracy 1.00000
Epoch [ 19]: Loss 0.01529
Epoch [ 19]: Loss 0.01389
Epoch [ 19]: Loss 0.01319
Epoch [ 19]: Loss 0.01424
Epoch [ 19]: Loss 0.01329
Epoch [ 19]: Loss 0.01340
Epoch [ 19]: Loss 0.01322
Validation: Loss 0.01323 Accuracy 1.00000
Validation: Loss 0.01197 Accuracy 1.00000
Epoch [ 20]: Loss 0.01272
Epoch [ 20]: Loss 0.01282
Epoch [ 20]: Loss 0.01330
Epoch [ 20]: Loss 0.01230
Epoch [ 20]: Loss 0.01337
Epoch [ 20]: Loss 0.01275
Epoch [ 20]: Loss 0.01080
Validation: Loss 0.01224 Accuracy 1.00000
Validation: Loss 0.01106 Accuracy 1.00000
Epoch [ 21]: Loss 0.01293
Epoch [ 21]: Loss 0.01173
Epoch [ 21]: Loss 0.01159
Epoch [ 21]: Loss 0.01180
Epoch [ 21]: Loss 0.01177
Epoch [ 21]: Loss 0.01075
Epoch [ 21]: Loss 0.01301
Validation: Loss 0.01125 Accuracy 1.00000
Validation: Loss 0.01017 Accuracy 1.00000
Epoch [ 22]: Loss 0.01014
Epoch [ 22]: Loss 0.01140
Epoch [ 22]: Loss 0.01133
Epoch [ 22]: Loss 0.01069
Epoch [ 22]: Loss 0.01104
Epoch [ 22]: Loss 0.01056
Epoch [ 22]: Loss 0.00901
Validation: Loss 0.01016 Accuracy 1.00000
Validation: Loss 0.00919 Accuracy 1.00000
Epoch [ 23]: Loss 0.01032
Epoch [ 23]: Loss 0.00986
Epoch [ 23]: Loss 0.00953
Epoch [ 23]: Loss 0.00954
Epoch [ 23]: Loss 0.00937
Epoch [ 23]: Loss 0.00946
Epoch [ 23]: Loss 0.00903
Validation: Loss 0.00900 Accuracy 1.00000
Validation: Loss 0.00817 Accuracy 1.00000
Epoch [ 24]: Loss 0.00939
Epoch [ 24]: Loss 0.00881
Epoch [ 24]: Loss 0.00883
Epoch [ 24]: Loss 0.00798
Epoch [ 24]: Loss 0.00831
Epoch [ 24]: Loss 0.00783
Epoch [ 24]: Loss 0.00951
Validation: Loss 0.00804 Accuracy 1.00000
Validation: Loss 0.00734 Accuracy 1.00000
Epoch [ 25]: Loss 0.00859
Epoch [ 25]: Loss 0.00799
Epoch [ 25]: Loss 0.00753
Epoch [ 25]: Loss 0.00765
Epoch [ 25]: Loss 0.00781
Epoch [ 25]: Loss 0.00727
Epoch [ 25]: Loss 0.00611
Validation: Loss 0.00737 Accuracy 1.00000
Validation: Loss 0.00674 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [ 1]: Loss 0.62745
Epoch [ 1]: Loss 0.59260
Epoch [ 1]: Loss 0.56156
Epoch [ 1]: Loss 0.54179
Epoch [ 1]: Loss 0.51448
Epoch [ 1]: Loss 0.49976
Epoch [ 1]: Loss 0.47605
Validation: Loss 0.47360 Accuracy 1.00000
Validation: Loss 0.47051 Accuracy 1.00000
Epoch [ 2]: Loss 0.46356
Epoch [ 2]: Loss 0.44939
Epoch [ 2]: Loss 0.44393
Epoch [ 2]: Loss 0.42764
Epoch [ 2]: Loss 0.41013
Epoch [ 2]: Loss 0.39318
Epoch [ 2]: Loss 0.39291
Validation: Loss 0.37700 Accuracy 1.00000
Validation: Loss 0.37330 Accuracy 1.00000
Epoch [ 3]: Loss 0.36664
Epoch [ 3]: Loss 0.36796
Epoch [ 3]: Loss 0.34074
Epoch [ 3]: Loss 0.33221
Epoch [ 3]: Loss 0.30405
Epoch [ 3]: Loss 0.31778
Epoch [ 3]: Loss 0.28704
Validation: Loss 0.29212 Accuracy 1.00000
Validation: Loss 0.28802 Accuracy 1.00000
Epoch [ 4]: Loss 0.27207
Epoch [ 4]: Loss 0.27869
Epoch [ 4]: Loss 0.27045
Epoch [ 4]: Loss 0.25777
Epoch [ 4]: Loss 0.22519
Epoch [ 4]: Loss 0.23846
Epoch [ 4]: Loss 0.24046
Validation: Loss 0.22187 Accuracy 1.00000
Validation: Loss 0.21769 Accuracy 1.00000
Epoch [ 5]: Loss 0.21279
Epoch [ 5]: Loss 0.19832
Epoch [ 5]: Loss 0.20248
Epoch [ 5]: Loss 0.18042
Epoch [ 5]: Loss 0.18462
Epoch [ 5]: Loss 0.17985
Epoch [ 5]: Loss 0.15683
Validation: Loss 0.16533 Accuracy 1.00000
Validation: Loss 0.16145 Accuracy 1.00000
Epoch [ 6]: Loss 0.15520
Epoch [ 6]: Loss 0.15812
Epoch [ 6]: Loss 0.13642
Epoch [ 6]: Loss 0.14315
Epoch [ 6]: Loss 0.13897
Epoch [ 6]: Loss 0.11991
Epoch [ 6]: Loss 0.11770
Validation: Loss 0.12127 Accuracy 1.00000
Validation: Loss 0.11802 Accuracy 1.00000
Epoch [ 7]: Loss 0.11105
Epoch [ 7]: Loss 0.11183
Epoch [ 7]: Loss 0.10622
Epoch [ 7]: Loss 0.09932
Epoch [ 7]: Loss 0.09284
Epoch [ 7]: Loss 0.09669
Epoch [ 7]: Loss 0.08107
Validation: Loss 0.08672 Accuracy 1.00000
Validation: Loss 0.08430 Accuracy 1.00000
Epoch [ 8]: Loss 0.08598
Epoch [ 8]: Loss 0.07321
Epoch [ 8]: Loss 0.06965
Epoch [ 8]: Loss 0.07609
Epoch [ 8]: Loss 0.06411
Epoch [ 8]: Loss 0.06593
Epoch [ 8]: Loss 0.06265
Validation: Loss 0.06032 Accuracy 1.00000
Validation: Loss 0.05872 Accuracy 1.00000
Epoch [ 9]: Loss 0.05744
Epoch [ 9]: Loss 0.05601
Epoch [ 9]: Loss 0.05251
Epoch [ 9]: Loss 0.04870
Epoch [ 9]: Loss 0.04794
Epoch [ 9]: Loss 0.04374
Epoch [ 9]: Loss 0.05199
Validation: Loss 0.04490 Accuracy 1.00000
Validation: Loss 0.04378 Accuracy 1.00000
Epoch [ 10]: Loss 0.04110
Epoch [ 10]: Loss 0.03807
Epoch [ 10]: Loss 0.04066
Epoch [ 10]: Loss 0.04300
Epoch [ 10]: Loss 0.03953
Epoch [ 10]: Loss 0.03455
Epoch [ 10]: Loss 0.03701
Validation: Loss 0.03643 Accuracy 1.00000
Validation: Loss 0.03551 Accuracy 1.00000
Epoch [ 11]: Loss 0.03334
Epoch [ 11]: Loss 0.03273
Epoch [ 11]: Loss 0.03163
Epoch [ 11]: Loss 0.03277
Epoch [ 11]: Loss 0.03231
Epoch [ 11]: Loss 0.03184
Epoch [ 11]: Loss 0.03541
Validation: Loss 0.03098 Accuracy 1.00000
Validation: Loss 0.03018 Accuracy 1.00000
Epoch [ 12]: Loss 0.02951
Epoch [ 12]: Loss 0.02788
Epoch [ 12]: Loss 0.02781
Epoch [ 12]: Loss 0.02653
Epoch [ 12]: Loss 0.02880
Epoch [ 12]: Loss 0.02732
Epoch [ 12]: Loss 0.02726
Validation: Loss 0.02697 Accuracy 1.00000
Validation: Loss 0.02626 Accuracy 1.00000
Epoch [ 13]: Loss 0.02719
Epoch [ 13]: Loss 0.02324
Epoch [ 13]: Loss 0.02516
Epoch [ 13]: Loss 0.02386
Epoch [ 13]: Loss 0.02326
Epoch [ 13]: Loss 0.02406
Epoch [ 13]: Loss 0.02489
Validation: Loss 0.02387 Accuracy 1.00000
Validation: Loss 0.02323 Accuracy 1.00000
Epoch [ 14]: Loss 0.02326
Epoch [ 14]: Loss 0.02230
Epoch [ 14]: Loss 0.02079
Epoch [ 14]: Loss 0.02181
Epoch [ 14]: Loss 0.02144
Epoch [ 14]: Loss 0.02038
Epoch [ 14]: Loss 0.02372
Validation: Loss 0.02137 Accuracy 1.00000
Validation: Loss 0.02079 Accuracy 1.00000
Epoch [ 15]: Loss 0.02048
Epoch [ 15]: Loss 0.02123
Epoch [ 15]: Loss 0.02003
Epoch [ 15]: Loss 0.01922
Epoch [ 15]: Loss 0.01688
Epoch [ 15]: Loss 0.01934
Epoch [ 15]: Loss 0.01938
Validation: Loss 0.01930 Accuracy 1.00000
Validation: Loss 0.01876 Accuracy 1.00000
Epoch [ 16]: Loss 0.01831
Epoch [ 16]: Loss 0.01870
Epoch [ 16]: Loss 0.01893
Epoch [ 16]: Loss 0.01642
Epoch [ 16]: Loss 0.01874
Epoch [ 16]: Loss 0.01518
Epoch [ 16]: Loss 0.01637
Validation: Loss 0.01755 Accuracy 1.00000
Validation: Loss 0.01705 Accuracy 1.00000
Epoch [ 17]: Loss 0.01585
Epoch [ 17]: Loss 0.01703
Epoch [ 17]: Loss 0.01605
Epoch [ 17]: Loss 0.01689
Epoch [ 17]: Loss 0.01519
Epoch [ 17]: Loss 0.01632
Epoch [ 17]: Loss 0.01263
Validation: Loss 0.01607 Accuracy 1.00000
Validation: Loss 0.01560 Accuracy 1.00000
Epoch [ 18]: Loss 0.01463
Epoch [ 18]: Loss 0.01474
Epoch [ 18]: Loss 0.01555
Epoch [ 18]: Loss 0.01552
Epoch [ 18]: Loss 0.01323
Epoch [ 18]: Loss 0.01460
Epoch [ 18]: Loss 0.01537
Validation: Loss 0.01481 Accuracy 1.00000
Validation: Loss 0.01437 Accuracy 1.00000
Epoch [ 19]: Loss 0.01359
Epoch [ 19]: Loss 0.01384
Epoch [ 19]: Loss 0.01407
Epoch [ 19]: Loss 0.01471
Epoch [ 19]: Loss 0.01259
Epoch [ 19]: Loss 0.01285
Epoch [ 19]: Loss 0.01320
Validation: Loss 0.01370 Accuracy 1.00000
Validation: Loss 0.01329 Accuracy 1.00000
Epoch [ 20]: Loss 0.01219
Epoch [ 20]: Loss 0.01367
Epoch [ 20]: Loss 0.01242
Epoch [ 20]: Loss 0.01106
Epoch [ 20]: Loss 0.01385
Epoch [ 20]: Loss 0.01226
Epoch [ 20]: Loss 0.01271
Validation: Loss 0.01271 Accuracy 1.00000
Validation: Loss 0.01232 Accuracy 1.00000
Epoch [ 21]: Loss 0.01389
Epoch [ 21]: Loss 0.01082
Epoch [ 21]: Loss 0.01174
Epoch [ 21]: Loss 0.01149
Epoch [ 21]: Loss 0.01196
Epoch [ 21]: Loss 0.01058
Epoch [ 21]: Loss 0.01002
Validation: Loss 0.01178 Accuracy 1.00000
Validation: Loss 0.01141 Accuracy 1.00000
Epoch [ 22]: Loss 0.01154
Epoch [ 22]: Loss 0.01109
Epoch [ 22]: Loss 0.00972
Epoch [ 22]: Loss 0.01085
Epoch [ 22]: Loss 0.01176
Epoch [ 22]: Loss 0.00977
Epoch [ 22]: Loss 0.01088
Validation: Loss 0.01083 Accuracy 1.00000
Validation: Loss 0.01049 Accuracy 1.00000
Epoch [ 23]: Loss 0.00982
Epoch [ 23]: Loss 0.01019
Epoch [ 23]: Loss 0.01073
Epoch [ 23]: Loss 0.01009
Epoch [ 23]: Loss 0.00958
Epoch [ 23]: Loss 0.00944
Epoch [ 23]: Loss 0.00733
Validation: Loss 0.00976 Accuracy 1.00000
Validation: Loss 0.00946 Accuracy 1.00000
Epoch [ 24]: Loss 0.00930
Epoch [ 24]: Loss 0.00864
Epoch [ 24]: Loss 0.00927
Epoch [ 24]: Loss 0.00846
Epoch [ 24]: Loss 0.00888
Epoch [ 24]: Loss 0.00867
Epoch [ 24]: Loss 0.00813
Validation: Loss 0.00868 Accuracy 1.00000
Validation: Loss 0.00842 Accuracy 1.00000
Epoch [ 25]: Loss 0.00837
Epoch [ 25]: Loss 0.00808
Epoch [ 25]: Loss 0.00761
Epoch [ 25]: Loss 0.00734
Epoch [ 25]: Loss 0.00787
Epoch [ 25]: Loss 0.00811
Epoch [ 25]: Loss 0.00782
Validation: Loss 0.00781 Accuracy 1.00000
Validation: Loss 0.00759 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.4, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6
CUDA libraries:
- CUBLAS: 12.4.5
- CURAND: 10.3.5
- CUFFT: 11.2.1
- CUSOLVER: 11.6.1
- CUSPARSE: 12.3.1
- CUPTI: 22.0.0
- NVML: 12.0.0+555.42.6
Julia packages:
- CUDA: 5.3.3
- CUDA_Driver_jll: 0.8.1+0
- CUDA_Runtime_jll: 0.12.1+0
Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
Preferences:
- CUDA_Driver_jll.compat: false
1 device:
0: Quadro RTX 5000 (sm_75, 15.203 GiB / 16.000 GiB available)
This page was generated using Literate.jl.