Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = gpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.62553
Epoch [  1]: Loss 0.60510
Epoch [  1]: Loss 0.55835
Epoch [  1]: Loss 0.54149
Epoch [  1]: Loss 0.52378
Epoch [  1]: Loss 0.49478
Epoch [  1]: Loss 0.48398
Validation: Loss 0.46964 Accuracy 1.00000
Validation: Loss 0.45746 Accuracy 1.00000
Epoch [  2]: Loss 0.47575
Epoch [  2]: Loss 0.45342
Epoch [  2]: Loss 0.43241
Epoch [  2]: Loss 0.43282
Epoch [  2]: Loss 0.41279
Epoch [  2]: Loss 0.39988
Epoch [  2]: Loss 0.37074
Validation: Loss 0.37208 Accuracy 1.00000
Validation: Loss 0.35773 Accuracy 1.00000
Epoch [  3]: Loss 0.37759
Epoch [  3]: Loss 0.35141
Epoch [  3]: Loss 0.34756
Epoch [  3]: Loss 0.32144
Epoch [  3]: Loss 0.32601
Epoch [  3]: Loss 0.31558
Epoch [  3]: Loss 0.31383
Validation: Loss 0.28696 Accuracy 1.00000
Validation: Loss 0.27097 Accuracy 1.00000
Epoch [  4]: Loss 0.28812
Epoch [  4]: Loss 0.26925
Epoch [  4]: Loss 0.26329
Epoch [  4]: Loss 0.25165
Epoch [  4]: Loss 0.24186
Epoch [  4]: Loss 0.24788
Epoch [  4]: Loss 0.23699
Validation: Loss 0.21699 Accuracy 1.00000
Validation: Loss 0.20103 Accuracy 1.00000
Epoch [  5]: Loss 0.20486
Epoch [  5]: Loss 0.19917
Epoch [  5]: Loss 0.21229
Epoch [  5]: Loss 0.19976
Epoch [  5]: Loss 0.17600
Epoch [  5]: Loss 0.18594
Epoch [  5]: Loss 0.15670
Validation: Loss 0.16125 Accuracy 1.00000
Validation: Loss 0.14669 Accuracy 1.00000
Epoch [  6]: Loss 0.16396
Epoch [  6]: Loss 0.15604
Epoch [  6]: Loss 0.14779
Epoch [  6]: Loss 0.14857
Epoch [  6]: Loss 0.12744
Epoch [  6]: Loss 0.12807
Epoch [  6]: Loss 0.11374
Validation: Loss 0.11812 Accuracy 1.00000
Validation: Loss 0.10615 Accuracy 1.00000
Epoch [  7]: Loss 0.11222
Epoch [  7]: Loss 0.10495
Epoch [  7]: Loss 0.10957
Epoch [  7]: Loss 0.10119
Epoch [  7]: Loss 0.10683
Epoch [  7]: Loss 0.09341
Epoch [  7]: Loss 0.09857
Validation: Loss 0.08470 Accuracy 1.00000
Validation: Loss 0.07579 Accuracy 1.00000
Epoch [  8]: Loss 0.08797
Epoch [  8]: Loss 0.07320
Epoch [  8]: Loss 0.07783
Epoch [  8]: Loss 0.07383
Epoch [  8]: Loss 0.06510
Epoch [  8]: Loss 0.07110
Epoch [  8]: Loss 0.05521
Validation: Loss 0.05896 Accuracy 1.00000
Validation: Loss 0.05305 Accuracy 1.00000
Epoch [  9]: Loss 0.05691
Epoch [  9]: Loss 0.05191
Epoch [  9]: Loss 0.05573
Epoch [  9]: Loss 0.05292
Epoch [  9]: Loss 0.04957
Epoch [  9]: Loss 0.04796
Epoch [  9]: Loss 0.04510
Validation: Loss 0.04381 Accuracy 1.00000
Validation: Loss 0.03964 Accuracy 1.00000
Epoch [ 10]: Loss 0.05030
Epoch [ 10]: Loss 0.03891
Epoch [ 10]: Loss 0.04128
Epoch [ 10]: Loss 0.03873
Epoch [ 10]: Loss 0.03725
Epoch [ 10]: Loss 0.03647
Epoch [ 10]: Loss 0.03616
Validation: Loss 0.03545 Accuracy 1.00000
Validation: Loss 0.03206 Accuracy 1.00000
Epoch [ 11]: Loss 0.03576
Epoch [ 11]: Loss 0.03512
Epoch [ 11]: Loss 0.03400
Epoch [ 11]: Loss 0.03295
Epoch [ 11]: Loss 0.03057
Epoch [ 11]: Loss 0.03171
Epoch [ 11]: Loss 0.03048
Validation: Loss 0.03011 Accuracy 1.00000
Validation: Loss 0.02716 Accuracy 1.00000
Epoch [ 12]: Loss 0.02811
Epoch [ 12]: Loss 0.03051
Epoch [ 12]: Loss 0.02985
Epoch [ 12]: Loss 0.02784
Epoch [ 12]: Loss 0.02848
Epoch [ 12]: Loss 0.02678
Epoch [ 12]: Loss 0.02645
Validation: Loss 0.02622 Accuracy 1.00000
Validation: Loss 0.02360 Accuracy 1.00000
Epoch [ 13]: Loss 0.02600
Epoch [ 13]: Loss 0.02664
Epoch [ 13]: Loss 0.02552
Epoch [ 13]: Loss 0.02399
Epoch [ 13]: Loss 0.02336
Epoch [ 13]: Loss 0.02505
Epoch [ 13]: Loss 0.02253
Validation: Loss 0.02321 Accuracy 1.00000
Validation: Loss 0.02084 Accuracy 1.00000
Epoch [ 14]: Loss 0.02242
Epoch [ 14]: Loss 0.02084
Epoch [ 14]: Loss 0.02346
Epoch [ 14]: Loss 0.02310
Epoch [ 14]: Loss 0.02043
Epoch [ 14]: Loss 0.02277
Epoch [ 14]: Loss 0.02318
Validation: Loss 0.02080 Accuracy 1.00000
Validation: Loss 0.01863 Accuracy 1.00000
Epoch [ 15]: Loss 0.02136
Epoch [ 15]: Loss 0.01921
Epoch [ 15]: Loss 0.02144
Epoch [ 15]: Loss 0.02043
Epoch [ 15]: Loss 0.01801
Epoch [ 15]: Loss 0.01950
Epoch [ 15]: Loss 0.01957
Validation: Loss 0.01878 Accuracy 1.00000
Validation: Loss 0.01679 Accuracy 1.00000
Epoch [ 16]: Loss 0.01963
Epoch [ 16]: Loss 0.01869
Epoch [ 16]: Loss 0.01693
Epoch [ 16]: Loss 0.01735
Epoch [ 16]: Loss 0.01941
Epoch [ 16]: Loss 0.01672
Epoch [ 16]: Loss 0.01712
Validation: Loss 0.01707 Accuracy 1.00000
Validation: Loss 0.01522 Accuracy 1.00000
Epoch [ 17]: Loss 0.01644
Epoch [ 17]: Loss 0.01655
Epoch [ 17]: Loss 0.01663
Epoch [ 17]: Loss 0.01688
Epoch [ 17]: Loss 0.01635
Epoch [ 17]: Loss 0.01556
Epoch [ 17]: Loss 0.01802
Validation: Loss 0.01561 Accuracy 1.00000
Validation: Loss 0.01389 Accuracy 1.00000
Epoch [ 18]: Loss 0.01591
Epoch [ 18]: Loss 0.01498
Epoch [ 18]: Loss 0.01638
Epoch [ 18]: Loss 0.01533
Epoch [ 18]: Loss 0.01457
Epoch [ 18]: Loss 0.01398
Epoch [ 18]: Loss 0.01269
Validation: Loss 0.01433 Accuracy 1.00000
Validation: Loss 0.01273 Accuracy 1.00000
Epoch [ 19]: Loss 0.01489
Epoch [ 19]: Loss 0.01322
Epoch [ 19]: Loss 0.01360
Epoch [ 19]: Loss 0.01345
Epoch [ 19]: Loss 0.01382
Epoch [ 19]: Loss 0.01379
Epoch [ 19]: Loss 0.01584
Validation: Loss 0.01323 Accuracy 1.00000
Validation: Loss 0.01173 Accuracy 1.00000
Epoch [ 20]: Loss 0.01413
Epoch [ 20]: Loss 0.01346
Epoch [ 20]: Loss 0.01243
Epoch [ 20]: Loss 0.01341
Epoch [ 20]: Loss 0.01142
Epoch [ 20]: Loss 0.01238
Epoch [ 20]: Loss 0.01175
Validation: Loss 0.01222 Accuracy 1.00000
Validation: Loss 0.01083 Accuracy 1.00000
Epoch [ 21]: Loss 0.01175
Epoch [ 21]: Loss 0.01221
Epoch [ 21]: Loss 0.01252
Epoch [ 21]: Loss 0.01250
Epoch [ 21]: Loss 0.01099
Epoch [ 21]: Loss 0.01107
Epoch [ 21]: Loss 0.01175
Validation: Loss 0.01125 Accuracy 1.00000
Validation: Loss 0.00997 Accuracy 1.00000
Epoch [ 22]: Loss 0.01070
Epoch [ 22]: Loss 0.01111
Epoch [ 22]: Loss 0.01011
Epoch [ 22]: Loss 0.01072
Epoch [ 22]: Loss 0.01068
Epoch [ 22]: Loss 0.01163
Epoch [ 22]: Loss 0.01102
Validation: Loss 0.01020 Accuracy 1.00000
Validation: Loss 0.00906 Accuracy 1.00000
Epoch [ 23]: Loss 0.01049
Epoch [ 23]: Loss 0.01014
Epoch [ 23]: Loss 0.01008
Epoch [ 23]: Loss 0.00944
Epoch [ 23]: Loss 0.00962
Epoch [ 23]: Loss 0.00895
Epoch [ 23]: Loss 0.00913
Validation: Loss 0.00905 Accuracy 1.00000
Validation: Loss 0.00807 Accuracy 1.00000
Epoch [ 24]: Loss 0.00915
Epoch [ 24]: Loss 0.00921
Epoch [ 24]: Loss 0.00845
Epoch [ 24]: Loss 0.00845
Epoch [ 24]: Loss 0.00890
Epoch [ 24]: Loss 0.00803
Epoch [ 24]: Loss 0.00752
Validation: Loss 0.00808 Accuracy 1.00000
Validation: Loss 0.00723 Accuracy 1.00000
Epoch [ 25]: Loss 0.00786
Epoch [ 25]: Loss 0.00768
Epoch [ 25]: Loss 0.00805
Epoch [ 25]: Loss 0.00816
Epoch [ 25]: Loss 0.00758
Epoch [ 25]: Loss 0.00742
Epoch [ 25]: Loss 0.00741
Validation: Loss 0.00738 Accuracy 1.00000
Validation: Loss 0.00663 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.63003
Epoch [  1]: Loss 0.59696
Epoch [  1]: Loss 0.56691
Epoch [  1]: Loss 0.54274
Epoch [  1]: Loss 0.51708
Epoch [  1]: Loss 0.48755
Epoch [  1]: Loss 0.48718
Validation: Loss 0.46947 Accuracy 1.00000
Validation: Loss 0.46928 Accuracy 1.00000
Epoch [  2]: Loss 0.46952
Epoch [  2]: Loss 0.46210
Epoch [  2]: Loss 0.43708
Epoch [  2]: Loss 0.41283
Epoch [  2]: Loss 0.40757
Epoch [  2]: Loss 0.39893
Epoch [  2]: Loss 0.40373
Validation: Loss 0.37232 Accuracy 1.00000
Validation: Loss 0.37178 Accuracy 1.00000
Epoch [  3]: Loss 0.37683
Epoch [  3]: Loss 0.35058
Epoch [  3]: Loss 0.35233
Epoch [  3]: Loss 0.33516
Epoch [  3]: Loss 0.31468
Epoch [  3]: Loss 0.30104
Epoch [  3]: Loss 0.30092
Validation: Loss 0.28742 Accuracy 1.00000
Validation: Loss 0.28695 Accuracy 1.00000
Epoch [  4]: Loss 0.28437
Epoch [  4]: Loss 0.26257
Epoch [  4]: Loss 0.25741
Epoch [  4]: Loss 0.26158
Epoch [  4]: Loss 0.24157
Epoch [  4]: Loss 0.23809
Epoch [  4]: Loss 0.25725
Validation: Loss 0.21740 Accuracy 1.00000
Validation: Loss 0.21727 Accuracy 1.00000
Epoch [  5]: Loss 0.22413
Epoch [  5]: Loss 0.20429
Epoch [  5]: Loss 0.19420
Epoch [  5]: Loss 0.19286
Epoch [  5]: Loss 0.17194
Epoch [  5]: Loss 0.17953
Epoch [  5]: Loss 0.16164
Validation: Loss 0.16148 Accuracy 1.00000
Validation: Loss 0.16178 Accuracy 1.00000
Epoch [  6]: Loss 0.15849
Epoch [  6]: Loss 0.16297
Epoch [  6]: Loss 0.14785
Epoch [  6]: Loss 0.13243
Epoch [  6]: Loss 0.13312
Epoch [  6]: Loss 0.12825
Epoch [  6]: Loss 0.11395
Validation: Loss 0.11824 Accuracy 1.00000
Validation: Loss 0.11879 Accuracy 1.00000
Epoch [  7]: Loss 0.11945
Epoch [  7]: Loss 0.10755
Epoch [  7]: Loss 0.10389
Epoch [  7]: Loss 0.10288
Epoch [  7]: Loss 0.09822
Epoch [  7]: Loss 0.09268
Epoch [  7]: Loss 0.09171
Validation: Loss 0.08467 Accuracy 1.00000
Validation: Loss 0.08524 Accuracy 1.00000
Epoch [  8]: Loss 0.08380
Epoch [  8]: Loss 0.07620
Epoch [  8]: Loss 0.06893
Epoch [  8]: Loss 0.07261
Epoch [  8]: Loss 0.07318
Epoch [  8]: Loss 0.06899
Epoch [  8]: Loss 0.05518
Validation: Loss 0.05907 Accuracy 1.00000
Validation: Loss 0.05948 Accuracy 1.00000
Epoch [  9]: Loss 0.05651
Epoch [  9]: Loss 0.05527
Epoch [  9]: Loss 0.05700
Epoch [  9]: Loss 0.04890
Epoch [  9]: Loss 0.04720
Epoch [  9]: Loss 0.04867
Epoch [  9]: Loss 0.04540
Validation: Loss 0.04413 Accuracy 1.00000
Validation: Loss 0.04438 Accuracy 1.00000
Epoch [ 10]: Loss 0.04195
Epoch [ 10]: Loss 0.04340
Epoch [ 10]: Loss 0.04065
Epoch [ 10]: Loss 0.04005
Epoch [ 10]: Loss 0.03923
Epoch [ 10]: Loss 0.03813
Epoch [ 10]: Loss 0.03156
Validation: Loss 0.03586 Accuracy 1.00000
Validation: Loss 0.03605 Accuracy 1.00000
Epoch [ 11]: Loss 0.03494
Epoch [ 11]: Loss 0.03379
Epoch [ 11]: Loss 0.03358
Epoch [ 11]: Loss 0.03390
Epoch [ 11]: Loss 0.03312
Epoch [ 11]: Loss 0.03158
Epoch [ 11]: Loss 0.02879
Validation: Loss 0.03052 Accuracy 1.00000
Validation: Loss 0.03070 Accuracy 1.00000
Epoch [ 12]: Loss 0.03313
Epoch [ 12]: Loss 0.02708
Epoch [ 12]: Loss 0.02802
Epoch [ 12]: Loss 0.02921
Epoch [ 12]: Loss 0.02716
Epoch [ 12]: Loss 0.02787
Epoch [ 12]: Loss 0.02590
Validation: Loss 0.02661 Accuracy 1.00000
Validation: Loss 0.02677 Accuracy 1.00000
Epoch [ 13]: Loss 0.02568
Epoch [ 13]: Loss 0.02455
Epoch [ 13]: Loss 0.02725
Epoch [ 13]: Loss 0.02433
Epoch [ 13]: Loss 0.02632
Epoch [ 13]: Loss 0.02308
Epoch [ 13]: Loss 0.02287
Validation: Loss 0.02358 Accuracy 1.00000
Validation: Loss 0.02373 Accuracy 1.00000
Epoch [ 14]: Loss 0.02295
Epoch [ 14]: Loss 0.02235
Epoch [ 14]: Loss 0.02141
Epoch [ 14]: Loss 0.02329
Epoch [ 14]: Loss 0.02351
Epoch [ 14]: Loss 0.02120
Epoch [ 14]: Loss 0.01971
Validation: Loss 0.02114 Accuracy 1.00000
Validation: Loss 0.02127 Accuracy 1.00000
Epoch [ 15]: Loss 0.01979
Epoch [ 15]: Loss 0.02159
Epoch [ 15]: Loss 0.02015
Epoch [ 15]: Loss 0.01984
Epoch [ 15]: Loss 0.01975
Epoch [ 15]: Loss 0.02058
Epoch [ 15]: Loss 0.01560
Validation: Loss 0.01911 Accuracy 1.00000
Validation: Loss 0.01924 Accuracy 1.00000
Epoch [ 16]: Loss 0.01894
Epoch [ 16]: Loss 0.01885
Epoch [ 16]: Loss 0.01842
Epoch [ 16]: Loss 0.01818
Epoch [ 16]: Loss 0.01741
Epoch [ 16]: Loss 0.01752
Epoch [ 16]: Loss 0.01811
Validation: Loss 0.01740 Accuracy 1.00000
Validation: Loss 0.01751 Accuracy 1.00000
Epoch [ 17]: Loss 0.01640
Epoch [ 17]: Loss 0.01702
Epoch [ 17]: Loss 0.01696
Epoch [ 17]: Loss 0.01563
Epoch [ 17]: Loss 0.01621
Epoch [ 17]: Loss 0.01695
Epoch [ 17]: Loss 0.01844
Validation: Loss 0.01592 Accuracy 1.00000
Validation: Loss 0.01603 Accuracy 1.00000
Epoch [ 18]: Loss 0.01453
Epoch [ 18]: Loss 0.01511
Epoch [ 18]: Loss 0.01629
Epoch [ 18]: Loss 0.01505
Epoch [ 18]: Loss 0.01591
Epoch [ 18]: Loss 0.01426
Epoch [ 18]: Loss 0.01578
Validation: Loss 0.01463 Accuracy 1.00000
Validation: Loss 0.01473 Accuracy 1.00000
Epoch [ 19]: Loss 0.01349
Epoch [ 19]: Loss 0.01488
Epoch [ 19]: Loss 0.01332
Epoch [ 19]: Loss 0.01482
Epoch [ 19]: Loss 0.01563
Epoch [ 19]: Loss 0.01231
Epoch [ 19]: Loss 0.01245
Validation: Loss 0.01350 Accuracy 1.00000
Validation: Loss 0.01360 Accuracy 1.00000
Epoch [ 20]: Loss 0.01351
Epoch [ 20]: Loss 0.01323
Epoch [ 20]: Loss 0.01347
Epoch [ 20]: Loss 0.01267
Epoch [ 20]: Loss 0.01344
Epoch [ 20]: Loss 0.01164
Epoch [ 20]: Loss 0.01201
Validation: Loss 0.01250 Accuracy 1.00000
Validation: Loss 0.01260 Accuracy 1.00000
Epoch [ 21]: Loss 0.01225
Epoch [ 21]: Loss 0.01262
Epoch [ 21]: Loss 0.01212
Epoch [ 21]: Loss 0.01126
Epoch [ 21]: Loss 0.01137
Epoch [ 21]: Loss 0.01230
Epoch [ 21]: Loss 0.01205
Validation: Loss 0.01157 Accuracy 1.00000
Validation: Loss 0.01166 Accuracy 1.00000
Epoch [ 22]: Loss 0.01225
Epoch [ 22]: Loss 0.01110
Epoch [ 22]: Loss 0.01040
Epoch [ 22]: Loss 0.01072
Epoch [ 22]: Loss 0.01101
Epoch [ 22]: Loss 0.01100
Epoch [ 22]: Loss 0.01066
Validation: Loss 0.01059 Accuracy 1.00000
Validation: Loss 0.01068 Accuracy 1.00000
Epoch [ 23]: Loss 0.01040
Epoch [ 23]: Loss 0.00974
Epoch [ 23]: Loss 0.01038
Epoch [ 23]: Loss 0.01026
Epoch [ 23]: Loss 0.00959
Epoch [ 23]: Loss 0.00992
Epoch [ 23]: Loss 0.01033
Validation: Loss 0.00950 Accuracy 1.00000
Validation: Loss 0.00958 Accuracy 1.00000
Epoch [ 24]: Loss 0.00904
Epoch [ 24]: Loss 0.01043
Epoch [ 24]: Loss 0.00946
Epoch [ 24]: Loss 0.00901
Epoch [ 24]: Loss 0.00827
Epoch [ 24]: Loss 0.00763
Epoch [ 24]: Loss 0.00912
Validation: Loss 0.00841 Accuracy 1.00000
Validation: Loss 0.00848 Accuracy 1.00000
Epoch [ 25]: Loss 0.00873
Epoch [ 25]: Loss 0.00897
Epoch [ 25]: Loss 0.00763
Epoch [ 25]: Loss 0.00795
Epoch [ 25]: Loss 0.00742
Epoch [ 25]: Loss 0.00765
Epoch [ 25]: Loss 0.00652
Validation: Loss 0.00760 Accuracy 1.00000
Validation: Loss 0.00765 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6

CUDA libraries: 
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6

Julia packages: 
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0

Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.453 GiB / 4.750 GiB available)

This page was generated using Literate.jl.