Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, AMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
Statistics
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractExplicitContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <:
Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
# Get the dataloaders
(train_loader, val_loader) = get_dataloaders()
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
dev = gpu_device()
train_state = Training.TrainState(rng, model, Adam(0.01f0); transform_variables=dev)
for epoch in 1:25
# Train the model
for (x, y) in train_loader
x = x |> dev
y = y |> dev
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
x = x |> dev
y = y |> dev
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
Epoch [ 1]: Loss 0.55994
Epoch [ 1]: Loss 0.51047
Epoch [ 1]: Loss 0.48322
Epoch [ 1]: Loss 0.44852
Epoch [ 1]: Loss 0.41908
Epoch [ 1]: Loss 0.40029
Epoch [ 1]: Loss 0.40570
Validation: Loss 0.36805 Accuracy 1.00000
Validation: Loss 0.36347 Accuracy 1.00000
Epoch [ 2]: Loss 0.36024
Epoch [ 2]: Loss 0.36134
Epoch [ 2]: Loss 0.33199
Epoch [ 2]: Loss 0.31666
Epoch [ 2]: Loss 0.30262
Epoch [ 2]: Loss 0.28749
Epoch [ 2]: Loss 0.27449
Validation: Loss 0.25828 Accuracy 1.00000
Validation: Loss 0.25599 Accuracy 1.00000
Epoch [ 3]: Loss 0.25749
Epoch [ 3]: Loss 0.24880
Epoch [ 3]: Loss 0.23370
Epoch [ 3]: Loss 0.22128
Epoch [ 3]: Loss 0.21029
Epoch [ 3]: Loss 0.19966
Epoch [ 3]: Loss 0.19183
Validation: Loss 0.18054 Accuracy 1.00000
Validation: Loss 0.17970 Accuracy 1.00000
Epoch [ 4]: Loss 0.18303
Epoch [ 4]: Loss 0.17026
Epoch [ 4]: Loss 0.16181
Epoch [ 4]: Loss 0.15543
Epoch [ 4]: Loss 0.14966
Epoch [ 4]: Loss 0.14363
Epoch [ 4]: Loss 0.14015
Validation: Loss 0.12906 Accuracy 1.00000
Validation: Loss 0.12858 Accuracy 1.00000
Epoch [ 5]: Loss 0.12910
Epoch [ 5]: Loss 0.12285
Epoch [ 5]: Loss 0.11793
Epoch [ 5]: Loss 0.11315
Epoch [ 5]: Loss 0.10810
Epoch [ 5]: Loss 0.10394
Epoch [ 5]: Loss 0.09901
Validation: Loss 0.09403 Accuracy 1.00000
Validation: Loss 0.09352 Accuracy 1.00000
Epoch [ 6]: Loss 0.09613
Epoch [ 6]: Loss 0.09010
Epoch [ 6]: Loss 0.08369
Epoch [ 6]: Loss 0.08278
Epoch [ 6]: Loss 0.08052
Epoch [ 6]: Loss 0.07641
Epoch [ 6]: Loss 0.07043
Validation: Loss 0.06942 Accuracy 1.00000
Validation: Loss 0.06877 Accuracy 1.00000
Epoch [ 7]: Loss 0.06958
Epoch [ 7]: Loss 0.06656
Epoch [ 7]: Loss 0.06450
Epoch [ 7]: Loss 0.06088
Epoch [ 7]: Loss 0.05907
Epoch [ 7]: Loss 0.05713
Epoch [ 7]: Loss 0.05341
Validation: Loss 0.05180 Accuracy 1.00000
Validation: Loss 0.05102 Accuracy 1.00000
Epoch [ 8]: Loss 0.05416
Epoch [ 8]: Loss 0.05140
Epoch [ 8]: Loss 0.04721
Epoch [ 8]: Loss 0.04507
Epoch [ 8]: Loss 0.04158
Epoch [ 8]: Loss 0.04248
Epoch [ 8]: Loss 0.04365
Validation: Loss 0.03900 Accuracy 1.00000
Validation: Loss 0.03818 Accuracy 1.00000
Epoch [ 9]: Loss 0.03816
Epoch [ 9]: Loss 0.03725
Epoch [ 9]: Loss 0.03583
Epoch [ 9]: Loss 0.03687
Epoch [ 9]: Loss 0.03341
Epoch [ 9]: Loss 0.03238
Epoch [ 9]: Loss 0.03037
Validation: Loss 0.02984 Accuracy 1.00000
Validation: Loss 0.02904 Accuracy 1.00000
Epoch [ 10]: Loss 0.02955
Epoch [ 10]: Loss 0.02806
Epoch [ 10]: Loss 0.02834
Epoch [ 10]: Loss 0.02977
Epoch [ 10]: Loss 0.02615
Epoch [ 10]: Loss 0.02378
Epoch [ 10]: Loss 0.02252
Validation: Loss 0.02353 Accuracy 1.00000
Validation: Loss 0.02280 Accuracy 1.00000
Epoch [ 11]: Loss 0.02311
Epoch [ 11]: Loss 0.02404
Epoch [ 11]: Loss 0.02285
Epoch [ 11]: Loss 0.02123
Epoch [ 11]: Loss 0.02148
Epoch [ 11]: Loss 0.01949
Epoch [ 11]: Loss 0.01811
Validation: Loss 0.01918 Accuracy 1.00000
Validation: Loss 0.01853 Accuracy 1.00000
Epoch [ 12]: Loss 0.02028
Epoch [ 12]: Loss 0.01769
Epoch [ 12]: Loss 0.01807
Epoch [ 12]: Loss 0.01881
Epoch [ 12]: Loss 0.01758
Epoch [ 12]: Loss 0.01637
Epoch [ 12]: Loss 0.01608
Validation: Loss 0.01615 Accuracy 1.00000
Validation: Loss 0.01558 Accuracy 1.00000
Epoch [ 13]: Loss 0.01741
Epoch [ 13]: Loss 0.01555
Epoch [ 13]: Loss 0.01513
Epoch [ 13]: Loss 0.01544
Epoch [ 13]: Loss 0.01442
Epoch [ 13]: Loss 0.01464
Epoch [ 13]: Loss 0.01343
Validation: Loss 0.01396 Accuracy 1.00000
Validation: Loss 0.01346 Accuracy 1.00000
Epoch [ 14]: Loss 0.01450
Epoch [ 14]: Loss 0.01401
Epoch [ 14]: Loss 0.01375
Epoch [ 14]: Loss 0.01333
Epoch [ 14]: Loss 0.01325
Epoch [ 14]: Loss 0.01204
Epoch [ 14]: Loss 0.01107
Validation: Loss 0.01231 Accuracy 1.00000
Validation: Loss 0.01187 Accuracy 1.00000
Epoch [ 15]: Loss 0.01335
Epoch [ 15]: Loss 0.01233
Epoch [ 15]: Loss 0.01220
Epoch [ 15]: Loss 0.01182
Epoch [ 15]: Loss 0.01112
Epoch [ 15]: Loss 0.01102
Epoch [ 15]: Loss 0.00963
Validation: Loss 0.01103 Accuracy 1.00000
Validation: Loss 0.01063 Accuracy 1.00000
Epoch [ 16]: Loss 0.01145
Epoch [ 16]: Loss 0.01024
Epoch [ 16]: Loss 0.01194
Epoch [ 16]: Loss 0.00970
Epoch [ 16]: Loss 0.00988
Epoch [ 16]: Loss 0.01076
Epoch [ 16]: Loss 0.01113
Validation: Loss 0.00999 Accuracy 1.00000
Validation: Loss 0.00962 Accuracy 1.00000
Epoch [ 17]: Loss 0.01009
Epoch [ 17]: Loss 0.01004
Epoch [ 17]: Loss 0.01018
Epoch [ 17]: Loss 0.00968
Epoch [ 17]: Loss 0.00900
Epoch [ 17]: Loss 0.00924
Epoch [ 17]: Loss 0.00983
Validation: Loss 0.00912 Accuracy 1.00000
Validation: Loss 0.00878 Accuracy 1.00000
Epoch [ 18]: Loss 0.00939
Epoch [ 18]: Loss 0.00865
Epoch [ 18]: Loss 0.00864
Epoch [ 18]: Loss 0.00885
Epoch [ 18]: Loss 0.00837
Epoch [ 18]: Loss 0.00934
Epoch [ 18]: Loss 0.00899
Validation: Loss 0.00838 Accuracy 1.00000
Validation: Loss 0.00806 Accuracy 1.00000
Epoch [ 19]: Loss 0.00903
Epoch [ 19]: Loss 0.00843
Epoch [ 19]: Loss 0.00825
Epoch [ 19]: Loss 0.00742
Epoch [ 19]: Loss 0.00789
Epoch [ 19]: Loss 0.00803
Epoch [ 19]: Loss 0.00830
Validation: Loss 0.00774 Accuracy 1.00000
Validation: Loss 0.00744 Accuracy 1.00000
Epoch [ 20]: Loss 0.00799
Epoch [ 20]: Loss 0.00714
Epoch [ 20]: Loss 0.00799
Epoch [ 20]: Loss 0.00782
Epoch [ 20]: Loss 0.00714
Epoch [ 20]: Loss 0.00745
Epoch [ 20]: Loss 0.00709
Validation: Loss 0.00718 Accuracy 1.00000
Validation: Loss 0.00691 Accuracy 1.00000
Epoch [ 21]: Loss 0.00720
Epoch [ 21]: Loss 0.00719
Epoch [ 21]: Loss 0.00723
Epoch [ 21]: Loss 0.00691
Epoch [ 21]: Loss 0.00706
Epoch [ 21]: Loss 0.00679
Epoch [ 21]: Loss 0.00635
Validation: Loss 0.00669 Accuracy 1.00000
Validation: Loss 0.00644 Accuracy 1.00000
Epoch [ 22]: Loss 0.00648
Epoch [ 22]: Loss 0.00668
Epoch [ 22]: Loss 0.00643
Epoch [ 22]: Loss 0.00689
Epoch [ 22]: Loss 0.00650
Epoch [ 22]: Loss 0.00641
Epoch [ 22]: Loss 0.00651
Validation: Loss 0.00626 Accuracy 1.00000
Validation: Loss 0.00602 Accuracy 1.00000
Epoch [ 23]: Loss 0.00622
Epoch [ 23]: Loss 0.00612
Epoch [ 23]: Loss 0.00632
Epoch [ 23]: Loss 0.00596
Epoch [ 23]: Loss 0.00632
Epoch [ 23]: Loss 0.00594
Epoch [ 23]: Loss 0.00615
Validation: Loss 0.00588 Accuracy 1.00000
Validation: Loss 0.00565 Accuracy 1.00000
Epoch [ 24]: Loss 0.00574
Epoch [ 24]: Loss 0.00573
Epoch [ 24]: Loss 0.00579
Epoch [ 24]: Loss 0.00580
Epoch [ 24]: Loss 0.00586
Epoch [ 24]: Loss 0.00591
Epoch [ 24]: Loss 0.00507
Validation: Loss 0.00553 Accuracy 1.00000
Validation: Loss 0.00531 Accuracy 1.00000
Epoch [ 25]: Loss 0.00511
Epoch [ 25]: Loss 0.00567
Epoch [ 25]: Loss 0.00562
Epoch [ 25]: Loss 0.00537
Epoch [ 25]: Loss 0.00535
Epoch [ 25]: Loss 0.00552
Epoch [ 25]: Loss 0.00541
Validation: Loss 0.00522 Accuracy 1.00000
Validation: Loss 0.00501 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [ 1]: Loss 0.56316
Epoch [ 1]: Loss 0.51362
Epoch [ 1]: Loss 0.46587
Epoch [ 1]: Loss 0.44222
Epoch [ 1]: Loss 0.42381
Epoch [ 1]: Loss 0.40441
Epoch [ 1]: Loss 0.37289
Validation: Loss 0.36375 Accuracy 1.00000
Validation: Loss 0.37596 Accuracy 1.00000
Epoch [ 2]: Loss 0.36591
Epoch [ 2]: Loss 0.35794
Epoch [ 2]: Loss 0.33095
Epoch [ 2]: Loss 0.31272
Epoch [ 2]: Loss 0.29893
Epoch [ 2]: Loss 0.28205
Epoch [ 2]: Loss 0.26940
Validation: Loss 0.25448 Accuracy 1.00000
Validation: Loss 0.26255 Accuracy 1.00000
Epoch [ 3]: Loss 0.26247
Epoch [ 3]: Loss 0.24280
Epoch [ 3]: Loss 0.22428
Epoch [ 3]: Loss 0.22090
Epoch [ 3]: Loss 0.20664
Epoch [ 3]: Loss 0.19857
Epoch [ 3]: Loss 0.18688
Validation: Loss 0.17701 Accuracy 1.00000
Validation: Loss 0.18136 Accuracy 1.00000
Epoch [ 4]: Loss 0.17990
Epoch [ 4]: Loss 0.16881
Epoch [ 4]: Loss 0.15792
Epoch [ 4]: Loss 0.15347
Epoch [ 4]: Loss 0.14881
Epoch [ 4]: Loss 0.13767
Epoch [ 4]: Loss 0.13352
Validation: Loss 0.12562 Accuracy 1.00000
Validation: Loss 0.12842 Accuracy 1.00000
Epoch [ 5]: Loss 0.12649
Epoch [ 5]: Loss 0.12160
Epoch [ 5]: Loss 0.11529
Epoch [ 5]: Loss 0.10803
Epoch [ 5]: Loss 0.10416
Epoch [ 5]: Loss 0.10049
Epoch [ 5]: Loss 0.09833
Validation: Loss 0.09088 Accuracy 1.00000
Validation: Loss 0.09327 Accuracy 1.00000
Epoch [ 6]: Loss 0.09061
Epoch [ 6]: Loss 0.08727
Epoch [ 6]: Loss 0.08393
Epoch [ 6]: Loss 0.08035
Epoch [ 6]: Loss 0.07851
Epoch [ 6]: Loss 0.07324
Epoch [ 6]: Loss 0.06634
Validation: Loss 0.06677 Accuracy 1.00000
Validation: Loss 0.06906 Accuracy 1.00000
Epoch [ 7]: Loss 0.06702
Epoch [ 7]: Loss 0.06306
Epoch [ 7]: Loss 0.06158
Epoch [ 7]: Loss 0.05855
Epoch [ 7]: Loss 0.05752
Epoch [ 7]: Loss 0.05597
Epoch [ 7]: Loss 0.05555
Validation: Loss 0.04964 Accuracy 1.00000
Validation: Loss 0.05192 Accuracy 1.00000
Epoch [ 8]: Loss 0.05171
Epoch [ 8]: Loss 0.04802
Epoch [ 8]: Loss 0.04646
Epoch [ 8]: Loss 0.04366
Epoch [ 8]: Loss 0.04213
Epoch [ 8]: Loss 0.04038
Epoch [ 8]: Loss 0.04080
Validation: Loss 0.03730 Accuracy 1.00000
Validation: Loss 0.03938 Accuracy 1.00000
Epoch [ 9]: Loss 0.03813
Epoch [ 9]: Loss 0.03492
Epoch [ 9]: Loss 0.03649
Epoch [ 9]: Loss 0.03446
Epoch [ 9]: Loss 0.03215
Epoch [ 9]: Loss 0.03066
Epoch [ 9]: Loss 0.02812
Validation: Loss 0.02848 Accuracy 1.00000
Validation: Loss 0.03037 Accuracy 1.00000
Epoch [ 10]: Loss 0.02945
Epoch [ 10]: Loss 0.02742
Epoch [ 10]: Loss 0.02668
Epoch [ 10]: Loss 0.02618
Epoch [ 10]: Loss 0.02480
Epoch [ 10]: Loss 0.02522
Epoch [ 10]: Loss 0.02184
Validation: Loss 0.02240 Accuracy 1.00000
Validation: Loss 0.02410 Accuracy 1.00000
Epoch [ 11]: Loss 0.02170
Epoch [ 11]: Loss 0.02274
Epoch [ 11]: Loss 0.02153
Epoch [ 11]: Loss 0.02086
Epoch [ 11]: Loss 0.02064
Epoch [ 11]: Loss 0.01961
Epoch [ 11]: Loss 0.01914
Validation: Loss 0.01826 Accuracy 1.00000
Validation: Loss 0.01976 Accuracy 1.00000
Epoch [ 12]: Loss 0.01989
Epoch [ 12]: Loss 0.01881
Epoch [ 12]: Loss 0.01743
Epoch [ 12]: Loss 0.01728
Epoch [ 12]: Loss 0.01607
Epoch [ 12]: Loss 0.01578
Epoch [ 12]: Loss 0.01522
Validation: Loss 0.01537 Accuracy 1.00000
Validation: Loss 0.01665 Accuracy 1.00000
Epoch [ 13]: Loss 0.01595
Epoch [ 13]: Loss 0.01666
Epoch [ 13]: Loss 0.01449
Epoch [ 13]: Loss 0.01410
Epoch [ 13]: Loss 0.01408
Epoch [ 13]: Loss 0.01405
Epoch [ 13]: Loss 0.01344
Validation: Loss 0.01329 Accuracy 1.00000
Validation: Loss 0.01441 Accuracy 1.00000
Epoch [ 14]: Loss 0.01378
Epoch [ 14]: Loss 0.01252
Epoch [ 14]: Loss 0.01302
Epoch [ 14]: Loss 0.01416
Epoch [ 14]: Loss 0.01236
Epoch [ 14]: Loss 0.01199
Epoch [ 14]: Loss 0.01187
Validation: Loss 0.01173 Accuracy 1.00000
Validation: Loss 0.01274 Accuracy 1.00000
Epoch [ 15]: Loss 0.01182
Epoch [ 15]: Loss 0.01199
Epoch [ 15]: Loss 0.01163
Epoch [ 15]: Loss 0.01164
Epoch [ 15]: Loss 0.01108
Epoch [ 15]: Loss 0.01065
Epoch [ 15]: Loss 0.01178
Validation: Loss 0.01051 Accuracy 1.00000
Validation: Loss 0.01142 Accuracy 1.00000
Epoch [ 16]: Loss 0.01028
Epoch [ 16]: Loss 0.01061
Epoch [ 16]: Loss 0.01089
Epoch [ 16]: Loss 0.01032
Epoch [ 16]: Loss 0.01043
Epoch [ 16]: Loss 0.00973
Epoch [ 16]: Loss 0.00928
Validation: Loss 0.00952 Accuracy 1.00000
Validation: Loss 0.01035 Accuracy 1.00000
Epoch [ 17]: Loss 0.01009
Epoch [ 17]: Loss 0.00878
Epoch [ 17]: Loss 0.00914
Epoch [ 17]: Loss 0.00925
Epoch [ 17]: Loss 0.00957
Epoch [ 17]: Loss 0.00959
Epoch [ 17]: Loss 0.00901
Validation: Loss 0.00870 Accuracy 1.00000
Validation: Loss 0.00946 Accuracy 1.00000
Epoch [ 18]: Loss 0.00903
Epoch [ 18]: Loss 0.00911
Epoch [ 18]: Loss 0.00830
Epoch [ 18]: Loss 0.00894
Epoch [ 18]: Loss 0.00794
Epoch [ 18]: Loss 0.00844
Epoch [ 18]: Loss 0.00805
Validation: Loss 0.00800 Accuracy 1.00000
Validation: Loss 0.00870 Accuracy 1.00000
Epoch [ 19]: Loss 0.00835
Epoch [ 19]: Loss 0.00813
Epoch [ 19]: Loss 0.00747
Epoch [ 19]: Loss 0.00764
Epoch [ 19]: Loss 0.00796
Epoch [ 19]: Loss 0.00796
Epoch [ 19]: Loss 0.00811
Validation: Loss 0.00739 Accuracy 1.00000
Validation: Loss 0.00805 Accuracy 1.00000
Epoch [ 20]: Loss 0.00702
Epoch [ 20]: Loss 0.00693
Epoch [ 20]: Loss 0.00786
Epoch [ 20]: Loss 0.00733
Epoch [ 20]: Loss 0.00757
Epoch [ 20]: Loss 0.00738
Epoch [ 20]: Loss 0.00711
Validation: Loss 0.00687 Accuracy 1.00000
Validation: Loss 0.00748 Accuracy 1.00000
Epoch [ 21]: Loss 0.00703
Epoch [ 21]: Loss 0.00702
Epoch [ 21]: Loss 0.00724
Epoch [ 21]: Loss 0.00727
Epoch [ 21]: Loss 0.00655
Epoch [ 21]: Loss 0.00613
Epoch [ 21]: Loss 0.00582
Validation: Loss 0.00640 Accuracy 1.00000
Validation: Loss 0.00698 Accuracy 1.00000
Epoch [ 22]: Loss 0.00627
Epoch [ 22]: Loss 0.00615
Epoch [ 22]: Loss 0.00663
Epoch [ 22]: Loss 0.00662
Epoch [ 22]: Loss 0.00634
Epoch [ 22]: Loss 0.00615
Epoch [ 22]: Loss 0.00671
Validation: Loss 0.00599 Accuracy 1.00000
Validation: Loss 0.00654 Accuracy 1.00000
Epoch [ 23]: Loss 0.00605
Epoch [ 23]: Loss 0.00592
Epoch [ 23]: Loss 0.00615
Epoch [ 23]: Loss 0.00598
Epoch [ 23]: Loss 0.00615
Epoch [ 23]: Loss 0.00559
Epoch [ 23]: Loss 0.00600
Validation: Loss 0.00562 Accuracy 1.00000
Validation: Loss 0.00614 Accuracy 1.00000
Epoch [ 24]: Loss 0.00575
Epoch [ 24]: Loss 0.00558
Epoch [ 24]: Loss 0.00621
Epoch [ 24]: Loss 0.00531
Epoch [ 24]: Loss 0.00548
Epoch [ 24]: Loss 0.00552
Epoch [ 24]: Loss 0.00496
Validation: Loss 0.00529 Accuracy 1.00000
Validation: Loss 0.00578 Accuracy 1.00000
Epoch [ 25]: Loss 0.00526
Epoch [ 25]: Loss 0.00550
Epoch [ 25]: Loss 0.00500
Epoch [ 25]: Loss 0.00541
Epoch [ 25]: Loss 0.00527
Epoch [ 25]: Loss 0.00532
Epoch [ 25]: Loss 0.00514
Validation: Loss 0.00499 Accuracy 1.00000
Validation: Loss 0.00545 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model
@save "trained_model.jld2" {compress = true} ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxDeviceUtils)
if @isdefined(CUDA) && LuxDeviceUtils.functional(LuxCUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && LuxDeviceUtils.functional(LuxAMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 4 default, 0 interactive, 2 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 4
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6
CUDA libraries:
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6
Julia packages:
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0
Toolchain:
- Julia: 1.10.4
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.484 GiB / 4.750 GiB available)
This page was generated using Literate.jl.