Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = gpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.62263
Epoch [  1]: Loss 0.59399
Epoch [  1]: Loss 0.56550
Epoch [  1]: Loss 0.53704
Epoch [  1]: Loss 0.50868
Epoch [  1]: Loss 0.50280
Epoch [  1]: Loss 0.47381
Validation: Loss 0.47891 Accuracy 1.00000
Validation: Loss 0.47035 Accuracy 1.00000
Epoch [  2]: Loss 0.46086
Epoch [  2]: Loss 0.44802
Epoch [  2]: Loss 0.44150
Epoch [  2]: Loss 0.42453
Epoch [  2]: Loss 0.41094
Epoch [  2]: Loss 0.40257
Epoch [  2]: Loss 0.36869
Validation: Loss 0.38380 Accuracy 1.00000
Validation: Loss 0.37333 Accuracy 1.00000
Epoch [  3]: Loss 0.36389
Epoch [  3]: Loss 0.34828
Epoch [  3]: Loss 0.35541
Epoch [  3]: Loss 0.33913
Epoch [  3]: Loss 0.31160
Epoch [  3]: Loss 0.29873
Epoch [  3]: Loss 0.31439
Validation: Loss 0.30000 Accuracy 1.00000
Validation: Loss 0.28858 Accuracy 1.00000
Epoch [  4]: Loss 0.29708
Epoch [  4]: Loss 0.27315
Epoch [  4]: Loss 0.25883
Epoch [  4]: Loss 0.26385
Epoch [  4]: Loss 0.23952
Epoch [  4]: Loss 0.21266
Epoch [  4]: Loss 0.21457
Validation: Loss 0.22973 Accuracy 1.00000
Validation: Loss 0.21849 Accuracy 1.00000
Epoch [  5]: Loss 0.21333
Epoch [  5]: Loss 0.19749
Epoch [  5]: Loss 0.20678
Epoch [  5]: Loss 0.18353
Epoch [  5]: Loss 0.17721
Epoch [  5]: Loss 0.17316
Epoch [  5]: Loss 0.17852
Validation: Loss 0.17300 Accuracy 1.00000
Validation: Loss 0.16268 Accuracy 1.00000
Epoch [  6]: Loss 0.15564
Epoch [  6]: Loss 0.16077
Epoch [  6]: Loss 0.13975
Epoch [  6]: Loss 0.13800
Epoch [  6]: Loss 0.12762
Epoch [  6]: Loss 0.12303
Epoch [  6]: Loss 0.14861
Validation: Loss 0.12766 Accuracy 1.00000
Validation: Loss 0.11929 Accuracy 1.00000
Epoch [  7]: Loss 0.11430
Epoch [  7]: Loss 0.11170
Epoch [  7]: Loss 0.10897
Epoch [  7]: Loss 0.09104
Epoch [  7]: Loss 0.09552
Epoch [  7]: Loss 0.09343
Epoch [  7]: Loss 0.09741
Validation: Loss 0.09109 Accuracy 1.00000
Validation: Loss 0.08504 Accuracy 1.00000
Epoch [  8]: Loss 0.08001
Epoch [  8]: Loss 0.07238
Epoch [  8]: Loss 0.07332
Epoch [  8]: Loss 0.07675
Epoch [  8]: Loss 0.06976
Epoch [  8]: Loss 0.06325
Epoch [  8]: Loss 0.05714
Validation: Loss 0.06301 Accuracy 1.00000
Validation: Loss 0.05901 Accuracy 1.00000
Epoch [  9]: Loss 0.05653
Epoch [  9]: Loss 0.05541
Epoch [  9]: Loss 0.05022
Epoch [  9]: Loss 0.05388
Epoch [  9]: Loss 0.04537
Epoch [  9]: Loss 0.04733
Epoch [  9]: Loss 0.03902
Validation: Loss 0.04679 Accuracy 1.00000
Validation: Loss 0.04392 Accuracy 1.00000
Epoch [ 10]: Loss 0.04261
Epoch [ 10]: Loss 0.04384
Epoch [ 10]: Loss 0.04056
Epoch [ 10]: Loss 0.03736
Epoch [ 10]: Loss 0.03694
Epoch [ 10]: Loss 0.03556
Epoch [ 10]: Loss 0.03602
Validation: Loss 0.03797 Accuracy 1.00000
Validation: Loss 0.03562 Accuracy 1.00000
Epoch [ 11]: Loss 0.03599
Epoch [ 11]: Loss 0.03258
Epoch [ 11]: Loss 0.03311
Epoch [ 11]: Loss 0.02910
Epoch [ 11]: Loss 0.03106
Epoch [ 11]: Loss 0.03245
Epoch [ 11]: Loss 0.03458
Validation: Loss 0.03233 Accuracy 1.00000
Validation: Loss 0.03028 Accuracy 1.00000
Epoch [ 12]: Loss 0.02909
Epoch [ 12]: Loss 0.02751
Epoch [ 12]: Loss 0.02837
Epoch [ 12]: Loss 0.02873
Epoch [ 12]: Loss 0.02771
Epoch [ 12]: Loss 0.02534
Epoch [ 12]: Loss 0.02918
Validation: Loss 0.02817 Accuracy 1.00000
Validation: Loss 0.02635 Accuracy 1.00000
Epoch [ 13]: Loss 0.02616
Epoch [ 13]: Loss 0.02504
Epoch [ 13]: Loss 0.02630
Epoch [ 13]: Loss 0.02434
Epoch [ 13]: Loss 0.02101
Epoch [ 13]: Loss 0.02270
Epoch [ 13]: Loss 0.02740
Validation: Loss 0.02493 Accuracy 1.00000
Validation: Loss 0.02329 Accuracy 1.00000
Epoch [ 14]: Loss 0.02073
Epoch [ 14]: Loss 0.02333
Epoch [ 14]: Loss 0.02201
Epoch [ 14]: Loss 0.02182
Epoch [ 14]: Loss 0.02147
Epoch [ 14]: Loss 0.02040
Epoch [ 14]: Loss 0.02156
Validation: Loss 0.02232 Accuracy 1.00000
Validation: Loss 0.02083 Accuracy 1.00000
Epoch [ 15]: Loss 0.01868
Epoch [ 15]: Loss 0.02054
Epoch [ 15]: Loss 0.01917
Epoch [ 15]: Loss 0.01907
Epoch [ 15]: Loss 0.01932
Epoch [ 15]: Loss 0.01954
Epoch [ 15]: Loss 0.01943
Validation: Loss 0.02018 Accuracy 1.00000
Validation: Loss 0.01880 Accuracy 1.00000
Epoch [ 16]: Loss 0.01833
Epoch [ 16]: Loss 0.01655
Epoch [ 16]: Loss 0.01930
Epoch [ 16]: Loss 0.01789
Epoch [ 16]: Loss 0.01685
Epoch [ 16]: Loss 0.01598
Epoch [ 16]: Loss 0.01893
Validation: Loss 0.01836 Accuracy 1.00000
Validation: Loss 0.01708 Accuracy 1.00000
Epoch [ 17]: Loss 0.01491
Epoch [ 17]: Loss 0.01472
Epoch [ 17]: Loss 0.01651
Epoch [ 17]: Loss 0.01525
Epoch [ 17]: Loss 0.01687
Epoch [ 17]: Loss 0.01694
Epoch [ 17]: Loss 0.01792
Validation: Loss 0.01682 Accuracy 1.00000
Validation: Loss 0.01563 Accuracy 1.00000
Epoch [ 18]: Loss 0.01594
Epoch [ 18]: Loss 0.01445
Epoch [ 18]: Loss 0.01500
Epoch [ 18]: Loss 0.01510
Epoch [ 18]: Loss 0.01408
Epoch [ 18]: Loss 0.01321
Epoch [ 18]: Loss 0.01482
Validation: Loss 0.01548 Accuracy 1.00000
Validation: Loss 0.01437 Accuracy 1.00000
Epoch [ 19]: Loss 0.01465
Epoch [ 19]: Loss 0.01424
Epoch [ 19]: Loss 0.01381
Epoch [ 19]: Loss 0.01365
Epoch [ 19]: Loss 0.01192
Epoch [ 19]: Loss 0.01257
Epoch [ 19]: Loss 0.01372
Validation: Loss 0.01432 Accuracy 1.00000
Validation: Loss 0.01328 Accuracy 1.00000
Epoch [ 20]: Loss 0.01330
Epoch [ 20]: Loss 0.01280
Epoch [ 20]: Loss 0.01228
Epoch [ 20]: Loss 0.01287
Epoch [ 20]: Loss 0.01214
Epoch [ 20]: Loss 0.01141
Epoch [ 20]: Loss 0.01272
Validation: Loss 0.01328 Accuracy 1.00000
Validation: Loss 0.01232 Accuracy 1.00000
Epoch [ 21]: Loss 0.01129
Epoch [ 21]: Loss 0.01198
Epoch [ 21]: Loss 0.01173
Epoch [ 21]: Loss 0.01149
Epoch [ 21]: Loss 0.01206
Epoch [ 21]: Loss 0.01079
Epoch [ 21]: Loss 0.01179
Validation: Loss 0.01231 Accuracy 1.00000
Validation: Loss 0.01142 Accuracy 1.00000
Epoch [ 22]: Loss 0.01088
Epoch [ 22]: Loss 0.01077
Epoch [ 22]: Loss 0.01140
Epoch [ 22]: Loss 0.01084
Epoch [ 22]: Loss 0.00999
Epoch [ 22]: Loss 0.01038
Epoch [ 22]: Loss 0.01045
Validation: Loss 0.01131 Accuracy 1.00000
Validation: Loss 0.01050 Accuracy 1.00000
Epoch [ 23]: Loss 0.01062
Epoch [ 23]: Loss 0.01028
Epoch [ 23]: Loss 0.01034
Epoch [ 23]: Loss 0.00990
Epoch [ 23]: Loss 0.00892
Epoch [ 23]: Loss 0.00904
Epoch [ 23]: Loss 0.00839
Validation: Loss 0.01018 Accuracy 1.00000
Validation: Loss 0.00946 Accuracy 1.00000
Epoch [ 24]: Loss 0.00878
Epoch [ 24]: Loss 0.00932
Epoch [ 24]: Loss 0.00901
Epoch [ 24]: Loss 0.00935
Epoch [ 24]: Loss 0.00712
Epoch [ 24]: Loss 0.00905
Epoch [ 24]: Loss 0.00872
Validation: Loss 0.00903 Accuracy 1.00000
Validation: Loss 0.00842 Accuracy 1.00000
Epoch [ 25]: Loss 0.00745
Epoch [ 25]: Loss 0.00749
Epoch [ 25]: Loss 0.00833
Epoch [ 25]: Loss 0.00854
Epoch [ 25]: Loss 0.00738
Epoch [ 25]: Loss 0.00796
Epoch [ 25]: Loss 0.00705
Validation: Loss 0.00813 Accuracy 1.00000
Validation: Loss 0.00759 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.62288
Epoch [  1]: Loss 0.60156
Epoch [  1]: Loss 0.56575
Epoch [  1]: Loss 0.54397
Epoch [  1]: Loss 0.50466
Epoch [  1]: Loss 0.50173
Epoch [  1]: Loss 0.47170
Validation: Loss 0.46604 Accuracy 1.00000
Validation: Loss 0.48339 Accuracy 1.00000
Epoch [  2]: Loss 0.46482
Epoch [  2]: Loss 0.45374
Epoch [  2]: Loss 0.43084
Epoch [  2]: Loss 0.42236
Epoch [  2]: Loss 0.41458
Epoch [  2]: Loss 0.40944
Epoch [  2]: Loss 0.37447
Validation: Loss 0.36841 Accuracy 1.00000
Validation: Loss 0.39020 Accuracy 1.00000
Epoch [  3]: Loss 0.37004
Epoch [  3]: Loss 0.36218
Epoch [  3]: Loss 0.34469
Epoch [  3]: Loss 0.33805
Epoch [  3]: Loss 0.31394
Epoch [  3]: Loss 0.30801
Epoch [  3]: Loss 0.28145
Validation: Loss 0.28322 Accuracy 1.00000
Validation: Loss 0.30696 Accuracy 1.00000
Epoch [  4]: Loss 0.27850
Epoch [  4]: Loss 0.28725
Epoch [  4]: Loss 0.26155
Epoch [  4]: Loss 0.25966
Epoch [  4]: Loss 0.24059
Epoch [  4]: Loss 0.23101
Epoch [  4]: Loss 0.21410
Validation: Loss 0.21337 Accuracy 1.00000
Validation: Loss 0.23700 Accuracy 1.00000
Epoch [  5]: Loss 0.21875
Epoch [  5]: Loss 0.19822
Epoch [  5]: Loss 0.20377
Epoch [  5]: Loss 0.18605
Epoch [  5]: Loss 0.19454
Epoch [  5]: Loss 0.17456
Epoch [  5]: Loss 0.13679
Validation: Loss 0.15817 Accuracy 1.00000
Validation: Loss 0.17965 Accuracy 1.00000
Epoch [  6]: Loss 0.14921
Epoch [  6]: Loss 0.15577
Epoch [  6]: Loss 0.15232
Epoch [  6]: Loss 0.13977
Epoch [  6]: Loss 0.12848
Epoch [  6]: Loss 0.13338
Epoch [  6]: Loss 0.14178
Validation: Loss 0.11593 Accuracy 1.00000
Validation: Loss 0.13357 Accuracy 1.00000
Epoch [  7]: Loss 0.12252
Epoch [  7]: Loss 0.11735
Epoch [  7]: Loss 0.10403
Epoch [  7]: Loss 0.09437
Epoch [  7]: Loss 0.09209
Epoch [  7]: Loss 0.09785
Epoch [  7]: Loss 0.09238
Validation: Loss 0.08257 Accuracy 1.00000
Validation: Loss 0.09541 Accuracy 1.00000
Epoch [  8]: Loss 0.08162
Epoch [  8]: Loss 0.08032
Epoch [  8]: Loss 0.07694
Epoch [  8]: Loss 0.06816
Epoch [  8]: Loss 0.06543
Epoch [  8]: Loss 0.06746
Epoch [  8]: Loss 0.07005
Validation: Loss 0.05738 Accuracy 1.00000
Validation: Loss 0.06581 Accuracy 1.00000
Epoch [  9]: Loss 0.05757
Epoch [  9]: Loss 0.05529
Epoch [  9]: Loss 0.05748
Epoch [  9]: Loss 0.04998
Epoch [  9]: Loss 0.04759
Epoch [  9]: Loss 0.04576
Epoch [  9]: Loss 0.03935
Validation: Loss 0.04279 Accuracy 1.00000
Validation: Loss 0.04867 Accuracy 1.00000
Epoch [ 10]: Loss 0.04467
Epoch [ 10]: Loss 0.04392
Epoch [ 10]: Loss 0.03815
Epoch [ 10]: Loss 0.03728
Epoch [ 10]: Loss 0.03651
Epoch [ 10]: Loss 0.04009
Epoch [ 10]: Loss 0.03388
Validation: Loss 0.03471 Accuracy 1.00000
Validation: Loss 0.03950 Accuracy 1.00000
Epoch [ 11]: Loss 0.03537
Epoch [ 11]: Loss 0.03502
Epoch [ 11]: Loss 0.03277
Epoch [ 11]: Loss 0.03312
Epoch [ 11]: Loss 0.03344
Epoch [ 11]: Loss 0.02930
Epoch [ 11]: Loss 0.02759
Validation: Loss 0.02949 Accuracy 1.00000
Validation: Loss 0.03364 Accuracy 1.00000
Epoch [ 12]: Loss 0.02901
Epoch [ 12]: Loss 0.02869
Epoch [ 12]: Loss 0.02981
Epoch [ 12]: Loss 0.02728
Epoch [ 12]: Loss 0.02793
Epoch [ 12]: Loss 0.02718
Epoch [ 12]: Loss 0.02664
Validation: Loss 0.02568 Accuracy 1.00000
Validation: Loss 0.02938 Accuracy 1.00000
Epoch [ 13]: Loss 0.02727
Epoch [ 13]: Loss 0.02457
Epoch [ 13]: Loss 0.02457
Epoch [ 13]: Loss 0.02482
Epoch [ 13]: Loss 0.02480
Epoch [ 13]: Loss 0.02331
Epoch [ 13]: Loss 0.02188
Validation: Loss 0.02272 Accuracy 1.00000
Validation: Loss 0.02605 Accuracy 1.00000
Epoch [ 14]: Loss 0.02234
Epoch [ 14]: Loss 0.02277
Epoch [ 14]: Loss 0.02421
Epoch [ 14]: Loss 0.02029
Epoch [ 14]: Loss 0.02074
Epoch [ 14]: Loss 0.02138
Epoch [ 14]: Loss 0.02334
Validation: Loss 0.02034 Accuracy 1.00000
Validation: Loss 0.02339 Accuracy 1.00000
Epoch [ 15]: Loss 0.01998
Epoch [ 15]: Loss 0.01987
Epoch [ 15]: Loss 0.02031
Epoch [ 15]: Loss 0.02163
Epoch [ 15]: Loss 0.01918
Epoch [ 15]: Loss 0.01856
Epoch [ 15]: Loss 0.01618
Validation: Loss 0.01835 Accuracy 1.00000
Validation: Loss 0.02116 Accuracy 1.00000
Epoch [ 16]: Loss 0.01870
Epoch [ 16]: Loss 0.01830
Epoch [ 16]: Loss 0.01737
Epoch [ 16]: Loss 0.01729
Epoch [ 16]: Loss 0.01886
Epoch [ 16]: Loss 0.01732
Epoch [ 16]: Loss 0.01596
Validation: Loss 0.01668 Accuracy 1.00000
Validation: Loss 0.01929 Accuracy 1.00000
Epoch [ 17]: Loss 0.01733
Epoch [ 17]: Loss 0.01601
Epoch [ 17]: Loss 0.01736
Epoch [ 17]: Loss 0.01575
Epoch [ 17]: Loss 0.01554
Epoch [ 17]: Loss 0.01626
Epoch [ 17]: Loss 0.01457
Validation: Loss 0.01525 Accuracy 1.00000
Validation: Loss 0.01768 Accuracy 1.00000
Epoch [ 18]: Loss 0.01630
Epoch [ 18]: Loss 0.01551
Epoch [ 18]: Loss 0.01460
Epoch [ 18]: Loss 0.01413
Epoch [ 18]: Loss 0.01483
Epoch [ 18]: Loss 0.01380
Epoch [ 18]: Loss 0.01671
Validation: Loss 0.01401 Accuracy 1.00000
Validation: Loss 0.01628 Accuracy 1.00000
Epoch [ 19]: Loss 0.01485
Epoch [ 19]: Loss 0.01456
Epoch [ 19]: Loss 0.01383
Epoch [ 19]: Loss 0.01368
Epoch [ 19]: Loss 0.01365
Epoch [ 19]: Loss 0.01271
Epoch [ 19]: Loss 0.01071
Validation: Loss 0.01292 Accuracy 1.00000
Validation: Loss 0.01504 Accuracy 1.00000
Epoch [ 20]: Loss 0.01371
Epoch [ 20]: Loss 0.01331
Epoch [ 20]: Loss 0.01202
Epoch [ 20]: Loss 0.01214
Epoch [ 20]: Loss 0.01261
Epoch [ 20]: Loss 0.01248
Epoch [ 20]: Loss 0.01235
Validation: Loss 0.01195 Accuracy 1.00000
Validation: Loss 0.01393 Accuracy 1.00000
Epoch [ 21]: Loss 0.01281
Epoch [ 21]: Loss 0.01124
Epoch [ 21]: Loss 0.01193
Epoch [ 21]: Loss 0.01111
Epoch [ 21]: Loss 0.01108
Epoch [ 21]: Loss 0.01188
Epoch [ 21]: Loss 0.01318
Validation: Loss 0.01102 Accuracy 1.00000
Validation: Loss 0.01284 Accuracy 1.00000
Epoch [ 22]: Loss 0.01170
Epoch [ 22]: Loss 0.01155
Epoch [ 22]: Loss 0.00926
Epoch [ 22]: Loss 0.00976
Epoch [ 22]: Loss 0.01090
Epoch [ 22]: Loss 0.01156
Epoch [ 22]: Loss 0.01042
Validation: Loss 0.01002 Accuracy 1.00000
Validation: Loss 0.01165 Accuracy 1.00000
Epoch [ 23]: Loss 0.01043
Epoch [ 23]: Loss 0.00959
Epoch [ 23]: Loss 0.01012
Epoch [ 23]: Loss 0.00912
Epoch [ 23]: Loss 0.00986
Epoch [ 23]: Loss 0.00960
Epoch [ 23]: Loss 0.00821
Validation: Loss 0.00893 Accuracy 1.00000
Validation: Loss 0.01034 Accuracy 1.00000
Epoch [ 24]: Loss 0.00843
Epoch [ 24]: Loss 0.00871
Epoch [ 24]: Loss 0.00873
Epoch [ 24]: Loss 0.00883
Epoch [ 24]: Loss 0.00828
Epoch [ 24]: Loss 0.00899
Epoch [ 24]: Loss 0.00795
Validation: Loss 0.00797 Accuracy 1.00000
Validation: Loss 0.00918 Accuracy 1.00000
Epoch [ 25]: Loss 0.00871
Epoch [ 25]: Loss 0.00772
Epoch [ 25]: Loss 0.00776
Epoch [ 25]: Loss 0.00744
Epoch [ 25]: Loss 0.00720
Epoch [ 25]: Loss 0.00773
Epoch [ 25]: Loss 0.00802
Validation: Loss 0.00727 Accuracy 1.00000
Validation: Loss 0.00834 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0

Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.453 GiB / 4.750 GiB available)

This page was generated using Literate.jl.