Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = gpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.62119
Epoch [  1]: Loss 0.59574
Epoch [  1]: Loss 0.55560
Epoch [  1]: Loss 0.54413
Epoch [  1]: Loss 0.52242
Epoch [  1]: Loss 0.49230
Epoch [  1]: Loss 0.50040
Validation: Loss 0.47404 Accuracy 1.00000
Validation: Loss 0.46642 Accuracy 1.00000
Epoch [  2]: Loss 0.47304
Epoch [  2]: Loss 0.44008
Epoch [  2]: Loss 0.43130
Epoch [  2]: Loss 0.43225
Epoch [  2]: Loss 0.42277
Epoch [  2]: Loss 0.39227
Epoch [  2]: Loss 0.40997
Validation: Loss 0.37889 Accuracy 1.00000
Validation: Loss 0.36971 Accuracy 1.00000
Epoch [  3]: Loss 0.37536
Epoch [  3]: Loss 0.36507
Epoch [  3]: Loss 0.34466
Epoch [  3]: Loss 0.32613
Epoch [  3]: Loss 0.32422
Epoch [  3]: Loss 0.30613
Epoch [  3]: Loss 0.30507
Validation: Loss 0.29538 Accuracy 1.00000
Validation: Loss 0.28527 Accuracy 1.00000
Epoch [  4]: Loss 0.28244
Epoch [  4]: Loss 0.27678
Epoch [  4]: Loss 0.27830
Epoch [  4]: Loss 0.26213
Epoch [  4]: Loss 0.23831
Epoch [  4]: Loss 0.23514
Epoch [  4]: Loss 0.21314
Validation: Loss 0.22583 Accuracy 1.00000
Validation: Loss 0.21560 Accuracy 1.00000
Epoch [  5]: Loss 0.21782
Epoch [  5]: Loss 0.21827
Epoch [  5]: Loss 0.20122
Epoch [  5]: Loss 0.18831
Epoch [  5]: Loss 0.19369
Epoch [  5]: Loss 0.16890
Epoch [  5]: Loss 0.14980
Validation: Loss 0.16984 Accuracy 1.00000
Validation: Loss 0.16043 Accuracy 1.00000
Epoch [  6]: Loss 0.16845
Epoch [  6]: Loss 0.15413
Epoch [  6]: Loss 0.15137
Epoch [  6]: Loss 0.14031
Epoch [  6]: Loss 0.12981
Epoch [  6]: Loss 0.13301
Epoch [  6]: Loss 0.13494
Validation: Loss 0.12594 Accuracy 1.00000
Validation: Loss 0.11818 Accuracy 1.00000
Epoch [  7]: Loss 0.11420
Epoch [  7]: Loss 0.11860
Epoch [  7]: Loss 0.10908
Epoch [  7]: Loss 0.10439
Epoch [  7]: Loss 0.10327
Epoch [  7]: Loss 0.09212
Epoch [  7]: Loss 0.09798
Validation: Loss 0.09058 Accuracy 1.00000
Validation: Loss 0.08486 Accuracy 1.00000
Epoch [  8]: Loss 0.08829
Epoch [  8]: Loss 0.08023
Epoch [  8]: Loss 0.07346
Epoch [  8]: Loss 0.07633
Epoch [  8]: Loss 0.07117
Epoch [  8]: Loss 0.06883
Epoch [  8]: Loss 0.05656
Validation: Loss 0.06288 Accuracy 1.00000
Validation: Loss 0.05908 Accuracy 1.00000
Epoch [  9]: Loss 0.05933
Epoch [  9]: Loss 0.05994
Epoch [  9]: Loss 0.05380
Epoch [  9]: Loss 0.05270
Epoch [  9]: Loss 0.05178
Epoch [  9]: Loss 0.04477
Epoch [  9]: Loss 0.04302
Validation: Loss 0.04644 Accuracy 1.00000
Validation: Loss 0.04378 Accuracy 1.00000
Epoch [ 10]: Loss 0.04801
Epoch [ 10]: Loss 0.04085
Epoch [ 10]: Loss 0.03921
Epoch [ 10]: Loss 0.04068
Epoch [ 10]: Loss 0.03650
Epoch [ 10]: Loss 0.04000
Epoch [ 10]: Loss 0.03830
Validation: Loss 0.03753 Accuracy 1.00000
Validation: Loss 0.03539 Accuracy 1.00000
Epoch [ 11]: Loss 0.03724
Epoch [ 11]: Loss 0.03387
Epoch [ 11]: Loss 0.03543
Epoch [ 11]: Loss 0.03485
Epoch [ 11]: Loss 0.03291
Epoch [ 11]: Loss 0.02831
Epoch [ 11]: Loss 0.03010
Validation: Loss 0.03185 Accuracy 1.00000
Validation: Loss 0.03000 Accuracy 1.00000
Epoch [ 12]: Loss 0.02944
Epoch [ 12]: Loss 0.03070
Epoch [ 12]: Loss 0.02904
Epoch [ 12]: Loss 0.02895
Epoch [ 12]: Loss 0.02837
Epoch [ 12]: Loss 0.02685
Epoch [ 12]: Loss 0.02608
Validation: Loss 0.02774 Accuracy 1.00000
Validation: Loss 0.02609 Accuracy 1.00000
Epoch [ 13]: Loss 0.02728
Epoch [ 13]: Loss 0.02725
Epoch [ 13]: Loss 0.02421
Epoch [ 13]: Loss 0.02515
Epoch [ 13]: Loss 0.02518
Epoch [ 13]: Loss 0.02340
Epoch [ 13]: Loss 0.02039
Validation: Loss 0.02455 Accuracy 1.00000
Validation: Loss 0.02307 Accuracy 1.00000
Epoch [ 14]: Loss 0.02285
Epoch [ 14]: Loss 0.02289
Epoch [ 14]: Loss 0.02128
Epoch [ 14]: Loss 0.02238
Epoch [ 14]: Loss 0.02223
Epoch [ 14]: Loss 0.02189
Epoch [ 14]: Loss 0.02535
Validation: Loss 0.02201 Accuracy 1.00000
Validation: Loss 0.02065 Accuracy 1.00000
Epoch [ 15]: Loss 0.02104
Epoch [ 15]: Loss 0.02071
Epoch [ 15]: Loss 0.02060
Epoch [ 15]: Loss 0.01973
Epoch [ 15]: Loss 0.01851
Epoch [ 15]: Loss 0.02028
Epoch [ 15]: Loss 0.01927
Validation: Loss 0.01987 Accuracy 1.00000
Validation: Loss 0.01863 Accuracy 1.00000
Epoch [ 16]: Loss 0.01791
Epoch [ 16]: Loss 0.01925
Epoch [ 16]: Loss 0.01882
Epoch [ 16]: Loss 0.01698
Epoch [ 16]: Loss 0.01830
Epoch [ 16]: Loss 0.01788
Epoch [ 16]: Loss 0.01814
Validation: Loss 0.01808 Accuracy 1.00000
Validation: Loss 0.01692 Accuracy 1.00000
Epoch [ 17]: Loss 0.01725
Epoch [ 17]: Loss 0.01620
Epoch [ 17]: Loss 0.01844
Epoch [ 17]: Loss 0.01692
Epoch [ 17]: Loss 0.01524
Epoch [ 17]: Loss 0.01521
Epoch [ 17]: Loss 0.01724
Validation: Loss 0.01654 Accuracy 1.00000
Validation: Loss 0.01547 Accuracy 1.00000
Epoch [ 18]: Loss 0.01577
Epoch [ 18]: Loss 0.01542
Epoch [ 18]: Loss 0.01538
Epoch [ 18]: Loss 0.01565
Epoch [ 18]: Loss 0.01547
Epoch [ 18]: Loss 0.01366
Epoch [ 18]: Loss 0.01408
Validation: Loss 0.01521 Accuracy 1.00000
Validation: Loss 0.01422 Accuracy 1.00000
Epoch [ 19]: Loss 0.01494
Epoch [ 19]: Loss 0.01476
Epoch [ 19]: Loss 0.01374
Epoch [ 19]: Loss 0.01302
Epoch [ 19]: Loss 0.01413
Epoch [ 19]: Loss 0.01314
Epoch [ 19]: Loss 0.01465
Validation: Loss 0.01406 Accuracy 1.00000
Validation: Loss 0.01313 Accuracy 1.00000
Epoch [ 20]: Loss 0.01434
Epoch [ 20]: Loss 0.01379
Epoch [ 20]: Loss 0.01151
Epoch [ 20]: Loss 0.01322
Epoch [ 20]: Loss 0.01218
Epoch [ 20]: Loss 0.01230
Epoch [ 20]: Loss 0.01407
Validation: Loss 0.01303 Accuracy 1.00000
Validation: Loss 0.01217 Accuracy 1.00000
Epoch [ 21]: Loss 0.01468
Epoch [ 21]: Loss 0.01191
Epoch [ 21]: Loss 0.01214
Epoch [ 21]: Loss 0.01105
Epoch [ 21]: Loss 0.01138
Epoch [ 21]: Loss 0.01079
Epoch [ 21]: Loss 0.01207
Validation: Loss 0.01205 Accuracy 1.00000
Validation: Loss 0.01126 Accuracy 1.00000
Epoch [ 22]: Loss 0.01220
Epoch [ 22]: Loss 0.01123
Epoch [ 22]: Loss 0.01043
Epoch [ 22]: Loss 0.01132
Epoch [ 22]: Loss 0.01083
Epoch [ 22]: Loss 0.01062
Epoch [ 22]: Loss 0.01023
Validation: Loss 0.01105 Accuracy 1.00000
Validation: Loss 0.01033 Accuracy 1.00000
Epoch [ 23]: Loss 0.00990
Epoch [ 23]: Loss 0.01136
Epoch [ 23]: Loss 0.01006
Epoch [ 23]: Loss 0.01063
Epoch [ 23]: Loss 0.00927
Epoch [ 23]: Loss 0.00940
Epoch [ 23]: Loss 0.00999
Validation: Loss 0.00992 Accuracy 1.00000
Validation: Loss 0.00929 Accuracy 1.00000
Epoch [ 24]: Loss 0.00930
Epoch [ 24]: Loss 0.00879
Epoch [ 24]: Loss 0.00926
Epoch [ 24]: Loss 0.00857
Epoch [ 24]: Loss 0.00953
Epoch [ 24]: Loss 0.00894
Epoch [ 24]: Loss 0.00753
Validation: Loss 0.00878 Accuracy 1.00000
Validation: Loss 0.00824 Accuracy 1.00000
Epoch [ 25]: Loss 0.00921
Epoch [ 25]: Loss 0.00741
Epoch [ 25]: Loss 0.00867
Epoch [ 25]: Loss 0.00794
Epoch [ 25]: Loss 0.00745
Epoch [ 25]: Loss 0.00742
Epoch [ 25]: Loss 0.00824
Validation: Loss 0.00790 Accuracy 1.00000
Validation: Loss 0.00743 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.62722
Epoch [  1]: Loss 0.59413
Epoch [  1]: Loss 0.55655
Epoch [  1]: Loss 0.54395
Epoch [  1]: Loss 0.51816
Epoch [  1]: Loss 0.51010
Epoch [  1]: Loss 0.47961
Validation: Loss 0.46258 Accuracy 1.00000
Validation: Loss 0.46519 Accuracy 1.00000
Epoch [  2]: Loss 0.47281
Epoch [  2]: Loss 0.45069
Epoch [  2]: Loss 0.44288
Epoch [  2]: Loss 0.42149
Epoch [  2]: Loss 0.41961
Epoch [  2]: Loss 0.40148
Epoch [  2]: Loss 0.37439
Validation: Loss 0.36427 Accuracy 1.00000
Validation: Loss 0.36735 Accuracy 1.00000
Epoch [  3]: Loss 0.37005
Epoch [  3]: Loss 0.36556
Epoch [  3]: Loss 0.35175
Epoch [  3]: Loss 0.32889
Epoch [  3]: Loss 0.31650
Epoch [  3]: Loss 0.31270
Epoch [  3]: Loss 0.30203
Validation: Loss 0.27841 Accuracy 1.00000
Validation: Loss 0.28163 Accuracy 1.00000
Epoch [  4]: Loss 0.28210
Epoch [  4]: Loss 0.29894
Epoch [  4]: Loss 0.25652
Epoch [  4]: Loss 0.26337
Epoch [  4]: Loss 0.23870
Epoch [  4]: Loss 0.22994
Epoch [  4]: Loss 0.21905
Validation: Loss 0.20832 Accuracy 1.00000
Validation: Loss 0.21135 Accuracy 1.00000
Epoch [  5]: Loss 0.21729
Epoch [  5]: Loss 0.20556
Epoch [  5]: Loss 0.18958
Epoch [  5]: Loss 0.18994
Epoch [  5]: Loss 0.18850
Epoch [  5]: Loss 0.18489
Epoch [  5]: Loss 0.17659
Validation: Loss 0.15314 Accuracy 1.00000
Validation: Loss 0.15580 Accuracy 1.00000
Epoch [  6]: Loss 0.16271
Epoch [  6]: Loss 0.15928
Epoch [  6]: Loss 0.13819
Epoch [  6]: Loss 0.13685
Epoch [  6]: Loss 0.13191
Epoch [  6]: Loss 0.13796
Epoch [  6]: Loss 0.13965
Validation: Loss 0.11145 Accuracy 1.00000
Validation: Loss 0.11353 Accuracy 1.00000
Epoch [  7]: Loss 0.12304
Epoch [  7]: Loss 0.11471
Epoch [  7]: Loss 0.10947
Epoch [  7]: Loss 0.10706
Epoch [  7]: Loss 0.09683
Epoch [  7]: Loss 0.08594
Epoch [  7]: Loss 0.08188
Validation: Loss 0.07944 Accuracy 1.00000
Validation: Loss 0.08091 Accuracy 1.00000
Epoch [  8]: Loss 0.07766
Epoch [  8]: Loss 0.07953
Epoch [  8]: Loss 0.08340
Epoch [  8]: Loss 0.07306
Epoch [  8]: Loss 0.06519
Epoch [  8]: Loss 0.06961
Epoch [  8]: Loss 0.05725
Validation: Loss 0.05554 Accuracy 1.00000
Validation: Loss 0.05651 Accuracy 1.00000
Epoch [  9]: Loss 0.05919
Epoch [  9]: Loss 0.05915
Epoch [  9]: Loss 0.04997
Epoch [  9]: Loss 0.05109
Epoch [  9]: Loss 0.05020
Epoch [  9]: Loss 0.04629
Epoch [  9]: Loss 0.04627
Validation: Loss 0.04135 Accuracy 1.00000
Validation: Loss 0.04209 Accuracy 1.00000
Epoch [ 10]: Loss 0.04544
Epoch [ 10]: Loss 0.03921
Epoch [ 10]: Loss 0.04085
Epoch [ 10]: Loss 0.04055
Epoch [ 10]: Loss 0.03785
Epoch [ 10]: Loss 0.03856
Epoch [ 10]: Loss 0.03724
Validation: Loss 0.03345 Accuracy 1.00000
Validation: Loss 0.03409 Accuracy 1.00000
Epoch [ 11]: Loss 0.03828
Epoch [ 11]: Loss 0.03170
Epoch [ 11]: Loss 0.03526
Epoch [ 11]: Loss 0.03154
Epoch [ 11]: Loss 0.03262
Epoch [ 11]: Loss 0.03114
Epoch [ 11]: Loss 0.03046
Validation: Loss 0.02835 Accuracy 1.00000
Validation: Loss 0.02892 Accuracy 1.00000
Epoch [ 12]: Loss 0.02955
Epoch [ 12]: Loss 0.03087
Epoch [ 12]: Loss 0.02767
Epoch [ 12]: Loss 0.02730
Epoch [ 12]: Loss 0.02769
Epoch [ 12]: Loss 0.02786
Epoch [ 12]: Loss 0.03004
Validation: Loss 0.02464 Accuracy 1.00000
Validation: Loss 0.02515 Accuracy 1.00000
Epoch [ 13]: Loss 0.02595
Epoch [ 13]: Loss 0.02547
Epoch [ 13]: Loss 0.02543
Epoch [ 13]: Loss 0.02584
Epoch [ 13]: Loss 0.02205
Epoch [ 13]: Loss 0.02475
Epoch [ 13]: Loss 0.02738
Validation: Loss 0.02176 Accuracy 1.00000
Validation: Loss 0.02223 Accuracy 1.00000
Epoch [ 14]: Loss 0.02314
Epoch [ 14]: Loss 0.02237
Epoch [ 14]: Loss 0.02317
Epoch [ 14]: Loss 0.02203
Epoch [ 14]: Loss 0.02092
Epoch [ 14]: Loss 0.02244
Epoch [ 14]: Loss 0.01936
Validation: Loss 0.01944 Accuracy 1.00000
Validation: Loss 0.01987 Accuracy 1.00000
Epoch [ 15]: Loss 0.02071
Epoch [ 15]: Loss 0.02146
Epoch [ 15]: Loss 0.01903
Epoch [ 15]: Loss 0.01873
Epoch [ 15]: Loss 0.01967
Epoch [ 15]: Loss 0.02077
Epoch [ 15]: Loss 0.01741
Validation: Loss 0.01753 Accuracy 1.00000
Validation: Loss 0.01793 Accuracy 1.00000
Epoch [ 16]: Loss 0.01718
Epoch [ 16]: Loss 0.01893
Epoch [ 16]: Loss 0.01770
Epoch [ 16]: Loss 0.01792
Epoch [ 16]: Loss 0.01866
Epoch [ 16]: Loss 0.01826
Epoch [ 16]: Loss 0.01704
Validation: Loss 0.01592 Accuracy 1.00000
Validation: Loss 0.01629 Accuracy 1.00000
Epoch [ 17]: Loss 0.01653
Epoch [ 17]: Loss 0.01631
Epoch [ 17]: Loss 0.01671
Epoch [ 17]: Loss 0.01832
Epoch [ 17]: Loss 0.01495
Epoch [ 17]: Loss 0.01549
Epoch [ 17]: Loss 0.01853
Validation: Loss 0.01454 Accuracy 1.00000
Validation: Loss 0.01488 Accuracy 1.00000
Epoch [ 18]: Loss 0.01510
Epoch [ 18]: Loss 0.01550
Epoch [ 18]: Loss 0.01345
Epoch [ 18]: Loss 0.01668
Epoch [ 18]: Loss 0.01537
Epoch [ 18]: Loss 0.01448
Epoch [ 18]: Loss 0.01487
Validation: Loss 0.01335 Accuracy 1.00000
Validation: Loss 0.01367 Accuracy 1.00000
Epoch [ 19]: Loss 0.01456
Epoch [ 19]: Loss 0.01398
Epoch [ 19]: Loss 0.01408
Epoch [ 19]: Loss 0.01261
Epoch [ 19]: Loss 0.01495
Epoch [ 19]: Loss 0.01322
Epoch [ 19]: Loss 0.01373
Validation: Loss 0.01231 Accuracy 1.00000
Validation: Loss 0.01261 Accuracy 1.00000
Epoch [ 20]: Loss 0.01293
Epoch [ 20]: Loss 0.01379
Epoch [ 20]: Loss 0.01300
Epoch [ 20]: Loss 0.01287
Epoch [ 20]: Loss 0.01152
Epoch [ 20]: Loss 0.01283
Epoch [ 20]: Loss 0.01317
Validation: Loss 0.01138 Accuracy 1.00000
Validation: Loss 0.01166 Accuracy 1.00000
Epoch [ 21]: Loss 0.01338
Epoch [ 21]: Loss 0.01216
Epoch [ 21]: Loss 0.01169
Epoch [ 21]: Loss 0.01177
Epoch [ 21]: Loss 0.01186
Epoch [ 21]: Loss 0.01105
Epoch [ 21]: Loss 0.00907
Validation: Loss 0.01048 Accuracy 1.00000
Validation: Loss 0.01074 Accuracy 1.00000
Epoch [ 22]: Loss 0.01132
Epoch [ 22]: Loss 0.01110
Epoch [ 22]: Loss 0.01074
Epoch [ 22]: Loss 0.01151
Epoch [ 22]: Loss 0.01105
Epoch [ 22]: Loss 0.01019
Epoch [ 22]: Loss 0.00826
Validation: Loss 0.00953 Accuracy 1.00000
Validation: Loss 0.00976 Accuracy 1.00000
Epoch [ 23]: Loss 0.01066
Epoch [ 23]: Loss 0.01069
Epoch [ 23]: Loss 0.00968
Epoch [ 23]: Loss 0.00992
Epoch [ 23]: Loss 0.00893
Epoch [ 23]: Loss 0.00948
Epoch [ 23]: Loss 0.00774
Validation: Loss 0.00850 Accuracy 1.00000
Validation: Loss 0.00869 Accuracy 1.00000
Epoch [ 24]: Loss 0.00814
Epoch [ 24]: Loss 0.00956
Epoch [ 24]: Loss 0.00821
Epoch [ 24]: Loss 0.00847
Epoch [ 24]: Loss 0.00960
Epoch [ 24]: Loss 0.00843
Epoch [ 24]: Loss 0.00757
Validation: Loss 0.00761 Accuracy 1.00000
Validation: Loss 0.00777 Accuracy 1.00000
Epoch [ 25]: Loss 0.00751
Epoch [ 25]: Loss 0.00819
Epoch [ 25]: Loss 0.00836
Epoch [ 25]: Loss 0.00796
Epoch [ 25]: Loss 0.00748
Epoch [ 25]: Loss 0.00749
Epoch [ 25]: Loss 0.00768
Validation: Loss 0.00696 Accuracy 1.00000
Validation: Loss 0.00710 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6

CUDA libraries: 
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6

Julia packages: 
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0

Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.453 GiB / 4.750 GiB available)

This page was generated using Literate.jl.