Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = gpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.62228
Epoch [  1]: Loss 0.58571
Epoch [  1]: Loss 0.56973
Epoch [  1]: Loss 0.54320
Epoch [  1]: Loss 0.52264
Epoch [  1]: Loss 0.49863
Epoch [  1]: Loss 0.50773
Validation: Loss 0.46027 Accuracy 1.00000
Validation: Loss 0.47027 Accuracy 1.00000
Epoch [  2]: Loss 0.47163
Epoch [  2]: Loss 0.45379
Epoch [  2]: Loss 0.44359
Epoch [  2]: Loss 0.43374
Epoch [  2]: Loss 0.39947
Epoch [  2]: Loss 0.40825
Epoch [  2]: Loss 0.38567
Validation: Loss 0.36253 Accuracy 1.00000
Validation: Loss 0.37395 Accuracy 1.00000
Epoch [  3]: Loss 0.37715
Epoch [  3]: Loss 0.35866
Epoch [  3]: Loss 0.35465
Epoch [  3]: Loss 0.33274
Epoch [  3]: Loss 0.32503
Epoch [  3]: Loss 0.31172
Epoch [  3]: Loss 0.28375
Validation: Loss 0.27681 Accuracy 1.00000
Validation: Loss 0.28966 Accuracy 1.00000
Epoch [  4]: Loss 0.29246
Epoch [  4]: Loss 0.27398
Epoch [  4]: Loss 0.25867
Epoch [  4]: Loss 0.26764
Epoch [  4]: Loss 0.25732
Epoch [  4]: Loss 0.23090
Epoch [  4]: Loss 0.23431
Validation: Loss 0.20692 Accuracy 1.00000
Validation: Loss 0.22037 Accuracy 1.00000
Epoch [  5]: Loss 0.22287
Epoch [  5]: Loss 0.20974
Epoch [  5]: Loss 0.20163
Epoch [  5]: Loss 0.19467
Epoch [  5]: Loss 0.18014
Epoch [  5]: Loss 0.18608
Epoch [  5]: Loss 0.17436
Validation: Loss 0.15213 Accuracy 1.00000
Validation: Loss 0.16489 Accuracy 1.00000
Epoch [  6]: Loss 0.16569
Epoch [  6]: Loss 0.15666
Epoch [  6]: Loss 0.15815
Epoch [  6]: Loss 0.13737
Epoch [  6]: Loss 0.13432
Epoch [  6]: Loss 0.14018
Epoch [  6]: Loss 0.11618
Validation: Loss 0.11088 Accuracy 1.00000
Validation: Loss 0.12165 Accuracy 1.00000
Epoch [  7]: Loss 0.12116
Epoch [  7]: Loss 0.12185
Epoch [  7]: Loss 0.11245
Epoch [  7]: Loss 0.10089
Epoch [  7]: Loss 0.10328
Epoch [  7]: Loss 0.09578
Epoch [  7]: Loss 0.07751
Validation: Loss 0.07946 Accuracy 1.00000
Validation: Loss 0.08757 Accuracy 1.00000
Epoch [  8]: Loss 0.08631
Epoch [  8]: Loss 0.07977
Epoch [  8]: Loss 0.07880
Epoch [  8]: Loss 0.07758
Epoch [  8]: Loss 0.06987
Epoch [  8]: Loss 0.06798
Epoch [  8]: Loss 0.07269
Validation: Loss 0.05564 Accuracy 1.00000
Validation: Loss 0.06111 Accuracy 1.00000
Epoch [  9]: Loss 0.06551
Epoch [  9]: Loss 0.05592
Epoch [  9]: Loss 0.05594
Epoch [  9]: Loss 0.05511
Epoch [  9]: Loss 0.05059
Epoch [  9]: Loss 0.04331
Epoch [  9]: Loss 0.04586
Validation: Loss 0.04106 Accuracy 1.00000
Validation: Loss 0.04490 Accuracy 1.00000
Epoch [ 10]: Loss 0.04347
Epoch [ 10]: Loss 0.04279
Epoch [ 10]: Loss 0.04173
Epoch [ 10]: Loss 0.04170
Epoch [ 10]: Loss 0.03853
Epoch [ 10]: Loss 0.04025
Epoch [ 10]: Loss 0.03299
Validation: Loss 0.03307 Accuracy 1.00000
Validation: Loss 0.03621 Accuracy 1.00000
Epoch [ 11]: Loss 0.03841
Epoch [ 11]: Loss 0.03623
Epoch [ 11]: Loss 0.03336
Epoch [ 11]: Loss 0.03336
Epoch [ 11]: Loss 0.03205
Epoch [ 11]: Loss 0.03028
Epoch [ 11]: Loss 0.03195
Validation: Loss 0.02795 Accuracy 1.00000
Validation: Loss 0.03069 Accuracy 1.00000
Epoch [ 12]: Loss 0.03101
Epoch [ 12]: Loss 0.03167
Epoch [ 12]: Loss 0.02849
Epoch [ 12]: Loss 0.02917
Epoch [ 12]: Loss 0.02664
Epoch [ 12]: Loss 0.02826
Epoch [ 12]: Loss 0.02401
Validation: Loss 0.02426 Accuracy 1.00000
Validation: Loss 0.02670 Accuracy 1.00000
Epoch [ 13]: Loss 0.02837
Epoch [ 13]: Loss 0.02402
Epoch [ 13]: Loss 0.02581
Epoch [ 13]: Loss 0.02668
Epoch [ 13]: Loss 0.02360
Epoch [ 13]: Loss 0.02448
Epoch [ 13]: Loss 0.02282
Validation: Loss 0.02141 Accuracy 1.00000
Validation: Loss 0.02362 Accuracy 1.00000
Epoch [ 14]: Loss 0.02389
Epoch [ 14]: Loss 0.02247
Epoch [ 14]: Loss 0.02272
Epoch [ 14]: Loss 0.02337
Epoch [ 14]: Loss 0.02260
Epoch [ 14]: Loss 0.02167
Epoch [ 14]: Loss 0.01731
Validation: Loss 0.01914 Accuracy 1.00000
Validation: Loss 0.02116 Accuracy 1.00000
Epoch [ 15]: Loss 0.02064
Epoch [ 15]: Loss 0.02071
Epoch [ 15]: Loss 0.02003
Epoch [ 15]: Loss 0.02031
Epoch [ 15]: Loss 0.02182
Epoch [ 15]: Loss 0.01812
Epoch [ 15]: Loss 0.02050
Validation: Loss 0.01726 Accuracy 1.00000
Validation: Loss 0.01913 Accuracy 1.00000
Epoch [ 16]: Loss 0.01952
Epoch [ 16]: Loss 0.01856
Epoch [ 16]: Loss 0.01918
Epoch [ 16]: Loss 0.01822
Epoch [ 16]: Loss 0.01782
Epoch [ 16]: Loss 0.01707
Epoch [ 16]: Loss 0.01800
Validation: Loss 0.01566 Accuracy 1.00000
Validation: Loss 0.01739 Accuracy 1.00000
Epoch [ 17]: Loss 0.01710
Epoch [ 17]: Loss 0.01705
Epoch [ 17]: Loss 0.01684
Epoch [ 17]: Loss 0.01678
Epoch [ 17]: Loss 0.01609
Epoch [ 17]: Loss 0.01666
Epoch [ 17]: Loss 0.01646
Validation: Loss 0.01429 Accuracy 1.00000
Validation: Loss 0.01591 Accuracy 1.00000
Epoch [ 18]: Loss 0.01580
Epoch [ 18]: Loss 0.01545
Epoch [ 18]: Loss 0.01527
Epoch [ 18]: Loss 0.01554
Epoch [ 18]: Loss 0.01417
Epoch [ 18]: Loss 0.01573
Epoch [ 18]: Loss 0.01572
Validation: Loss 0.01311 Accuracy 1.00000
Validation: Loss 0.01463 Accuracy 1.00000
Epoch [ 19]: Loss 0.01424
Epoch [ 19]: Loss 0.01548
Epoch [ 19]: Loss 0.01488
Epoch [ 19]: Loss 0.01334
Epoch [ 19]: Loss 0.01387
Epoch [ 19]: Loss 0.01249
Epoch [ 19]: Loss 0.01609
Validation: Loss 0.01209 Accuracy 1.00000
Validation: Loss 0.01350 Accuracy 1.00000
Epoch [ 20]: Loss 0.01353
Epoch [ 20]: Loss 0.01326
Epoch [ 20]: Loss 0.01413
Epoch [ 20]: Loss 0.01271
Epoch [ 20]: Loss 0.01257
Epoch [ 20]: Loss 0.01241
Epoch [ 20]: Loss 0.01193
Validation: Loss 0.01117 Accuracy 1.00000
Validation: Loss 0.01248 Accuracy 1.00000
Epoch [ 21]: Loss 0.01276
Epoch [ 21]: Loss 0.01187
Epoch [ 21]: Loss 0.01248
Epoch [ 21]: Loss 0.01212
Epoch [ 21]: Loss 0.01186
Epoch [ 21]: Loss 0.01194
Epoch [ 21]: Loss 0.00963
Validation: Loss 0.01031 Accuracy 1.00000
Validation: Loss 0.01153 Accuracy 1.00000
Epoch [ 22]: Loss 0.01066
Epoch [ 22]: Loss 0.01131
Epoch [ 22]: Loss 0.01095
Epoch [ 22]: Loss 0.01103
Epoch [ 22]: Loss 0.01155
Epoch [ 22]: Loss 0.01138
Epoch [ 22]: Loss 0.01006
Validation: Loss 0.00943 Accuracy 1.00000
Validation: Loss 0.01053 Accuracy 1.00000
Epoch [ 23]: Loss 0.00964
Epoch [ 23]: Loss 0.01152
Epoch [ 23]: Loss 0.00906
Epoch [ 23]: Loss 0.01027
Epoch [ 23]: Loss 0.00973
Epoch [ 23]: Loss 0.01020
Epoch [ 23]: Loss 0.00980
Validation: Loss 0.00844 Accuracy 1.00000
Validation: Loss 0.00939 Accuracy 1.00000
Epoch [ 24]: Loss 0.01021
Epoch [ 24]: Loss 0.00935
Epoch [ 24]: Loss 0.00970
Epoch [ 24]: Loss 0.00858
Epoch [ 24]: Loss 0.00770
Epoch [ 24]: Loss 0.00805
Epoch [ 24]: Loss 0.00922
Validation: Loss 0.00750 Accuracy 1.00000
Validation: Loss 0.00831 Accuracy 1.00000
Epoch [ 25]: Loss 0.00772
Epoch [ 25]: Loss 0.00850
Epoch [ 25]: Loss 0.00815
Epoch [ 25]: Loss 0.00818
Epoch [ 25]: Loss 0.00772
Epoch [ 25]: Loss 0.00719
Epoch [ 25]: Loss 0.00914
Validation: Loss 0.00682 Accuracy 1.00000
Validation: Loss 0.00753 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.61509
Epoch [  1]: Loss 0.60108
Epoch [  1]: Loss 0.55682
Epoch [  1]: Loss 0.54901
Epoch [  1]: Loss 0.49920
Epoch [  1]: Loss 0.51514
Epoch [  1]: Loss 0.51064
Validation: Loss 0.46906 Accuracy 1.00000
Validation: Loss 0.46767 Accuracy 1.00000
Epoch [  2]: Loss 0.47772
Epoch [  2]: Loss 0.45928
Epoch [  2]: Loss 0.43215
Epoch [  2]: Loss 0.42920
Epoch [  2]: Loss 0.41722
Epoch [  2]: Loss 0.39954
Epoch [  2]: Loss 0.38827
Validation: Loss 0.37399 Accuracy 1.00000
Validation: Loss 0.37236 Accuracy 1.00000
Epoch [  3]: Loss 0.36867
Epoch [  3]: Loss 0.35577
Epoch [  3]: Loss 0.34842
Epoch [  3]: Loss 0.34695
Epoch [  3]: Loss 0.32450
Epoch [  3]: Loss 0.31801
Epoch [  3]: Loss 0.32518
Validation: Loss 0.28976 Accuracy 1.00000
Validation: Loss 0.28805 Accuracy 1.00000
Epoch [  4]: Loss 0.28978
Epoch [  4]: Loss 0.27647
Epoch [  4]: Loss 0.26276
Epoch [  4]: Loss 0.25823
Epoch [  4]: Loss 0.25967
Epoch [  4]: Loss 0.24144
Epoch [  4]: Loss 0.25356
Validation: Loss 0.22011 Accuracy 1.00000
Validation: Loss 0.21852 Accuracy 1.00000
Epoch [  5]: Loss 0.22836
Epoch [  5]: Loss 0.21406
Epoch [  5]: Loss 0.20441
Epoch [  5]: Loss 0.19362
Epoch [  5]: Loss 0.18782
Epoch [  5]: Loss 0.17750
Epoch [  5]: Loss 0.17369
Validation: Loss 0.16410 Accuracy 1.00000
Validation: Loss 0.16279 Accuracy 1.00000
Epoch [  6]: Loss 0.17219
Epoch [  6]: Loss 0.15211
Epoch [  6]: Loss 0.15563
Epoch [  6]: Loss 0.14199
Epoch [  6]: Loss 0.14171
Epoch [  6]: Loss 0.13298
Epoch [  6]: Loss 0.12676
Validation: Loss 0.12054 Accuracy 1.00000
Validation: Loss 0.11958 Accuracy 1.00000
Epoch [  7]: Loss 0.12465
Epoch [  7]: Loss 0.11327
Epoch [  7]: Loss 0.10910
Epoch [  7]: Loss 0.09648
Epoch [  7]: Loss 0.10936
Epoch [  7]: Loss 0.10177
Epoch [  7]: Loss 0.08885
Validation: Loss 0.08635 Accuracy 1.00000
Validation: Loss 0.08568 Accuracy 1.00000
Epoch [  8]: Loss 0.09532
Epoch [  8]: Loss 0.08611
Epoch [  8]: Loss 0.07121
Epoch [  8]: Loss 0.07892
Epoch [  8]: Loss 0.06568
Epoch [  8]: Loss 0.06907
Epoch [  8]: Loss 0.05675
Validation: Loss 0.05999 Accuracy 1.00000
Validation: Loss 0.05956 Accuracy 1.00000
Epoch [  9]: Loss 0.05736
Epoch [  9]: Loss 0.05612
Epoch [  9]: Loss 0.05023
Epoch [  9]: Loss 0.05699
Epoch [  9]: Loss 0.05182
Epoch [  9]: Loss 0.05200
Epoch [  9]: Loss 0.04850
Validation: Loss 0.04473 Accuracy 1.00000
Validation: Loss 0.04442 Accuracy 1.00000
Epoch [ 10]: Loss 0.04791
Epoch [ 10]: Loss 0.04287
Epoch [ 10]: Loss 0.04373
Epoch [ 10]: Loss 0.03918
Epoch [ 10]: Loss 0.03873
Epoch [ 10]: Loss 0.03829
Epoch [ 10]: Loss 0.04021
Validation: Loss 0.03626 Accuracy 1.00000
Validation: Loss 0.03600 Accuracy 1.00000
Epoch [ 11]: Loss 0.03885
Epoch [ 11]: Loss 0.03412
Epoch [ 11]: Loss 0.03265
Epoch [ 11]: Loss 0.03425
Epoch [ 11]: Loss 0.03439
Epoch [ 11]: Loss 0.03421
Epoch [ 11]: Loss 0.02723
Validation: Loss 0.03079 Accuracy 1.00000
Validation: Loss 0.03057 Accuracy 1.00000
Epoch [ 12]: Loss 0.03052
Epoch [ 12]: Loss 0.03016
Epoch [ 12]: Loss 0.02903
Epoch [ 12]: Loss 0.02899
Epoch [ 12]: Loss 0.03049
Epoch [ 12]: Loss 0.02760
Epoch [ 12]: Loss 0.03159
Validation: Loss 0.02682 Accuracy 1.00000
Validation: Loss 0.02662 Accuracy 1.00000
Epoch [ 13]: Loss 0.03079
Epoch [ 13]: Loss 0.02633
Epoch [ 13]: Loss 0.02706
Epoch [ 13]: Loss 0.02541
Epoch [ 13]: Loss 0.02263
Epoch [ 13]: Loss 0.02406
Epoch [ 13]: Loss 0.02344
Validation: Loss 0.02371 Accuracy 1.00000
Validation: Loss 0.02353 Accuracy 1.00000
Epoch [ 14]: Loss 0.02772
Epoch [ 14]: Loss 0.02266
Epoch [ 14]: Loss 0.02230
Epoch [ 14]: Loss 0.02132
Epoch [ 14]: Loss 0.02350
Epoch [ 14]: Loss 0.02066
Epoch [ 14]: Loss 0.02339
Validation: Loss 0.02123 Accuracy 1.00000
Validation: Loss 0.02107 Accuracy 1.00000
Epoch [ 15]: Loss 0.02165
Epoch [ 15]: Loss 0.02132
Epoch [ 15]: Loss 0.02082
Epoch [ 15]: Loss 0.02090
Epoch [ 15]: Loss 0.01924
Epoch [ 15]: Loss 0.02049
Epoch [ 15]: Loss 0.01966
Validation: Loss 0.01918 Accuracy 1.00000
Validation: Loss 0.01903 Accuracy 1.00000
Epoch [ 16]: Loss 0.02048
Epoch [ 16]: Loss 0.01983
Epoch [ 16]: Loss 0.01884
Epoch [ 16]: Loss 0.01768
Epoch [ 16]: Loss 0.01726
Epoch [ 16]: Loss 0.01914
Epoch [ 16]: Loss 0.01576
Validation: Loss 0.01743 Accuracy 1.00000
Validation: Loss 0.01729 Accuracy 1.00000
Epoch [ 17]: Loss 0.01949
Epoch [ 17]: Loss 0.01637
Epoch [ 17]: Loss 0.01750
Epoch [ 17]: Loss 0.01575
Epoch [ 17]: Loss 0.01605
Epoch [ 17]: Loss 0.01714
Epoch [ 17]: Loss 0.01758
Validation: Loss 0.01592 Accuracy 1.00000
Validation: Loss 0.01579 Accuracy 1.00000
Epoch [ 18]: Loss 0.01689
Epoch [ 18]: Loss 0.01667
Epoch [ 18]: Loss 0.01582
Epoch [ 18]: Loss 0.01393
Epoch [ 18]: Loss 0.01451
Epoch [ 18]: Loss 0.01568
Epoch [ 18]: Loss 0.01611
Validation: Loss 0.01456 Accuracy 1.00000
Validation: Loss 0.01444 Accuracy 1.00000
Epoch [ 19]: Loss 0.01441
Epoch [ 19]: Loss 0.01467
Epoch [ 19]: Loss 0.01491
Epoch [ 19]: Loss 0.01301
Epoch [ 19]: Loss 0.01454
Epoch [ 19]: Loss 0.01463
Epoch [ 19]: Loss 0.01197
Validation: Loss 0.01328 Accuracy 1.00000
Validation: Loss 0.01318 Accuracy 1.00000
Epoch [ 20]: Loss 0.01368
Epoch [ 20]: Loss 0.01418
Epoch [ 20]: Loss 0.01122
Epoch [ 20]: Loss 0.01274
Epoch [ 20]: Loss 0.01410
Epoch [ 20]: Loss 0.01211
Epoch [ 20]: Loss 0.01230
Validation: Loss 0.01197 Accuracy 1.00000
Validation: Loss 0.01188 Accuracy 1.00000
Epoch [ 21]: Loss 0.01184
Epoch [ 21]: Loss 0.01250
Epoch [ 21]: Loss 0.01192
Epoch [ 21]: Loss 0.01045
Epoch [ 21]: Loss 0.01207
Epoch [ 21]: Loss 0.01107
Epoch [ 21]: Loss 0.01045
Validation: Loss 0.01055 Accuracy 1.00000
Validation: Loss 0.01047 Accuracy 1.00000
Epoch [ 22]: Loss 0.01058
Epoch [ 22]: Loss 0.01037
Epoch [ 22]: Loss 0.01077
Epoch [ 22]: Loss 0.00986
Epoch [ 22]: Loss 0.01028
Epoch [ 22]: Loss 0.00908
Epoch [ 22]: Loss 0.01101
Validation: Loss 0.00932 Accuracy 1.00000
Validation: Loss 0.00925 Accuracy 1.00000
Epoch [ 23]: Loss 0.00994
Epoch [ 23]: Loss 0.00873
Epoch [ 23]: Loss 0.00921
Epoch [ 23]: Loss 0.00933
Epoch [ 23]: Loss 0.00895
Epoch [ 23]: Loss 0.00843
Epoch [ 23]: Loss 0.00866
Validation: Loss 0.00844 Accuracy 1.00000
Validation: Loss 0.00838 Accuracy 1.00000
Epoch [ 24]: Loss 0.00869
Epoch [ 24]: Loss 0.00867
Epoch [ 24]: Loss 0.00778
Epoch [ 24]: Loss 0.00815
Epoch [ 24]: Loss 0.00829
Epoch [ 24]: Loss 0.00843
Epoch [ 24]: Loss 0.00738
Validation: Loss 0.00781 Accuracy 1.00000
Validation: Loss 0.00775 Accuracy 1.00000
Epoch [ 25]: Loss 0.00832
Epoch [ 25]: Loss 0.00760
Epoch [ 25]: Loss 0.00802
Epoch [ 25]: Loss 0.00745
Epoch [ 25]: Loss 0.00780
Epoch [ 25]: Loss 0.00725
Epoch [ 25]: Loss 0.00717
Validation: Loss 0.00732 Accuracy 1.00000
Validation: Loss 0.00726 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.4+0

Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.422 GiB / 4.750 GiB available)

This page was generated using Literate.jl.