Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
Precompiling Lux...
   2766.1 ms  ✓ GPUArrays
   1633.5 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   1753.1 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
  59930.9 ms  ✓ CUDA
   5735.1 ms  ✓ CUDA → ChainRulesCoreExt
   5786.6 ms  ✓ CUDA → EnzymeCoreExt
   6076.3 ms  ✓ CUDA → SpecialFunctionsExt
   5698.4 ms  ✓ ArrayInterface → ArrayInterfaceCUDAExt
   5738.0 ms  ✓ MLDataDevices → MLDataDevicesCUDAExt
   5876.0 ms  ✓ WeightInitializers → WeightInitializersCUDAExt
   6086.1 ms  ✓ NNlib → NNlibCUDAExt
   9077.5 ms  ✓ cuDNN
   5694.5 ms  ✓ MLDataDevices → MLDataDevicescuDNNExt
   5901.6 ms  ✓ NNlib → NNlibCUDACUDNNExt
   7804.4 ms  ✓ MLUtils
   2793.8 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
   6155.1 ms  ✓ LuxLib → LuxLibCUDAExt
   6392.0 ms  ✓ LuxLib → LuxLibcuDNNExt
  10777.6 ms  ✓ Lux
   3522.7 ms  ✓ Lux → LuxMLUtilsExt
   4244.8 ms  ✓ Lux → LuxZygoteExt
  21 dependencies successfully precompiled in 122 seconds. 231 already precompiled.
Precompiling LuxCUDA...
   5590.8 ms  ✓ LuxCUDA
  1 dependency successfully precompiled in 6 seconds. 121 already precompiled.
