Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = gpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.62697
Epoch [  1]: Loss 0.59398
Epoch [  1]: Loss 0.56516
Epoch [  1]: Loss 0.53496
Epoch [  1]: Loss 0.50859
Epoch [  1]: Loss 0.50693
Epoch [  1]: Loss 0.47860
Validation: Loss 0.47768 Accuracy 1.00000
Validation: Loss 0.46095 Accuracy 1.00000
Epoch [  2]: Loss 0.47312
Epoch [  2]: Loss 0.44743
Epoch [  2]: Loss 0.42440
Epoch [  2]: Loss 0.42461
Epoch [  2]: Loss 0.41525
Epoch [  2]: Loss 0.40489
Epoch [  2]: Loss 0.40008
Validation: Loss 0.38270 Accuracy 1.00000
Validation: Loss 0.36255 Accuracy 1.00000
Epoch [  3]: Loss 0.36150
Epoch [  3]: Loss 0.36675
Epoch [  3]: Loss 0.34973
Epoch [  3]: Loss 0.33199
Epoch [  3]: Loss 0.32029
Epoch [  3]: Loss 0.30015
Epoch [  3]: Loss 0.31455
Validation: Loss 0.29887 Accuracy 1.00000
Validation: Loss 0.27710 Accuracy 1.00000
Epoch [  4]: Loss 0.28120
Epoch [  4]: Loss 0.27205
Epoch [  4]: Loss 0.27875
Epoch [  4]: Loss 0.24838
Epoch [  4]: Loss 0.23269
Epoch [  4]: Loss 0.23989
Epoch [  4]: Loss 0.24375
Validation: Loss 0.22886 Accuracy 1.00000
Validation: Loss 0.20708 Accuracy 1.00000
Epoch [  5]: Loss 0.21448
Epoch [  5]: Loss 0.21254
Epoch [  5]: Loss 0.18882
Epoch [  5]: Loss 0.18588
Epoch [  5]: Loss 0.18461
Epoch [  5]: Loss 0.17709
Epoch [  5]: Loss 0.18849
Validation: Loss 0.17205 Accuracy 1.00000
Validation: Loss 0.15199 Accuracy 1.00000
Epoch [  6]: Loss 0.17380
Epoch [  6]: Loss 0.14316
Epoch [  6]: Loss 0.14159
Epoch [  6]: Loss 0.15160
Epoch [  6]: Loss 0.12746
Epoch [  6]: Loss 0.12037
Epoch [  6]: Loss 0.14508
Validation: Loss 0.12684 Accuracy 1.00000
Validation: Loss 0.11035 Accuracy 1.00000
Epoch [  7]: Loss 0.11594
Epoch [  7]: Loss 0.10788
Epoch [  7]: Loss 0.11033
Epoch [  7]: Loss 0.10059
Epoch [  7]: Loss 0.10187
Epoch [  7]: Loss 0.09224
Epoch [  7]: Loss 0.07907
Validation: Loss 0.09069 Accuracy 1.00000
Validation: Loss 0.07851 Accuracy 1.00000
Epoch [  8]: Loss 0.07172
Epoch [  8]: Loss 0.07587
Epoch [  8]: Loss 0.08157
Epoch [  8]: Loss 0.07087
Epoch [  8]: Loss 0.07289
Epoch [  8]: Loss 0.06797
Epoch [  8]: Loss 0.05956
Validation: Loss 0.06309 Accuracy 1.00000
Validation: Loss 0.05488 Accuracy 1.00000
Epoch [  9]: Loss 0.05936
Epoch [  9]: Loss 0.05396
Epoch [  9]: Loss 0.05191
Epoch [  9]: Loss 0.05135
Epoch [  9]: Loss 0.04774
Epoch [  9]: Loss 0.04779
Epoch [  9]: Loss 0.04686
Validation: Loss 0.04689 Accuracy 1.00000
Validation: Loss 0.04105 Accuracy 1.00000
Epoch [ 10]: Loss 0.04626
Epoch [ 10]: Loss 0.04220
Epoch [ 10]: Loss 0.03982
Epoch [ 10]: Loss 0.03832
Epoch [ 10]: Loss 0.03782
Epoch [ 10]: Loss 0.03671
Epoch [ 10]: Loss 0.03514
Validation: Loss 0.03801 Accuracy 1.00000
Validation: Loss 0.03326 Accuracy 1.00000
Epoch [ 11]: Loss 0.03595
Epoch [ 11]: Loss 0.03364
Epoch [ 11]: Loss 0.03255
Epoch [ 11]: Loss 0.03230
Epoch [ 11]: Loss 0.03105
Epoch [ 11]: Loss 0.03117
Epoch [ 11]: Loss 0.03813
Validation: Loss 0.03236 Accuracy 1.00000
Validation: Loss 0.02822 Accuracy 1.00000
Epoch [ 12]: Loss 0.03086
Epoch [ 12]: Loss 0.02945
Epoch [ 12]: Loss 0.02782
Epoch [ 12]: Loss 0.02872
Epoch [ 12]: Loss 0.02506
Epoch [ 12]: Loss 0.02741
Epoch [ 12]: Loss 0.03039
Validation: Loss 0.02818 Accuracy 1.00000
Validation: Loss 0.02451 Accuracy 1.00000
Epoch [ 13]: Loss 0.02362
Epoch [ 13]: Loss 0.02790
Epoch [ 13]: Loss 0.02369
Epoch [ 13]: Loss 0.02477
Epoch [ 13]: Loss 0.02390
Epoch [ 13]: Loss 0.02561
Epoch [ 13]: Loss 0.02100
Validation: Loss 0.02495 Accuracy 1.00000
Validation: Loss 0.02164 Accuracy 1.00000
Epoch [ 14]: Loss 0.02353
Epoch [ 14]: Loss 0.02275
Epoch [ 14]: Loss 0.02317
Epoch [ 14]: Loss 0.02103
Epoch [ 14]: Loss 0.02148
Epoch [ 14]: Loss 0.01979
Epoch [ 14]: Loss 0.02330
Validation: Loss 0.02237 Accuracy 1.00000
Validation: Loss 0.01934 Accuracy 1.00000
Epoch [ 15]: Loss 0.02015
Epoch [ 15]: Loss 0.02168
Epoch [ 15]: Loss 0.02007
Epoch [ 15]: Loss 0.01978
Epoch [ 15]: Loss 0.01967
Epoch [ 15]: Loss 0.01728
Epoch [ 15]: Loss 0.01964
Validation: Loss 0.02022 Accuracy 1.00000
Validation: Loss 0.01743 Accuracy 1.00000
Epoch [ 16]: Loss 0.01868
Epoch [ 16]: Loss 0.01775
Epoch [ 16]: Loss 0.01732
Epoch [ 16]: Loss 0.01741
Epoch [ 16]: Loss 0.01884
Epoch [ 16]: Loss 0.01776
Epoch [ 16]: Loss 0.01590
Validation: Loss 0.01842 Accuracy 1.00000
Validation: Loss 0.01582 Accuracy 1.00000
Epoch [ 17]: Loss 0.01537
Epoch [ 17]: Loss 0.01613
Epoch [ 17]: Loss 0.01608
Epoch [ 17]: Loss 0.01666
Epoch [ 17]: Loss 0.01691
Epoch [ 17]: Loss 0.01686
Epoch [ 17]: Loss 0.01515
Validation: Loss 0.01689 Accuracy 1.00000
Validation: Loss 0.01446 Accuracy 1.00000
Epoch [ 18]: Loss 0.01585
Epoch [ 18]: Loss 0.01639
Epoch [ 18]: Loss 0.01385
Epoch [ 18]: Loss 0.01543
Epoch [ 18]: Loss 0.01423
Epoch [ 18]: Loss 0.01436
Epoch [ 18]: Loss 0.01335
Validation: Loss 0.01556 Accuracy 1.00000
Validation: Loss 0.01329 Accuracy 1.00000
Epoch [ 19]: Loss 0.01473
Epoch [ 19]: Loss 0.01373
Epoch [ 19]: Loss 0.01372
Epoch [ 19]: Loss 0.01328
Epoch [ 19]: Loss 0.01436
Epoch [ 19]: Loss 0.01300
Epoch [ 19]: Loss 0.01338
Validation: Loss 0.01441 Accuracy 1.00000
Validation: Loss 0.01229 Accuracy 1.00000
Epoch [ 20]: Loss 0.01298
Epoch [ 20]: Loss 0.01270
Epoch [ 20]: Loss 0.01347
Epoch [ 20]: Loss 0.01206
Epoch [ 20]: Loss 0.01257
Epoch [ 20]: Loss 0.01293
Epoch [ 20]: Loss 0.01241
Validation: Loss 0.01339 Accuracy 1.00000
Validation: Loss 0.01140 Accuracy 1.00000
Epoch [ 21]: Loss 0.01303
Epoch [ 21]: Loss 0.01290
Epoch [ 21]: Loss 0.01071
Epoch [ 21]: Loss 0.01154
Epoch [ 21]: Loss 0.01143
Epoch [ 21]: Loss 0.01166
Epoch [ 21]: Loss 0.01155
Validation: Loss 0.01243 Accuracy 1.00000
Validation: Loss 0.01058 Accuracy 1.00000
Epoch [ 22]: Loss 0.01153
Epoch [ 22]: Loss 0.01141
Epoch [ 22]: Loss 0.01106
Epoch [ 22]: Loss 0.01093
Epoch [ 22]: Loss 0.01039
Epoch [ 22]: Loss 0.01079
Epoch [ 22]: Loss 0.01068
Validation: Loss 0.01148 Accuracy 1.00000
Validation: Loss 0.00977 Accuracy 1.00000
Epoch [ 23]: Loss 0.01117
Epoch [ 23]: Loss 0.01093
Epoch [ 23]: Loss 0.01028
Epoch [ 23]: Loss 0.00814
Epoch [ 23]: Loss 0.00982
Epoch [ 23]: Loss 0.01027
Epoch [ 23]: Loss 0.01047
Validation: Loss 0.01041 Accuracy 1.00000
Validation: Loss 0.00888 Accuracy 1.00000
Epoch [ 24]: Loss 0.00975
Epoch [ 24]: Loss 0.00850
Epoch [ 24]: Loss 0.00974
Epoch [ 24]: Loss 0.00886
Epoch [ 24]: Loss 0.00915
Epoch [ 24]: Loss 0.00867
Epoch [ 24]: Loss 0.00904
Validation: Loss 0.00923 Accuracy 1.00000
Validation: Loss 0.00792 Accuracy 1.00000
Epoch [ 25]: Loss 0.00883
Epoch [ 25]: Loss 0.00846
Epoch [ 25]: Loss 0.00794
Epoch [ 25]: Loss 0.00773
Epoch [ 25]: Loss 0.00762
Epoch [ 25]: Loss 0.00793
Epoch [ 25]: Loss 0.00859
Validation: Loss 0.00824 Accuracy 1.00000
Validation: Loss 0.00711 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.61677
Epoch [  1]: Loss 0.59664
Epoch [  1]: Loss 0.56696
Epoch [  1]: Loss 0.54556
Epoch [  1]: Loss 0.52115
Epoch [  1]: Loss 0.49877
Epoch [  1]: Loss 0.47923
Validation: Loss 0.46639 Accuracy 1.00000
Validation: Loss 0.46451 Accuracy 1.00000
Epoch [  2]: Loss 0.46974
Epoch [  2]: Loss 0.46087
Epoch [  2]: Loss 0.44181
Epoch [  2]: Loss 0.41903
Epoch [  2]: Loss 0.41014
Epoch [  2]: Loss 0.40046
Epoch [  2]: Loss 0.37173
Validation: Loss 0.36831 Accuracy 1.00000
Validation: Loss 0.36584 Accuracy 1.00000
Epoch [  3]: Loss 0.37839
Epoch [  3]: Loss 0.34011
Epoch [  3]: Loss 0.34844
Epoch [  3]: Loss 0.33601
Epoch [  3]: Loss 0.32274
Epoch [  3]: Loss 0.30881
Epoch [  3]: Loss 0.31293
Validation: Loss 0.28307 Accuracy 1.00000
Validation: Loss 0.28024 Accuracy 1.00000
Epoch [  4]: Loss 0.29027
Epoch [  4]: Loss 0.27727
Epoch [  4]: Loss 0.26092
Epoch [  4]: Loss 0.25209
Epoch [  4]: Loss 0.24520
Epoch [  4]: Loss 0.23017
Epoch [  4]: Loss 0.24575
Validation: Loss 0.21313 Accuracy 1.00000
Validation: Loss 0.21031 Accuracy 1.00000
Epoch [  5]: Loss 0.22150
Epoch [  5]: Loss 0.20521
Epoch [  5]: Loss 0.19930
Epoch [  5]: Loss 0.18867
Epoch [  5]: Loss 0.18568
Epoch [  5]: Loss 0.17074
Epoch [  5]: Loss 0.17204
Validation: Loss 0.15746 Accuracy 1.00000
Validation: Loss 0.15496 Accuracy 1.00000
Epoch [  6]: Loss 0.15970
Epoch [  6]: Loss 0.13205
Epoch [  6]: Loss 0.15669
Epoch [  6]: Loss 0.15128
Epoch [  6]: Loss 0.13629
Epoch [  6]: Loss 0.13000
Epoch [  6]: Loss 0.11765
Validation: Loss 0.11495 Accuracy 1.00000
Validation: Loss 0.11296 Accuracy 1.00000
Epoch [  7]: Loss 0.11980
Epoch [  7]: Loss 0.11451
Epoch [  7]: Loss 0.11211
Epoch [  7]: Loss 0.09977
Epoch [  7]: Loss 0.09863
Epoch [  7]: Loss 0.08840
Epoch [  7]: Loss 0.07912
Validation: Loss 0.08228 Accuracy 1.00000
Validation: Loss 0.08082 Accuracy 1.00000
Epoch [  8]: Loss 0.08656
Epoch [  8]: Loss 0.08304
Epoch [  8]: Loss 0.07644
Epoch [  8]: Loss 0.06626
Epoch [  8]: Loss 0.07073
Epoch [  8]: Loss 0.06500
Epoch [  8]: Loss 0.05950
Validation: Loss 0.05750 Accuracy 1.00000
Validation: Loss 0.05654 Accuracy 1.00000
Epoch [  9]: Loss 0.05698
Epoch [  9]: Loss 0.05342
Epoch [  9]: Loss 0.05181
Epoch [  9]: Loss 0.05536
Epoch [  9]: Loss 0.05212
Epoch [  9]: Loss 0.04511
Epoch [  9]: Loss 0.04361
Validation: Loss 0.04263 Accuracy 1.00000
Validation: Loss 0.04200 Accuracy 1.00000
Epoch [ 10]: Loss 0.04439
Epoch [ 10]: Loss 0.04347
Epoch [ 10]: Loss 0.04021
Epoch [ 10]: Loss 0.03888
Epoch [ 10]: Loss 0.03598
Epoch [ 10]: Loss 0.03881
Epoch [ 10]: Loss 0.03483
Validation: Loss 0.03447 Accuracy 1.00000
Validation: Loss 0.03398 Accuracy 1.00000
Epoch [ 11]: Loss 0.03252
Epoch [ 11]: Loss 0.03497
Epoch [ 11]: Loss 0.03439
Epoch [ 11]: Loss 0.03305
Epoch [ 11]: Loss 0.03149
Epoch [ 11]: Loss 0.03206
Epoch [ 11]: Loss 0.03145
Validation: Loss 0.02924 Accuracy 1.00000
Validation: Loss 0.02881 Accuracy 1.00000
Epoch [ 12]: Loss 0.02774
Epoch [ 12]: Loss 0.03031
Epoch [ 12]: Loss 0.03163
Epoch [ 12]: Loss 0.02678
Epoch [ 12]: Loss 0.02578
Epoch [ 12]: Loss 0.02732
Epoch [ 12]: Loss 0.02995
Validation: Loss 0.02543 Accuracy 1.00000
Validation: Loss 0.02505 Accuracy 1.00000
Epoch [ 13]: Loss 0.02522
Epoch [ 13]: Loss 0.02474
Epoch [ 13]: Loss 0.02833
Epoch [ 13]: Loss 0.02573
Epoch [ 13]: Loss 0.02098
Epoch [ 13]: Loss 0.02446
Epoch [ 13]: Loss 0.02244
Validation: Loss 0.02247 Accuracy 1.00000
Validation: Loss 0.02214 Accuracy 1.00000
Epoch [ 14]: Loss 0.02482
Epoch [ 14]: Loss 0.02483
Epoch [ 14]: Loss 0.02115
Epoch [ 14]: Loss 0.02101
Epoch [ 14]: Loss 0.02049
Epoch [ 14]: Loss 0.01979
Epoch [ 14]: Loss 0.02296
Validation: Loss 0.02009 Accuracy 1.00000
Validation: Loss 0.01979 Accuracy 1.00000
Epoch [ 15]: Loss 0.02099
Epoch [ 15]: Loss 0.02087
Epoch [ 15]: Loss 0.02019
Epoch [ 15]: Loss 0.01886
Epoch [ 15]: Loss 0.01825
Epoch [ 15]: Loss 0.02070
Epoch [ 15]: Loss 0.01535
Validation: Loss 0.01813 Accuracy 1.00000
Validation: Loss 0.01786 Accuracy 1.00000
Epoch [ 16]: Loss 0.01615
Epoch [ 16]: Loss 0.01854
Epoch [ 16]: Loss 0.01796
Epoch [ 16]: Loss 0.01724
Epoch [ 16]: Loss 0.01940
Epoch [ 16]: Loss 0.01851
Epoch [ 16]: Loss 0.01640
Validation: Loss 0.01649 Accuracy 1.00000
Validation: Loss 0.01624 Accuracy 1.00000
Epoch [ 17]: Loss 0.01617
Epoch [ 17]: Loss 0.01612
Epoch [ 17]: Loss 0.01668
Epoch [ 17]: Loss 0.01556
Epoch [ 17]: Loss 0.01700
Epoch [ 17]: Loss 0.01663
Epoch [ 17]: Loss 0.01557
Validation: Loss 0.01508 Accuracy 1.00000
Validation: Loss 0.01484 Accuracy 1.00000
Epoch [ 18]: Loss 0.01652
Epoch [ 18]: Loss 0.01518
Epoch [ 18]: Loss 0.01400
Epoch [ 18]: Loss 0.01579
Epoch [ 18]: Loss 0.01366
Epoch [ 18]: Loss 0.01473
Epoch [ 18]: Loss 0.01483
Validation: Loss 0.01385 Accuracy 1.00000
Validation: Loss 0.01363 Accuracy 1.00000
Epoch [ 19]: Loss 0.01521
Epoch [ 19]: Loss 0.01261
Epoch [ 19]: Loss 0.01432
Epoch [ 19]: Loss 0.01351
Epoch [ 19]: Loss 0.01362
Epoch [ 19]: Loss 0.01371
Epoch [ 19]: Loss 0.01267
Validation: Loss 0.01279 Accuracy 1.00000
Validation: Loss 0.01258 Accuracy 1.00000
Epoch [ 20]: Loss 0.01379
Epoch [ 20]: Loss 0.01352
Epoch [ 20]: Loss 0.01287
Epoch [ 20]: Loss 0.01335
Epoch [ 20]: Loss 0.01223
Epoch [ 20]: Loss 0.01161
Epoch [ 20]: Loss 0.00935
Validation: Loss 0.01184 Accuracy 1.00000
Validation: Loss 0.01164 Accuracy 1.00000
Epoch [ 21]: Loss 0.01266
Epoch [ 21]: Loss 0.01207
Epoch [ 21]: Loss 0.01113
Epoch [ 21]: Loss 0.01237
Epoch [ 21]: Loss 0.01212
Epoch [ 21]: Loss 0.01098
Epoch [ 21]: Loss 0.00998
Validation: Loss 0.01096 Accuracy 1.00000
Validation: Loss 0.01078 Accuracy 1.00000
Epoch [ 22]: Loss 0.01151
Epoch [ 22]: Loss 0.01029
Epoch [ 22]: Loss 0.01123
Epoch [ 22]: Loss 0.01122
Epoch [ 22]: Loss 0.01100
Epoch [ 22]: Loss 0.01014
Epoch [ 22]: Loss 0.01121
Validation: Loss 0.01007 Accuracy 1.00000
Validation: Loss 0.00990 Accuracy 1.00000
Epoch [ 23]: Loss 0.01075
Epoch [ 23]: Loss 0.01071
Epoch [ 23]: Loss 0.00989
Epoch [ 23]: Loss 0.00983
Epoch [ 23]: Loss 0.00970
Epoch [ 23]: Loss 0.00925
Epoch [ 23]: Loss 0.00829
Validation: Loss 0.00905 Accuracy 1.00000
Validation: Loss 0.00890 Accuracy 1.00000
Epoch [ 24]: Loss 0.00825
Epoch [ 24]: Loss 0.00950
Epoch [ 24]: Loss 0.00951
Epoch [ 24]: Loss 0.00885
Epoch [ 24]: Loss 0.00895
Epoch [ 24]: Loss 0.00847
Epoch [ 24]: Loss 0.00773
Validation: Loss 0.00805 Accuracy 1.00000
Validation: Loss 0.00792 Accuracy 1.00000
Epoch [ 25]: Loss 0.00708
Epoch [ 25]: Loss 0.00871
Epoch [ 25]: Loss 0.00888
Epoch [ 25]: Loss 0.00832
Epoch [ 25]: Loss 0.00751
Epoch [ 25]: Loss 0.00741
Epoch [ 25]: Loss 0.00643
Validation: Loss 0.00727 Accuracy 1.00000
Validation: Loss 0.00716 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6

CUDA libraries: 
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6

Julia packages: 
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0

Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

Preferences:
- CUDA_Driver_jll.compat: false

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.453 GiB / 4.750 GiB available)

This page was generated using Literate.jl.