Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = gpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.62971
Epoch [  1]: Loss 0.59257
Epoch [  1]: Loss 0.57339
Epoch [  1]: Loss 0.53836
Epoch [  1]: Loss 0.51698
Epoch [  1]: Loss 0.49490
Epoch [  1]: Loss 0.48043
Validation: Loss 0.46506 Accuracy 1.00000
Validation: Loss 0.47192 Accuracy 1.00000
Epoch [  2]: Loss 0.47233
Epoch [  2]: Loss 0.45569
Epoch [  2]: Loss 0.43496
Epoch [  2]: Loss 0.41765
Epoch [  2]: Loss 0.41143
Epoch [  2]: Loss 0.40128
Epoch [  2]: Loss 0.40212
Validation: Loss 0.36659 Accuracy 1.00000
Validation: Loss 0.37494 Accuracy 1.00000
Epoch [  3]: Loss 0.36967
Epoch [  3]: Loss 0.36134
Epoch [  3]: Loss 0.36210
Epoch [  3]: Loss 0.31757
Epoch [  3]: Loss 0.31068
Epoch [  3]: Loss 0.31522
Epoch [  3]: Loss 0.29133
Validation: Loss 0.28104 Accuracy 1.00000
Validation: Loss 0.29017 Accuracy 1.00000
Epoch [  4]: Loss 0.29200
Epoch [  4]: Loss 0.27371
Epoch [  4]: Loss 0.26770
Epoch [  4]: Loss 0.25882
Epoch [  4]: Loss 0.23682
Epoch [  4]: Loss 0.22155
Epoch [  4]: Loss 0.24595
Validation: Loss 0.21095 Accuracy 1.00000
Validation: Loss 0.22010 Accuracy 1.00000
Epoch [  5]: Loss 0.20887
Epoch [  5]: Loss 0.20335
Epoch [  5]: Loss 0.19391
Epoch [  5]: Loss 0.19752
Epoch [  5]: Loss 0.17516
Epoch [  5]: Loss 0.17962
Epoch [  5]: Loss 0.19653
Validation: Loss 0.15569 Accuracy 1.00000
Validation: Loss 0.16394 Accuracy 1.00000
Epoch [  6]: Loss 0.15946
Epoch [  6]: Loss 0.15349
Epoch [  6]: Loss 0.15102
Epoch [  6]: Loss 0.12940
Epoch [  6]: Loss 0.14021
Epoch [  6]: Loss 0.12752
Epoch [  6]: Loss 0.11737
Validation: Loss 0.11334 Accuracy 1.00000
Validation: Loss 0.11994 Accuracy 1.00000
Epoch [  7]: Loss 0.12843
Epoch [  7]: Loss 0.11030
Epoch [  7]: Loss 0.09183
Epoch [  7]: Loss 0.09374
Epoch [  7]: Loss 0.10322
Epoch [  7]: Loss 0.09652
Epoch [  7]: Loss 0.08432
Validation: Loss 0.08088 Accuracy 1.00000
Validation: Loss 0.08567 Accuracy 1.00000
Epoch [  8]: Loss 0.08357
Epoch [  8]: Loss 0.07563
Epoch [  8]: Loss 0.07285
Epoch [  8]: Loss 0.07308
Epoch [  8]: Loss 0.06987
Epoch [  8]: Loss 0.06369
Epoch [  8]: Loss 0.06499
Validation: Loss 0.05648 Accuracy 1.00000
Validation: Loss 0.05968 Accuracy 1.00000
Epoch [  9]: Loss 0.06047
Epoch [  9]: Loss 0.05518
Epoch [  9]: Loss 0.05308
Epoch [  9]: Loss 0.04792
Epoch [  9]: Loss 0.05168
Epoch [  9]: Loss 0.04425
Epoch [  9]: Loss 0.04354
Validation: Loss 0.04230 Accuracy 1.00000
Validation: Loss 0.04463 Accuracy 1.00000
Epoch [ 10]: Loss 0.04263
Epoch [ 10]: Loss 0.04192
Epoch [ 10]: Loss 0.03763
Epoch [ 10]: Loss 0.04138
Epoch [ 10]: Loss 0.04109
Epoch [ 10]: Loss 0.03844
Epoch [ 10]: Loss 0.02794
Validation: Loss 0.03440 Accuracy 1.00000
Validation: Loss 0.03633 Accuracy 1.00000
Epoch [ 11]: Loss 0.03736
Epoch [ 11]: Loss 0.03441
Epoch [ 11]: Loss 0.03303
Epoch [ 11]: Loss 0.03283
Epoch [ 11]: Loss 0.03098
Epoch [ 11]: Loss 0.03135
Epoch [ 11]: Loss 0.03072
Validation: Loss 0.02928 Accuracy 1.00000
Validation: Loss 0.03097 Accuracy 1.00000
Epoch [ 12]: Loss 0.03057
Epoch [ 12]: Loss 0.03019
Epoch [ 12]: Loss 0.02785
Epoch [ 12]: Loss 0.03096
Epoch [ 12]: Loss 0.02820
Epoch [ 12]: Loss 0.02510
Epoch [ 12]: Loss 0.02299
Validation: Loss 0.02552 Accuracy 1.00000
Validation: Loss 0.02703 Accuracy 1.00000
Epoch [ 13]: Loss 0.02736
Epoch [ 13]: Loss 0.02528
Epoch [ 13]: Loss 0.02305
Epoch [ 13]: Loss 0.02419
Epoch [ 13]: Loss 0.02659
Epoch [ 13]: Loss 0.02531
Epoch [ 13]: Loss 0.01943
Validation: Loss 0.02262 Accuracy 1.00000
Validation: Loss 0.02398 Accuracy 1.00000
Epoch [ 14]: Loss 0.02304
Epoch [ 14]: Loss 0.02327
Epoch [ 14]: Loss 0.02130
Epoch [ 14]: Loss 0.02117
Epoch [ 14]: Loss 0.02221
Epoch [ 14]: Loss 0.02344
Epoch [ 14]: Loss 0.02041
Validation: Loss 0.02029 Accuracy 1.00000
Validation: Loss 0.02153 Accuracy 1.00000
Epoch [ 15]: Loss 0.02125
Epoch [ 15]: Loss 0.01979
Epoch [ 15]: Loss 0.02100
Epoch [ 15]: Loss 0.01977
Epoch [ 15]: Loss 0.02040
Epoch [ 15]: Loss 0.01939
Epoch [ 15]: Loss 0.01631
Validation: Loss 0.01833 Accuracy 1.00000
Validation: Loss 0.01949 Accuracy 1.00000
Epoch [ 16]: Loss 0.01886
Epoch [ 16]: Loss 0.01843
Epoch [ 16]: Loss 0.01784
Epoch [ 16]: Loss 0.01743
Epoch [ 16]: Loss 0.01837
Epoch [ 16]: Loss 0.01897
Epoch [ 16]: Loss 0.01607
Validation: Loss 0.01668 Accuracy 1.00000
Validation: Loss 0.01775 Accuracy 1.00000
Epoch [ 17]: Loss 0.01815
Epoch [ 17]: Loss 0.01744
Epoch [ 17]: Loss 0.01513
Epoch [ 17]: Loss 0.01683
Epoch [ 17]: Loss 0.01694
Epoch [ 17]: Loss 0.01578
Epoch [ 17]: Loss 0.01466
Validation: Loss 0.01525 Accuracy 1.00000
Validation: Loss 0.01625 Accuracy 1.00000
Epoch [ 18]: Loss 0.01646
Epoch [ 18]: Loss 0.01507
Epoch [ 18]: Loss 0.01578
Epoch [ 18]: Loss 0.01484
Epoch [ 18]: Loss 0.01481
Epoch [ 18]: Loss 0.01464
Epoch [ 18]: Loss 0.01462
Validation: Loss 0.01401 Accuracy 1.00000
Validation: Loss 0.01494 Accuracy 1.00000
Epoch [ 19]: Loss 0.01330
Epoch [ 19]: Loss 0.01596
Epoch [ 19]: Loss 0.01493
Epoch [ 19]: Loss 0.01435
Epoch [ 19]: Loss 0.01391
Epoch [ 19]: Loss 0.01227
Epoch [ 19]: Loss 0.01185
Validation: Loss 0.01292 Accuracy 1.00000
Validation: Loss 0.01378 Accuracy 1.00000
Epoch [ 20]: Loss 0.01224
Epoch [ 20]: Loss 0.01354
Epoch [ 20]: Loss 0.01341
Epoch [ 20]: Loss 0.01274
Epoch [ 20]: Loss 0.01209
Epoch [ 20]: Loss 0.01357
Epoch [ 20]: Loss 0.01296
Validation: Loss 0.01193 Accuracy 1.00000
Validation: Loss 0.01273 Accuracy 1.00000
Epoch [ 21]: Loss 0.01229
Epoch [ 21]: Loss 0.01166
Epoch [ 21]: Loss 0.01127
Epoch [ 21]: Loss 0.01350
Epoch [ 21]: Loss 0.01101
Epoch [ 21]: Loss 0.01225
Epoch [ 21]: Loss 0.01006
Validation: Loss 0.01094 Accuracy 1.00000
Validation: Loss 0.01167 Accuracy 1.00000
Epoch [ 22]: Loss 0.01238
Epoch [ 22]: Loss 0.01066
Epoch [ 22]: Loss 0.01061
Epoch [ 22]: Loss 0.00984
Epoch [ 22]: Loss 0.01075
Epoch [ 22]: Loss 0.01099
Epoch [ 22]: Loss 0.01069
Validation: Loss 0.00986 Accuracy 1.00000
Validation: Loss 0.01050 Accuracy 1.00000
Epoch [ 23]: Loss 0.01043
Epoch [ 23]: Loss 0.01024
Epoch [ 23]: Loss 0.01026
Epoch [ 23]: Loss 0.00960
Epoch [ 23]: Loss 0.00819
Epoch [ 23]: Loss 0.00946
Epoch [ 23]: Loss 0.00979
Validation: Loss 0.00872 Accuracy 1.00000
Validation: Loss 0.00927 Accuracy 1.00000
Epoch [ 24]: Loss 0.00895
Epoch [ 24]: Loss 0.00905
Epoch [ 24]: Loss 0.00819
Epoch [ 24]: Loss 0.00869
Epoch [ 24]: Loss 0.00837
Epoch [ 24]: Loss 0.00819
Epoch [ 24]: Loss 0.00906
Validation: Loss 0.00782 Accuracy 1.00000
Validation: Loss 0.00830 Accuracy 1.00000
Epoch [ 25]: Loss 0.00774
Epoch [ 25]: Loss 0.00714
Epoch [ 25]: Loss 0.00818
Epoch [ 25]: Loss 0.00807
Epoch [ 25]: Loss 0.00790
Epoch [ 25]: Loss 0.00774
Epoch [ 25]: Loss 0.00719
Validation: Loss 0.00718 Accuracy 1.00000
Validation: Loss 0.00761 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.63452
Epoch [  1]: Loss 0.58475
Epoch [  1]: Loss 0.56772
Epoch [  1]: Loss 0.55285
Epoch [  1]: Loss 0.51931
Epoch [  1]: Loss 0.49952
Epoch [  1]: Loss 0.48290
Validation: Loss 0.46231 Accuracy 1.00000
Validation: Loss 0.46178 Accuracy 1.00000
Epoch [  2]: Loss 0.47137
Epoch [  2]: Loss 0.45927
Epoch [  2]: Loss 0.44311
Epoch [  2]: Loss 0.41728
Epoch [  2]: Loss 0.41874
Epoch [  2]: Loss 0.40626
Epoch [  2]: Loss 0.38031
Validation: Loss 0.36362 Accuracy 1.00000
Validation: Loss 0.36289 Accuracy 1.00000
Epoch [  3]: Loss 0.38143
Epoch [  3]: Loss 0.36440
Epoch [  3]: Loss 0.35107
Epoch [  3]: Loss 0.33287
Epoch [  3]: Loss 0.32499
Epoch [  3]: Loss 0.30182
Epoch [  3]: Loss 0.28698
Validation: Loss 0.27742 Accuracy 1.00000
Validation: Loss 0.27669 Accuracy 1.00000
Epoch [  4]: Loss 0.29334
Epoch [  4]: Loss 0.28473
Epoch [  4]: Loss 0.26456
Epoch [  4]: Loss 0.25983
Epoch [  4]: Loss 0.24079
Epoch [  4]: Loss 0.23357
Epoch [  4]: Loss 0.21942
Validation: Loss 0.20696 Accuracy 1.00000
Validation: Loss 0.20640 Accuracy 1.00000
Epoch [  5]: Loss 0.21394
Epoch [  5]: Loss 0.20307
Epoch [  5]: Loss 0.20392
Epoch [  5]: Loss 0.19862
Epoch [  5]: Loss 0.18408
Epoch [  5]: Loss 0.17642
Epoch [  5]: Loss 0.18840
Validation: Loss 0.15176 Accuracy 1.00000
Validation: Loss 0.15144 Accuracy 1.00000
Epoch [  6]: Loss 0.16435
Epoch [  6]: Loss 0.14796
Epoch [  6]: Loss 0.15781
Epoch [  6]: Loss 0.13945
Epoch [  6]: Loss 0.14118
Epoch [  6]: Loss 0.12911
Epoch [  6]: Loss 0.11438
Validation: Loss 0.11011 Accuracy 1.00000
Validation: Loss 0.11000 Accuracy 1.00000
Epoch [  7]: Loss 0.12858
Epoch [  7]: Loss 0.11691
Epoch [  7]: Loss 0.11066
Epoch [  7]: Loss 0.09406
Epoch [  7]: Loss 0.09806
Epoch [  7]: Loss 0.09013
Epoch [  7]: Loss 0.09469
Validation: Loss 0.07858 Accuracy 1.00000
Validation: Loss 0.07857 Accuracy 1.00000
Epoch [  8]: Loss 0.08926
Epoch [  8]: Loss 0.08194
Epoch [  8]: Loss 0.07845
Epoch [  8]: Loss 0.07024
Epoch [  8]: Loss 0.06963
Epoch [  8]: Loss 0.06261
Epoch [  8]: Loss 0.06569
Validation: Loss 0.05492 Accuracy 1.00000
Validation: Loss 0.05490 Accuracy 1.00000
Epoch [  9]: Loss 0.05588
Epoch [  9]: Loss 0.06106
Epoch [  9]: Loss 0.05314
Epoch [  9]: Loss 0.05496
Epoch [  9]: Loss 0.04832
Epoch [  9]: Loss 0.04531
Epoch [  9]: Loss 0.04504
Validation: Loss 0.04101 Accuracy 1.00000
Validation: Loss 0.04096 Accuracy 1.00000
Epoch [ 10]: Loss 0.04245
Epoch [ 10]: Loss 0.04303
Epoch [ 10]: Loss 0.04039
Epoch [ 10]: Loss 0.04299
Epoch [ 10]: Loss 0.03709
Epoch [ 10]: Loss 0.03847
Epoch [ 10]: Loss 0.03740
Validation: Loss 0.03325 Accuracy 1.00000
Validation: Loss 0.03321 Accuracy 1.00000
Epoch [ 11]: Loss 0.03502
Epoch [ 11]: Loss 0.03638
Epoch [ 11]: Loss 0.03548
Epoch [ 11]: Loss 0.03220
Epoch [ 11]: Loss 0.03175
Epoch [ 11]: Loss 0.03207
Epoch [ 11]: Loss 0.02855
Validation: Loss 0.02820 Accuracy 1.00000
Validation: Loss 0.02816 Accuracy 1.00000
Epoch [ 12]: Loss 0.03196
Epoch [ 12]: Loss 0.03134
Epoch [ 12]: Loss 0.02825
Epoch [ 12]: Loss 0.03009
Epoch [ 12]: Loss 0.02764
Epoch [ 12]: Loss 0.02594
Epoch [ 12]: Loss 0.02112
Validation: Loss 0.02453 Accuracy 1.00000
Validation: Loss 0.02450 Accuracy 1.00000
Epoch [ 13]: Loss 0.02542
Epoch [ 13]: Loss 0.02675
Epoch [ 13]: Loss 0.02781
Epoch [ 13]: Loss 0.02498
Epoch [ 13]: Loss 0.02497
Epoch [ 13]: Loss 0.02328
Epoch [ 13]: Loss 0.02023
Validation: Loss 0.02172 Accuracy 1.00000
Validation: Loss 0.02169 Accuracy 1.00000
Epoch [ 14]: Loss 0.02173
Epoch [ 14]: Loss 0.02518
Epoch [ 14]: Loss 0.02282
Epoch [ 14]: Loss 0.02205
Epoch [ 14]: Loss 0.02231
Epoch [ 14]: Loss 0.02158
Epoch [ 14]: Loss 0.02110
Validation: Loss 0.01946 Accuracy 1.00000
Validation: Loss 0.01943 Accuracy 1.00000
Epoch [ 15]: Loss 0.02132
Epoch [ 15]: Loss 0.02250
Epoch [ 15]: Loss 0.01980
Epoch [ 15]: Loss 0.02040
Epoch [ 15]: Loss 0.01835
Epoch [ 15]: Loss 0.01931
Epoch [ 15]: Loss 0.02113
Validation: Loss 0.01757 Accuracy 1.00000
Validation: Loss 0.01755 Accuracy 1.00000
Epoch [ 16]: Loss 0.01860
Epoch [ 16]: Loss 0.01819
Epoch [ 16]: Loss 0.02129
Epoch [ 16]: Loss 0.01852
Epoch [ 16]: Loss 0.01811
Epoch [ 16]: Loss 0.01542
Epoch [ 16]: Loss 0.01988
Validation: Loss 0.01596 Accuracy 1.00000
Validation: Loss 0.01595 Accuracy 1.00000
Epoch [ 17]: Loss 0.01596
Epoch [ 17]: Loss 0.01624
Epoch [ 17]: Loss 0.01804
Epoch [ 17]: Loss 0.01797
Epoch [ 17]: Loss 0.01623
Epoch [ 17]: Loss 0.01677
Epoch [ 17]: Loss 0.01463
Validation: Loss 0.01458 Accuracy 1.00000
Validation: Loss 0.01456 Accuracy 1.00000
Epoch [ 18]: Loss 0.01502
Epoch [ 18]: Loss 0.01573
Epoch [ 18]: Loss 0.01539
Epoch [ 18]: Loss 0.01505
Epoch [ 18]: Loss 0.01517
Epoch [ 18]: Loss 0.01612
Epoch [ 18]: Loss 0.01453
Validation: Loss 0.01338 Accuracy 1.00000
Validation: Loss 0.01337 Accuracy 1.00000
Epoch [ 19]: Loss 0.01422
Epoch [ 19]: Loss 0.01446
Epoch [ 19]: Loss 0.01388
Epoch [ 19]: Loss 0.01454
Epoch [ 19]: Loss 0.01439
Epoch [ 19]: Loss 0.01395
Epoch [ 19]: Loss 0.01237
Validation: Loss 0.01233 Accuracy 1.00000
Validation: Loss 0.01232 Accuracy 1.00000
Epoch [ 20]: Loss 0.01246
Epoch [ 20]: Loss 0.01311
Epoch [ 20]: Loss 0.01372
Epoch [ 20]: Loss 0.01245
Epoch [ 20]: Loss 0.01429
Epoch [ 20]: Loss 0.01278
Epoch [ 20]: Loss 0.01194
Validation: Loss 0.01141 Accuracy 1.00000
Validation: Loss 0.01140 Accuracy 1.00000
Epoch [ 21]: Loss 0.01203
Epoch [ 21]: Loss 0.01358
Epoch [ 21]: Loss 0.01105
Epoch [ 21]: Loss 0.01152
Epoch [ 21]: Loss 0.01150
Epoch [ 21]: Loss 0.01194
Epoch [ 21]: Loss 0.01649
Validation: Loss 0.01055 Accuracy 1.00000
Validation: Loss 0.01055 Accuracy 1.00000
Epoch [ 22]: Loss 0.01181
Epoch [ 22]: Loss 0.01111
Epoch [ 22]: Loss 0.01194
Epoch [ 22]: Loss 0.01073
Epoch [ 22]: Loss 0.01155
Epoch [ 22]: Loss 0.01064
Epoch [ 22]: Loss 0.00839
Validation: Loss 0.00965 Accuracy 1.00000
Validation: Loss 0.00965 Accuracy 1.00000
Epoch [ 23]: Loss 0.01037
Epoch [ 23]: Loss 0.01053
Epoch [ 23]: Loss 0.01029
Epoch [ 23]: Loss 0.00974
Epoch [ 23]: Loss 0.01024
Epoch [ 23]: Loss 0.01018
Epoch [ 23]: Loss 0.00839
Validation: Loss 0.00867 Accuracy 1.00000
Validation: Loss 0.00868 Accuracy 1.00000
Epoch [ 24]: Loss 0.00984
Epoch [ 24]: Loss 0.00872
Epoch [ 24]: Loss 0.00950
Epoch [ 24]: Loss 0.00874
Epoch [ 24]: Loss 0.00870
Epoch [ 24]: Loss 0.00868
Epoch [ 24]: Loss 0.00934
Validation: Loss 0.00771 Accuracy 1.00000
Validation: Loss 0.00771 Accuracy 1.00000
Epoch [ 25]: Loss 0.00840
Epoch [ 25]: Loss 0.00835
Epoch [ 25]: Loss 0.00836
Epoch [ 25]: Loss 0.00908
Epoch [ 25]: Loss 0.00748
Epoch [ 25]: Loss 0.00703
Epoch [ 25]: Loss 0.00672
Validation: Loss 0.00697 Accuracy 1.00000
Validation: Loss 0.00697 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6

CUDA libraries: 
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6

Julia packages: 
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0

Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.453 GiB / 4.750 GiB available)

This page was generated using Literate.jl.