Precompiling JLD2...
  34948.1 ms  ✓ JLD2
  1 dependency successfully precompiled in 35 seconds. 31 already precompiled.

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##230".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = gpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.61355
Epoch [  1]: Loss 0.60122
Epoch [  1]: Loss 0.55937
Epoch [  1]: Loss 0.53879
Epoch [  1]: Loss 0.51808
Epoch [  1]: Loss 0.51605
Epoch [  1]: Loss 0.47831
Validation: Loss 0.46661 Accuracy 1.00000
Validation: Loss 0.45916 Accuracy 1.00000
Epoch [  2]: Loss 0.46628
Epoch [  2]: Loss 0.44990
Epoch [  2]: Loss 0.44128
Epoch [  2]: Loss 0.43255
Epoch [  2]: Loss 0.40714
Epoch [  2]: Loss 0.40766
Epoch [  2]: Loss 0.39937
Validation: Loss 0.36971 Accuracy 1.00000
Validation: Loss 0.36096 Accuracy 1.00000
Epoch [  3]: Loss 0.37687
Epoch [  3]: Loss 0.35752
Epoch [  3]: Loss 0.34448
Epoch [  3]: Loss 0.34105
Epoch [  3]: Loss 0.31895
Epoch [  3]: Loss 0.31300
Epoch [  3]: Loss 0.30025
Validation: Loss 0.28482 Accuracy 1.00000
Validation: Loss 0.27534 Accuracy 1.00000
Epoch [  4]: Loss 0.29736
Epoch [  4]: Loss 0.27661
Epoch [  4]: Loss 0.27365
Epoch [  4]: Loss 0.25206
Epoch [  4]: Loss 0.24862
Epoch [  4]: Loss 0.23127
Epoch [  4]: Loss 0.20953
Validation: Loss 0.21468 Accuracy 1.00000
Validation: Loss 0.20514 Accuracy 1.00000
Epoch [  5]: Loss 0.22380
Epoch [  5]: Loss 0.20012
Epoch [  5]: Loss 0.19739
Epoch [  5]: Loss 0.20192
Epoch [  5]: Loss 0.19808
Epoch [  5]: Loss 0.16450
Epoch [  5]: Loss 0.16835
Validation: Loss 0.15907 Accuracy 1.00000
Validation: Loss 0.15023 Accuracy 1.00000
Epoch [  6]: Loss 0.16398
Epoch [  6]: Loss 0.15826
Epoch [  6]: Loss 0.16086
Epoch [  6]: Loss 0.14523
Epoch [  6]: Loss 0.12280
Epoch [  6]: Loss 0.12723
Epoch [  6]: Loss 0.12907
Validation: Loss 0.11640 Accuracy 1.00000
Validation: Loss 0.10915 Accuracy 1.00000
Epoch [  7]: Loss 0.12678
Epoch [  7]: Loss 0.11296
Epoch [  7]: Loss 0.10677
Epoch [  7]: Loss 0.09760
Epoch [  7]: Loss 0.10082
Epoch [  7]: Loss 0.09412
Epoch [  7]: Loss 0.09540
Validation: Loss 0.08320 Accuracy 1.00000
Validation: Loss 0.07790 Accuracy 1.00000
Epoch [  8]: Loss 0.08385
Epoch [  8]: Loss 0.08116
Epoch [  8]: Loss 0.07856
Epoch [  8]: Loss 0.07251
Epoch [  8]: Loss 0.06865
Epoch [  8]: Loss 0.06813
Epoch [  8]: Loss 0.05979
Validation: Loss 0.05788 Accuracy 1.00000
Validation: Loss 0.05440 Accuracy 1.00000
Epoch [  9]: Loss 0.05526
Epoch [  9]: Loss 0.05986
Epoch [  9]: Loss 0.05482
Epoch [  9]: Loss 0.05024
Epoch [  9]: Loss 0.05068
Epoch [  9]: Loss 0.04910
Epoch [  9]: Loss 0.03947
Validation: Loss 0.04309 Accuracy 1.00000
Validation: Loss 0.04062 Accuracy 1.00000
Epoch [ 10]: Loss 0.04734
Epoch [ 10]: Loss 0.04295
Epoch [ 10]: Loss 0.04133
Epoch [ 10]: Loss 0.03836
Epoch [ 10]: Loss 0.03951
Epoch [ 10]: Loss 0.03558
Epoch [ 10]: Loss 0.03752
Validation: Loss 0.03489 Accuracy 1.00000
Validation: Loss 0.03285 Accuracy 1.00000
Epoch [ 11]: Loss 0.03542
Epoch [ 11]: Loss 0.03343
Epoch [ 11]: Loss 0.03485
Epoch [ 11]: Loss 0.03239
Epoch [ 11]: Loss 0.03243
Epoch [ 11]: Loss 0.03304
Epoch [ 11]: Loss 0.03280
Validation: Loss 0.02961 Accuracy 1.00000
Validation: Loss 0.02784 Accuracy 1.00000
Epoch [ 12]: Loss 0.03061
Epoch [ 12]: Loss 0.02998
Epoch [ 12]: Loss 0.02837
Epoch [ 12]: Loss 0.02854
Epoch [ 12]: Loss 0.02796
Epoch [ 12]: Loss 0.02674
Epoch [ 12]: Loss 0.03127
Validation: Loss 0.02574 Accuracy 1.00000
Validation: Loss 0.02417 Accuracy 1.00000
Epoch [ 13]: Loss 0.02820
Epoch [ 13]: Loss 0.02770
Epoch [ 13]: Loss 0.02764
Epoch [ 13]: Loss 0.02521
Epoch [ 13]: Loss 0.02223
Epoch [ 13]: Loss 0.02183
Epoch [ 13]: Loss 0.02022
Validation: Loss 0.02272 Accuracy 1.00000
Validation: Loss 0.02131 Accuracy 1.00000
Epoch [ 14]: Loss 0.02412
Epoch [ 14]: Loss 0.02214
Epoch [ 14]: Loss 0.02296
Epoch [ 14]: Loss 0.02144
Epoch [ 14]: Loss 0.02131
Epoch [ 14]: Loss 0.02155
Epoch [ 14]: Loss 0.02549
Validation: Loss 0.02034 Accuracy 1.00000
Validation: Loss 0.01905 Accuracy 1.00000
Epoch [ 15]: Loss 0.02016
Epoch [ 15]: Loss 0.02038
Epoch [ 15]: Loss 0.02045
Epoch [ 15]: Loss 0.02121
Epoch [ 15]: Loss 0.01968
Epoch [ 15]: Loss 0.01939
Epoch [ 15]: Loss 0.01760
Validation: Loss 0.01834 Accuracy 1.00000
Validation: Loss 0.01716 Accuracy 1.00000
Epoch [ 16]: Loss 0.01903
Epoch [ 16]: Loss 0.01876
Epoch [ 16]: Loss 0.01733
Epoch [ 16]: Loss 0.01890
Epoch [ 16]: Loss 0.01801
Epoch [ 16]: Loss 0.01808
Epoch [ 16]: Loss 0.01462
Validation: Loss 0.01666 Accuracy 1.00000
Validation: Loss 0.01556 Accuracy 1.00000
Epoch [ 17]: Loss 0.01729
Epoch [ 17]: Loss 0.01612
Epoch [ 17]: Loss 0.01670
Epoch [ 17]: Loss 0.01664
Epoch [ 17]: Loss 0.01657
Epoch [ 17]: Loss 0.01651
Epoch [ 17]: Loss 0.01512
Validation: Loss 0.01522 Accuracy 1.00000
Validation: Loss 0.01420 Accuracy 1.00000
Epoch [ 18]: Loss 0.01596
Epoch [ 18]: Loss 0.01561
Epoch [ 18]: Loss 0.01616
Epoch [ 18]: Loss 0.01452
Epoch [ 18]: Loss 0.01505
Epoch [ 18]: Loss 0.01361
Epoch [ 18]: Loss 0.01611
Validation: Loss 0.01397 Accuracy 1.00000
Validation: Loss 0.01302 Accuracy 1.00000
Epoch [ 19]: Loss 0.01425
Epoch [ 19]: Loss 0.01503
Epoch [ 19]: Loss 0.01391
Epoch [ 19]: Loss 0.01348
Epoch [ 19]: Loss 0.01437
Epoch [ 19]: Loss 0.01224
Epoch [ 19]: Loss 0.01583
Validation: Loss 0.01285 Accuracy 1.00000
Validation: Loss 0.01196 Accuracy 1.00000
Epoch [ 20]: Loss 0.01350
Epoch [ 20]: Loss 0.01295
Epoch [ 20]: Loss 0.01291
Epoch [ 20]: Loss 0.01313
Epoch [ 20]: Loss 0.01210
Epoch [ 20]: Loss 0.01280
Epoch [ 20]: Loss 0.01120
Validation: Loss 0.01178 Accuracy 1.00000
Validation: Loss 0.01097 Accuracy 1.00000
Epoch [ 21]: Loss 0.01177
Epoch [ 21]: Loss 0.01220
Epoch [ 21]: Loss 0.01308
Epoch [ 21]: Loss 0.01087
Epoch [ 21]: Loss 0.01095
Epoch [ 21]: Loss 0.01153
Epoch [ 21]: Loss 0.01146
Validation: Loss 0.01066 Accuracy 1.00000
Validation: Loss 0.00994 Accuracy 1.00000
Epoch [ 22]: Loss 0.01003
Epoch [ 22]: Loss 0.01039
Epoch [ 22]: Loss 0.01010
Epoch [ 22]: Loss 0.01073
Epoch [ 22]: Loss 0.01084
Epoch [ 22]: Loss 0.01084
Epoch [ 22]: Loss 0.01059
Validation: Loss 0.00944 Accuracy 1.00000
Validation: Loss 0.00883 Accuracy 1.00000
Epoch [ 23]: Loss 0.01069
Epoch [ 23]: Loss 0.00971
Epoch [ 23]: Loss 0.00869
Epoch [ 23]: Loss 0.00928
Epoch [ 23]: Loss 0.00962
Epoch [ 23]: Loss 0.00785
Epoch [ 23]: Loss 0.00896
Validation: Loss 0.00838 Accuracy 1.00000
Validation: Loss 0.00786 Accuracy 1.00000
Epoch [ 24]: Loss 0.00938
Epoch [ 24]: Loss 0.00810
Epoch [ 24]: Loss 0.00859
Epoch [ 24]: Loss 0.00738
Epoch [ 24]: Loss 0.00828
Epoch [ 24]: Loss 0.00831
Epoch [ 24]: Loss 0.00759
Validation: Loss 0.00763 Accuracy 1.00000
Validation: Loss 0.00716 Accuracy 1.00000
Epoch [ 25]: Loss 0.00798
Epoch [ 25]: Loss 0.00749
Epoch [ 25]: Loss 0.00779
Epoch [ 25]: Loss 0.00764
Epoch [ 25]: Loss 0.00749
Epoch [ 25]: Loss 0.00767
Epoch [ 25]: Loss 0.00627
Validation: Loss 0.00709 Accuracy 1.00000
Validation: Loss 0.00666 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.59817
Epoch [  1]: Loss 0.61437
Epoch [  1]: Loss 0.56622
Epoch [  1]: Loss 0.54189
Epoch [  1]: Loss 0.51737
Epoch [  1]: Loss 0.49618
Epoch [  1]: Loss 0.49495
Validation: Loss 0.47899 Accuracy 1.00000
Validation: Loss 0.47553 Accuracy 1.00000
Epoch [  2]: Loss 0.46698
Epoch [  2]: Loss 0.46119
Epoch [  2]: Loss 0.45414
Epoch [  2]: Loss 0.43124
Epoch [  2]: Loss 0.40566
Epoch [  2]: Loss 0.38524
Epoch [  2]: Loss 0.37662
Validation: Loss 0.38435 Accuracy 1.00000
Validation: Loss 0.38085 Accuracy 1.00000
Epoch [  3]: Loss 0.37390
Epoch [  3]: Loss 0.36331
Epoch [  3]: Loss 0.34397
Epoch [  3]: Loss 0.32278
Epoch [  3]: Loss 0.33548
Epoch [  3]: Loss 0.30345
Epoch [  3]: Loss 0.33287
Validation: Loss 0.30156 Accuracy 1.00000
Validation: Loss 0.29772 Accuracy 1.00000
Epoch [  4]: Loss 0.27521
Epoch [  4]: Loss 0.28071
Epoch [  4]: Loss 0.28083
Epoch [  4]: Loss 0.26859
Epoch [  4]: Loss 0.24388
Epoch [  4]: Loss 0.23667
Epoch [  4]: Loss 0.20741
Validation: Loss 0.23261 Accuracy 1.00000
Validation: Loss 0.22866 Accuracy 1.00000
Epoch [  5]: Loss 0.21795
Epoch [  5]: Loss 0.21692
Epoch [  5]: Loss 0.20361
Epoch [  5]: Loss 0.18745
Epoch [  5]: Loss 0.18196
Epoch [  5]: Loss 0.18741
Epoch [  5]: Loss 0.17958
Validation: Loss 0.17655 Accuracy 1.00000
Validation: Loss 0.17258 Accuracy 1.00000
Epoch [  6]: Loss 0.16314
Epoch [  6]: Loss 0.16417
Epoch [  6]: Loss 0.14852
Epoch [  6]: Loss 0.15359
Epoch [  6]: Loss 0.12981
Epoch [  6]: Loss 0.13258
Epoch [  6]: Loss 0.13318
Validation: Loss 0.13129 Accuracy 1.00000
Validation: Loss 0.12777 Accuracy 1.00000
Epoch [  7]: Loss 0.11685
Epoch [  7]: Loss 0.12432
Epoch [  7]: Loss 0.10753
Epoch [  7]: Loss 0.10374
Epoch [  7]: Loss 0.10754
Epoch [  7]: Loss 0.09146
Epoch [  7]: Loss 0.10040
Validation: Loss 0.09435 Accuracy 1.00000
Validation: Loss 0.09168 Accuracy 1.00000
Epoch [  8]: Loss 0.08305
Epoch [  8]: Loss 0.08316
Epoch [  8]: Loss 0.07832
Epoch [  8]: Loss 0.07814
Epoch [  8]: Loss 0.07008
Epoch [  8]: Loss 0.06555
Epoch [  8]: Loss 0.07635
Validation: Loss 0.06526 Accuracy 1.00000
Validation: Loss 0.06347 Accuracy 1.00000
Epoch [  9]: Loss 0.06184
Epoch [  9]: Loss 0.05639
Epoch [  9]: Loss 0.05512
Epoch [  9]: Loss 0.05228
Epoch [  9]: Loss 0.05140
Epoch [  9]: Loss 0.04724
Epoch [  9]: Loss 0.04700
Validation: Loss 0.04806 Accuracy 1.00000
Validation: Loss 0.04679 Accuracy 1.00000
Epoch [ 10]: Loss 0.04574
Epoch [ 10]: Loss 0.04227
Epoch [ 10]: Loss 0.04450
Epoch [ 10]: Loss 0.03964
Epoch [ 10]: Loss 0.03944
Epoch [ 10]: Loss 0.03695
Epoch [ 10]: Loss 0.03498
Validation: Loss 0.03876 Accuracy 1.00000
Validation: Loss 0.03771 Accuracy 1.00000
Epoch [ 11]: Loss 0.03502
Epoch [ 11]: Loss 0.03347
Epoch [ 11]: Loss 0.03354
Epoch [ 11]: Loss 0.03395
Epoch [ 11]: Loss 0.03252
Epoch [ 11]: Loss 0.03411
Epoch [ 11]: Loss 0.03504
Validation: Loss 0.03292 Accuracy 1.00000
Validation: Loss 0.03200 Accuracy 1.00000
Epoch [ 12]: Loss 0.03184
Epoch [ 12]: Loss 0.03221
Epoch [ 12]: Loss 0.02706
Epoch [ 12]: Loss 0.02729
Epoch [ 12]: Loss 0.02980
Epoch [ 12]: Loss 0.02648
Epoch [ 12]: Loss 0.02640
Validation: Loss 0.02861 Accuracy 1.00000
Validation: Loss 0.02780 Accuracy 1.00000
Epoch [ 13]: Loss 0.02523
Epoch [ 13]: Loss 0.02739
Epoch [ 13]: Loss 0.02618
Epoch [ 13]: Loss 0.02526
Epoch [ 13]: Loss 0.02435
Epoch [ 13]: Loss 0.02349
Epoch [ 13]: Loss 0.02597
Validation: Loss 0.02532 Accuracy 1.00000
Validation: Loss 0.02459 Accuracy 1.00000
Epoch [ 14]: Loss 0.02337
Epoch [ 14]: Loss 0.02298
Epoch [ 14]: Loss 0.02154
Epoch [ 14]: Loss 0.02370
Epoch [ 14]: Loss 0.02168
Epoch [ 14]: Loss 0.02094
Epoch [ 14]: Loss 0.02563
Validation: Loss 0.02268 Accuracy 1.00000
Validation: Loss 0.02201 Accuracy 1.00000
Epoch [ 15]: Loss 0.02049
Epoch [ 15]: Loss 0.01980
Epoch [ 15]: Loss 0.02240
Epoch [ 15]: Loss 0.01997
Epoch [ 15]: Loss 0.01759
Epoch [ 15]: Loss 0.02125
Epoch [ 15]: Loss 0.01877
Validation: Loss 0.02048 Accuracy 1.00000
Validation: Loss 0.01987 Accuracy 1.00000
Epoch [ 16]: Loss 0.01847
Epoch [ 16]: Loss 0.01965
Epoch [ 16]: Loss 0.01812
Epoch [ 16]: Loss 0.01737
Epoch [ 16]: Loss 0.01818
Epoch [ 16]: Loss 0.01827
Epoch [ 16]: Loss 0.01633
Validation: Loss 0.01862 Accuracy 1.00000
Validation: Loss 0.01806 Accuracy 1.00000
Epoch [ 17]: Loss 0.01711
Epoch [ 17]: Loss 0.01622
Epoch [ 17]: Loss 0.01781
Epoch [ 17]: Loss 0.01513
Epoch [ 17]: Loss 0.01721
Epoch [ 17]: Loss 0.01667
Epoch [ 17]: Loss 0.01487
Validation: Loss 0.01702 Accuracy 1.00000
Validation: Loss 0.01650 Accuracy 1.00000
Epoch [ 18]: Loss 0.01609
Epoch [ 18]: Loss 0.01480
Epoch [ 18]: Loss 0.01630
Epoch [ 18]: Loss 0.01564
Epoch [ 18]: Loss 0.01499
Epoch [ 18]: Loss 0.01363
Epoch [ 18]: Loss 0.01420
Validation: Loss 0.01560 Accuracy 1.00000
Validation: Loss 0.01512 Accuracy 1.00000
Epoch [ 19]: Loss 0.01473
Epoch [ 19]: Loss 0.01484
Epoch [ 19]: Loss 0.01451
Epoch [ 19]: Loss 0.01235
Epoch [ 19]: Loss 0.01359
Epoch [ 19]: Loss 0.01329
Epoch [ 19]: Loss 0.01463
Validation: Loss 0.01427 Accuracy 1.00000
Validation: Loss 0.01384 Accuracy 1.00000
Epoch [ 20]: Loss 0.01296
Epoch [ 20]: Loss 0.01158
Epoch [ 20]: Loss 0.01246
Epoch [ 20]: Loss 0.01361
Epoch [ 20]: Loss 0.01262
Epoch [ 20]: Loss 0.01244
Epoch [ 20]: Loss 0.01423
Validation: Loss 0.01291 Accuracy 1.00000
Validation: Loss 0.01253 Accuracy 1.00000
Epoch [ 21]: Loss 0.01245
Epoch [ 21]: Loss 0.01116
Epoch [ 21]: Loss 0.01088
Epoch [ 21]: Loss 0.01069
Epoch [ 21]: Loss 0.01219
Epoch [ 21]: Loss 0.01143
Epoch [ 21]: Loss 0.00963
Validation: Loss 0.01141 Accuracy 1.00000
Validation: Loss 0.01109 Accuracy 1.00000
Epoch [ 22]: Loss 0.01079
Epoch [ 22]: Loss 0.01060
Epoch [ 22]: Loss 0.01021
Epoch [ 22]: Loss 0.00991
Epoch [ 22]: Loss 0.00946
Epoch [ 22]: Loss 0.00958
Epoch [ 22]: Loss 0.00924
Validation: Loss 0.01005 Accuracy 1.00000
Validation: Loss 0.00976 Accuracy 1.00000
Epoch [ 23]: Loss 0.00944
Epoch [ 23]: Loss 0.00867
Epoch [ 23]: Loss 0.00923
Epoch [ 23]: Loss 0.00891
Epoch [ 23]: Loss 0.00863
Epoch [ 23]: Loss 0.00866
Epoch [ 23]: Loss 0.00930
Validation: Loss 0.00907 Accuracy 1.00000
Validation: Loss 0.00881 Accuracy 1.00000
Epoch [ 24]: Loss 0.00726
Epoch [ 24]: Loss 0.00911
Epoch [ 24]: Loss 0.00847
Epoch [ 24]: Loss 0.00837
Epoch [ 24]: Loss 0.00767
Epoch [ 24]: Loss 0.00796
Epoch [ 24]: Loss 0.00822
Validation: Loss 0.00836 Accuracy 1.00000
Validation: Loss 0.00812 Accuracy 1.00000
Epoch [ 25]: Loss 0.00807
Epoch [ 25]: Loss 0.00749
Epoch [ 25]: Loss 0.00707
Epoch [ 25]: Loss 0.00784
Epoch [ 25]: Loss 0.00728
Epoch [ 25]: Loss 0.00740
Epoch [ 25]: Loss 0.00823
Validation: Loss 0.00781 Accuracy 1.00000
Validation: Loss 0.00759 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.1
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.141 GiB / 4.750 GiB available)

This page was generated using Literate.jl.