Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = gpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.61453
Epoch [  1]: Loss 0.59243
Epoch [  1]: Loss 0.56526
Epoch [  1]: Loss 0.54178
Epoch [  1]: Loss 0.52307
Epoch [  1]: Loss 0.49751
Epoch [  1]: Loss 0.48688
Validation: Loss 0.46892 Accuracy 1.00000
Validation: Loss 0.46603 Accuracy 1.00000
Epoch [  2]: Loss 0.45749
Epoch [  2]: Loss 0.45752
Epoch [  2]: Loss 0.44836
Epoch [  2]: Loss 0.42228
Epoch [  2]: Loss 0.41130
Epoch [  2]: Loss 0.40114
Epoch [  2]: Loss 0.37467
Validation: Loss 0.37161 Accuracy 1.00000
Validation: Loss 0.36814 Accuracy 1.00000
Epoch [  3]: Loss 0.38008
Epoch [  3]: Loss 0.35219
Epoch [  3]: Loss 0.34750
Epoch [  3]: Loss 0.33244
Epoch [  3]: Loss 0.32540
Epoch [  3]: Loss 0.30633
Epoch [  3]: Loss 0.26119
Validation: Loss 0.28683 Accuracy 1.00000
Validation: Loss 0.28290 Accuracy 1.00000
Epoch [  4]: Loss 0.27922
Epoch [  4]: Loss 0.28308
Epoch [  4]: Loss 0.27231
Epoch [  4]: Loss 0.24842
Epoch [  4]: Loss 0.24707
Epoch [  4]: Loss 0.23129
Epoch [  4]: Loss 0.22081
Validation: Loss 0.21783 Accuracy 1.00000
Validation: Loss 0.21373 Accuracy 1.00000
Epoch [  5]: Loss 0.21684
Epoch [  5]: Loss 0.20931
Epoch [  5]: Loss 0.21209
Epoch [  5]: Loss 0.19558
Epoch [  5]: Loss 0.17426
Epoch [  5]: Loss 0.16805
Epoch [  5]: Loss 0.16822
Validation: Loss 0.16272 Accuracy 1.00000
Validation: Loss 0.15894 Accuracy 1.00000
Epoch [  6]: Loss 0.16020
Epoch [  6]: Loss 0.14623
Epoch [  6]: Loss 0.13986
Epoch [  6]: Loss 0.15242
Epoch [  6]: Loss 0.13891
Epoch [  6]: Loss 0.12584
Epoch [  6]: Loss 0.15504
Validation: Loss 0.11991 Accuracy 1.00000
Validation: Loss 0.11681 Accuracy 1.00000
Epoch [  7]: Loss 0.11641
Epoch [  7]: Loss 0.12243
Epoch [  7]: Loss 0.11302
Epoch [  7]: Loss 0.10348
Epoch [  7]: Loss 0.09517
Epoch [  7]: Loss 0.08958
Epoch [  7]: Loss 0.07990
Validation: Loss 0.08583 Accuracy 1.00000
Validation: Loss 0.08359 Accuracy 1.00000
Epoch [  8]: Loss 0.08263
Epoch [  8]: Loss 0.08199
Epoch [  8]: Loss 0.08579
Epoch [  8]: Loss 0.07038
Epoch [  8]: Loss 0.06689
Epoch [  8]: Loss 0.06069
Epoch [  8]: Loss 0.07326
Validation: Loss 0.05976 Accuracy 1.00000
Validation: Loss 0.05830 Accuracy 1.00000
Epoch [  9]: Loss 0.06201
Epoch [  9]: Loss 0.06054
Epoch [  9]: Loss 0.05309
Epoch [  9]: Loss 0.05027
Epoch [  9]: Loss 0.04856
Epoch [  9]: Loss 0.04365
Epoch [  9]: Loss 0.04361
Validation: Loss 0.04394 Accuracy 1.00000
Validation: Loss 0.04295 Accuracy 1.00000
Epoch [ 10]: Loss 0.04252
Epoch [ 10]: Loss 0.04446
Epoch [ 10]: Loss 0.04069
Epoch [ 10]: Loss 0.04039
Epoch [ 10]: Loss 0.03668
Epoch [ 10]: Loss 0.03685
Epoch [ 10]: Loss 0.03488
Validation: Loss 0.03540 Accuracy 1.00000
Validation: Loss 0.03460 Accuracy 1.00000
Epoch [ 11]: Loss 0.03621
Epoch [ 11]: Loss 0.03394
Epoch [ 11]: Loss 0.03340
Epoch [ 11]: Loss 0.03121
Epoch [ 11]: Loss 0.03415
Epoch [ 11]: Loss 0.02958
Epoch [ 11]: Loss 0.02935
Validation: Loss 0.02998 Accuracy 1.00000
Validation: Loss 0.02929 Accuracy 1.00000
Epoch [ 12]: Loss 0.02952
Epoch [ 12]: Loss 0.02671
Epoch [ 12]: Loss 0.02943
Epoch [ 12]: Loss 0.03006
Epoch [ 12]: Loss 0.02760
Epoch [ 12]: Loss 0.02620
Epoch [ 12]: Loss 0.02636
Validation: Loss 0.02607 Accuracy 1.00000
Validation: Loss 0.02545 Accuracy 1.00000
Epoch [ 13]: Loss 0.02239
Epoch [ 13]: Loss 0.02604
Epoch [ 13]: Loss 0.02621
Epoch [ 13]: Loss 0.02447
Epoch [ 13]: Loss 0.02466
Epoch [ 13]: Loss 0.02445
Epoch [ 13]: Loss 0.02324
Validation: Loss 0.02305 Accuracy 1.00000
Validation: Loss 0.02249 Accuracy 1.00000
Epoch [ 14]: Loss 0.02541
Epoch [ 14]: Loss 0.02190
Epoch [ 14]: Loss 0.02094
Epoch [ 14]: Loss 0.02080
Epoch [ 14]: Loss 0.02159
Epoch [ 14]: Loss 0.02090
Epoch [ 14]: Loss 0.02148
Validation: Loss 0.02061 Accuracy 1.00000
Validation: Loss 0.02009 Accuracy 1.00000
Epoch [ 15]: Loss 0.02206
Epoch [ 15]: Loss 0.01957
Epoch [ 15]: Loss 0.02051
Epoch [ 15]: Loss 0.01829
Epoch [ 15]: Loss 0.01838
Epoch [ 15]: Loss 0.01949
Epoch [ 15]: Loss 0.01803
Validation: Loss 0.01859 Accuracy 1.00000
Validation: Loss 0.01811 Accuracy 1.00000
Epoch [ 16]: Loss 0.01826
Epoch [ 16]: Loss 0.01751
Epoch [ 16]: Loss 0.01808
Epoch [ 16]: Loss 0.01733
Epoch [ 16]: Loss 0.01791
Epoch [ 16]: Loss 0.01719
Epoch [ 16]: Loss 0.01883
Validation: Loss 0.01689 Accuracy 1.00000
Validation: Loss 0.01645 Accuracy 1.00000
Epoch [ 17]: Loss 0.01596
Epoch [ 17]: Loss 0.01667
Epoch [ 17]: Loss 0.01585
Epoch [ 17]: Loss 0.01639
Epoch [ 17]: Loss 0.01620
Epoch [ 17]: Loss 0.01630
Epoch [ 17]: Loss 0.01489
Validation: Loss 0.01545 Accuracy 1.00000
Validation: Loss 0.01504 Accuracy 1.00000
Epoch [ 18]: Loss 0.01691
Epoch [ 18]: Loss 0.01426
Epoch [ 18]: Loss 0.01493
Epoch [ 18]: Loss 0.01461
Epoch [ 18]: Loss 0.01430
Epoch [ 18]: Loss 0.01398
Epoch [ 18]: Loss 0.01476
Validation: Loss 0.01421 Accuracy 1.00000
Validation: Loss 0.01383 Accuracy 1.00000
Epoch [ 19]: Loss 0.01530
Epoch [ 19]: Loss 0.01417
Epoch [ 19]: Loss 0.01402
Epoch [ 19]: Loss 0.01294
Epoch [ 19]: Loss 0.01344
Epoch [ 19]: Loss 0.01226
Epoch [ 19]: Loss 0.01312
Validation: Loss 0.01313 Accuracy 1.00000
Validation: Loss 0.01277 Accuracy 1.00000
Epoch [ 20]: Loss 0.01338
Epoch [ 20]: Loss 0.01260
Epoch [ 20]: Loss 0.01260
Epoch [ 20]: Loss 0.01355
Epoch [ 20]: Loss 0.01232
Epoch [ 20]: Loss 0.01124
Epoch [ 20]: Loss 0.01310
Validation: Loss 0.01216 Accuracy 1.00000
Validation: Loss 0.01183 Accuracy 1.00000
Epoch [ 21]: Loss 0.01310
Epoch [ 21]: Loss 0.01260
Epoch [ 21]: Loss 0.01120
Epoch [ 21]: Loss 0.01152
Epoch [ 21]: Loss 0.01073
Epoch [ 21]: Loss 0.01149
Epoch [ 21]: Loss 0.00993
Validation: Loss 0.01123 Accuracy 1.00000
Validation: Loss 0.01093 Accuracy 1.00000
Epoch [ 22]: Loss 0.01198
Epoch [ 22]: Loss 0.01088
Epoch [ 22]: Loss 0.01106
Epoch [ 22]: Loss 0.01020
Epoch [ 22]: Loss 0.01039
Epoch [ 22]: Loss 0.01005
Epoch [ 22]: Loss 0.01112
Validation: Loss 0.01026 Accuracy 1.00000
Validation: Loss 0.00999 Accuracy 1.00000
Epoch [ 23]: Loss 0.00900
Epoch [ 23]: Loss 0.00974
Epoch [ 23]: Loss 0.00979
Epoch [ 23]: Loss 0.00970
Epoch [ 23]: Loss 0.00982
Epoch [ 23]: Loss 0.00994
Epoch [ 23]: Loss 0.01165
Validation: Loss 0.00918 Accuracy 1.00000
Validation: Loss 0.00894 Accuracy 1.00000
Epoch [ 24]: Loss 0.00874
Epoch [ 24]: Loss 0.00925
Epoch [ 24]: Loss 0.00887
Epoch [ 24]: Loss 0.00822
Epoch [ 24]: Loss 0.00786
Epoch [ 24]: Loss 0.00885
Epoch [ 24]: Loss 0.00966
Validation: Loss 0.00814 Accuracy 1.00000
Validation: Loss 0.00793 Accuracy 1.00000
Epoch [ 25]: Loss 0.00770
Epoch [ 25]: Loss 0.00836
Epoch [ 25]: Loss 0.00721
Epoch [ 25]: Loss 0.00798
Epoch [ 25]: Loss 0.00772
Epoch [ 25]: Loss 0.00756
Epoch [ 25]: Loss 0.00767
Validation: Loss 0.00738 Accuracy 1.00000
Validation: Loss 0.00720 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.62219
Epoch [  1]: Loss 0.59276
Epoch [  1]: Loss 0.55983
Epoch [  1]: Loss 0.54930
Epoch [  1]: Loss 0.51796
Epoch [  1]: Loss 0.50347
Epoch [  1]: Loss 0.48278
Validation: Loss 0.46349 Accuracy 1.00000
Validation: Loss 0.46455 Accuracy 1.00000
Epoch [  2]: Loss 0.46559
Epoch [  2]: Loss 0.45462
Epoch [  2]: Loss 0.43891
Epoch [  2]: Loss 0.43242
Epoch [  2]: Loss 0.41942
Epoch [  2]: Loss 0.39238
Epoch [  2]: Loss 0.37993
Validation: Loss 0.36509 Accuracy 1.00000
Validation: Loss 0.36619 Accuracy 1.00000
Epoch [  3]: Loss 0.37283
Epoch [  3]: Loss 0.35590
Epoch [  3]: Loss 0.34827
Epoch [  3]: Loss 0.32381
Epoch [  3]: Loss 0.32661
Epoch [  3]: Loss 0.31426
Epoch [  3]: Loss 0.29730
Validation: Loss 0.27915 Accuracy 1.00000
Validation: Loss 0.28035 Accuracy 1.00000
Epoch [  4]: Loss 0.28883
Epoch [  4]: Loss 0.27465
Epoch [  4]: Loss 0.26130
Epoch [  4]: Loss 0.24991
Epoch [  4]: Loss 0.24886
Epoch [  4]: Loss 0.23842
Epoch [  4]: Loss 0.22975
Validation: Loss 0.20896 Accuracy 1.00000
Validation: Loss 0.21025 Accuracy 1.00000
Epoch [  5]: Loss 0.21262
Epoch [  5]: Loss 0.20886
Epoch [  5]: Loss 0.19512
Epoch [  5]: Loss 0.19980
Epoch [  5]: Loss 0.18216
Epoch [  5]: Loss 0.17696
Epoch [  5]: Loss 0.16177
Validation: Loss 0.15372 Accuracy 1.00000
Validation: Loss 0.15510 Accuracy 1.00000
Epoch [  6]: Loss 0.16420
Epoch [  6]: Loss 0.14920
Epoch [  6]: Loss 0.14400
Epoch [  6]: Loss 0.15145
Epoch [  6]: Loss 0.12623
Epoch [  6]: Loss 0.13021
Epoch [  6]: Loss 0.13497
Validation: Loss 0.11202 Accuracy 1.00000
Validation: Loss 0.11330 Accuracy 1.00000
Epoch [  7]: Loss 0.11323
Epoch [  7]: Loss 0.11843
Epoch [  7]: Loss 0.10933
Epoch [  7]: Loss 0.10795
Epoch [  7]: Loss 0.09362
Epoch [  7]: Loss 0.09097
Epoch [  7]: Loss 0.08788
Validation: Loss 0.08015 Accuracy 1.00000
Validation: Loss 0.08113 Accuracy 1.00000
Epoch [  8]: Loss 0.08178
Epoch [  8]: Loss 0.07825
Epoch [  8]: Loss 0.07771
Epoch [  8]: Loss 0.07457
Epoch [  8]: Loss 0.06921
Epoch [  8]: Loss 0.06664
Epoch [  8]: Loss 0.06374
Validation: Loss 0.05604 Accuracy 1.00000
Validation: Loss 0.05669 Accuracy 1.00000
Epoch [  9]: Loss 0.05913
Epoch [  9]: Loss 0.05745
Epoch [  9]: Loss 0.05215
Epoch [  9]: Loss 0.05102
Epoch [  9]: Loss 0.05000
Epoch [  9]: Loss 0.04797
Epoch [  9]: Loss 0.03756
Validation: Loss 0.04138 Accuracy 1.00000
Validation: Loss 0.04184 Accuracy 1.00000
Epoch [ 10]: Loss 0.04689
Epoch [ 10]: Loss 0.04123
Epoch [ 10]: Loss 0.03834
Epoch [ 10]: Loss 0.04142
Epoch [ 10]: Loss 0.03869
Epoch [ 10]: Loss 0.03502
Epoch [ 10]: Loss 0.03346
Validation: Loss 0.03337 Accuracy 1.00000
Validation: Loss 0.03374 Accuracy 1.00000
Epoch [ 11]: Loss 0.03420
Epoch [ 11]: Loss 0.03600
Epoch [ 11]: Loss 0.03172
Epoch [ 11]: Loss 0.03482
Epoch [ 11]: Loss 0.03199
Epoch [ 11]: Loss 0.02977
Epoch [ 11]: Loss 0.02810
Validation: Loss 0.02824 Accuracy 1.00000
Validation: Loss 0.02857 Accuracy 1.00000
Epoch [ 12]: Loss 0.03090
Epoch [ 12]: Loss 0.03075
Epoch [ 12]: Loss 0.02543
Epoch [ 12]: Loss 0.02789
Epoch [ 12]: Loss 0.02710
Epoch [ 12]: Loss 0.02684
Epoch [ 12]: Loss 0.02873
Validation: Loss 0.02454 Accuracy 1.00000
Validation: Loss 0.02484 Accuracy 1.00000
Epoch [ 13]: Loss 0.02604
Epoch [ 13]: Loss 0.02703
Epoch [ 13]: Loss 0.02407
Epoch [ 13]: Loss 0.02322
Epoch [ 13]: Loss 0.02484
Epoch [ 13]: Loss 0.02322
Epoch [ 13]: Loss 0.02303
Validation: Loss 0.02167 Accuracy 1.00000
Validation: Loss 0.02193 Accuracy 1.00000
Epoch [ 14]: Loss 0.02342
Epoch [ 14]: Loss 0.02245
Epoch [ 14]: Loss 0.02209
Epoch [ 14]: Loss 0.02118
Epoch [ 14]: Loss 0.02029
Epoch [ 14]: Loss 0.02234
Epoch [ 14]: Loss 0.02059
Validation: Loss 0.01936 Accuracy 1.00000
Validation: Loss 0.01961 Accuracy 1.00000
Epoch [ 15]: Loss 0.02060
Epoch [ 15]: Loss 0.01960
Epoch [ 15]: Loss 0.02051
Epoch [ 15]: Loss 0.01953
Epoch [ 15]: Loss 0.01856
Epoch [ 15]: Loss 0.01947
Epoch [ 15]: Loss 0.01854
Validation: Loss 0.01745 Accuracy 1.00000
Validation: Loss 0.01768 Accuracy 1.00000
Epoch [ 16]: Loss 0.01760
Epoch [ 16]: Loss 0.01850
Epoch [ 16]: Loss 0.01906
Epoch [ 16]: Loss 0.01753
Epoch [ 16]: Loss 0.01791
Epoch [ 16]: Loss 0.01596
Epoch [ 16]: Loss 0.01871
Validation: Loss 0.01584 Accuracy 1.00000
Validation: Loss 0.01605 Accuracy 1.00000
Epoch [ 17]: Loss 0.01660
Epoch [ 17]: Loss 0.01579
Epoch [ 17]: Loss 0.01721
Epoch [ 17]: Loss 0.01626
Epoch [ 17]: Loss 0.01532
Epoch [ 17]: Loss 0.01556
Epoch [ 17]: Loss 0.01824
Validation: Loss 0.01446 Accuracy 1.00000
Validation: Loss 0.01466 Accuracy 1.00000
Epoch [ 18]: Loss 0.01460
Epoch [ 18]: Loss 0.01428
Epoch [ 18]: Loss 0.01620
Epoch [ 18]: Loss 0.01573
Epoch [ 18]: Loss 0.01485
Epoch [ 18]: Loss 0.01334
Epoch [ 18]: Loss 0.01517
Validation: Loss 0.01328 Accuracy 1.00000
Validation: Loss 0.01347 Accuracy 1.00000
Epoch [ 19]: Loss 0.01474
Epoch [ 19]: Loss 0.01405
Epoch [ 19]: Loss 0.01337
Epoch [ 19]: Loss 0.01412
Epoch [ 19]: Loss 0.01370
Epoch [ 19]: Loss 0.01278
Epoch [ 19]: Loss 0.01101
Validation: Loss 0.01226 Accuracy 1.00000
Validation: Loss 0.01244 Accuracy 1.00000
Epoch [ 20]: Loss 0.01337
Epoch [ 20]: Loss 0.01250
Epoch [ 20]: Loss 0.01249
Epoch [ 20]: Loss 0.01186
Epoch [ 20]: Loss 0.01282
Epoch [ 20]: Loss 0.01316
Epoch [ 20]: Loss 0.01161
Validation: Loss 0.01137 Accuracy 1.00000
Validation: Loss 0.01154 Accuracy 1.00000
Epoch [ 21]: Loss 0.01171
Epoch [ 21]: Loss 0.01183
Epoch [ 21]: Loss 0.01237
Epoch [ 21]: Loss 0.01192
Epoch [ 21]: Loss 0.01191
Epoch [ 21]: Loss 0.01035
Epoch [ 21]: Loss 0.01337
Validation: Loss 0.01055 Accuracy 1.00000
Validation: Loss 0.01070 Accuracy 1.00000
Epoch [ 22]: Loss 0.01153
Epoch [ 22]: Loss 0.01097
Epoch [ 22]: Loss 0.01116
Epoch [ 22]: Loss 0.01062
Epoch [ 22]: Loss 0.01069
Epoch [ 22]: Loss 0.01016
Epoch [ 22]: Loss 0.01153
Validation: Loss 0.00972 Accuracy 1.00000
Validation: Loss 0.00986 Accuracy 1.00000
Epoch [ 23]: Loss 0.01012
Epoch [ 23]: Loss 0.01085
Epoch [ 23]: Loss 0.00939
Epoch [ 23]: Loss 0.00972
Epoch [ 23]: Loss 0.00931
Epoch [ 23]: Loss 0.01062
Epoch [ 23]: Loss 0.00908
Validation: Loss 0.00880 Accuracy 1.00000
Validation: Loss 0.00893 Accuracy 1.00000
Epoch [ 24]: Loss 0.00890
Epoch [ 24]: Loss 0.01011
Epoch [ 24]: Loss 0.00911
Epoch [ 24]: Loss 0.00855
Epoch [ 24]: Loss 0.00827
Epoch [ 24]: Loss 0.00889
Epoch [ 24]: Loss 0.00833
Validation: Loss 0.00784 Accuracy 1.00000
Validation: Loss 0.00794 Accuracy 1.00000
Epoch [ 25]: Loss 0.00819
Epoch [ 25]: Loss 0.00844
Epoch [ 25]: Loss 0.00795
Epoch [ 25]: Loss 0.00784
Epoch [ 25]: Loss 0.00810
Epoch [ 25]: Loss 0.00725
Epoch [ 25]: Loss 0.00774
Validation: Loss 0.00704 Accuracy 1.00000
Validation: Loss 0.00714 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.4+0

Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.422 GiB / 4.750 GiB available)

This page was generated using Literate.jl.