Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
      Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(model, ps, st, (x, y))
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    # Get the dataloaders
    (train_loader, val_loader) = get_dataloaders()

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)

    dev = gpu_device()
    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(0.01f0); transform_variables=dev)

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            x = x |> dev
            y = y |> dev

            gs, loss, _, train_state = Lux.Experimental.compute_gradients(
                AutoZygote(), compute_loss, (x, y), train_state)
            train_state = Lux.Experimental.apply_gradients!(train_state, gs)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            x = x |> dev
            y = y |> dev
            loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
            acc = accuracy(ret.y_pred, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.56347
Epoch [  1]: Loss 0.51145
Epoch [  1]: Loss 0.46801
Epoch [  1]: Loss 0.43717
Epoch [  1]: Loss 0.42777
Epoch [  1]: Loss 0.41613
Epoch [  1]: Loss 0.39294
Validation: Loss 0.36223 Accuracy 1.00000
Validation: Loss 0.35466 Accuracy 1.00000
Epoch [  2]: Loss 0.37034
Epoch [  2]: Loss 0.34839
Epoch [  2]: Loss 0.33154
Epoch [  2]: Loss 0.31867
Epoch [  2]: Loss 0.30396
Epoch [  2]: Loss 0.28995
Epoch [  2]: Loss 0.25848
Validation: Loss 0.25423 Accuracy 1.00000
Validation: Loss 0.24992 Accuracy 1.00000
Epoch [  3]: Loss 0.25191
Epoch [  3]: Loss 0.24680
Epoch [  3]: Loss 0.23408
Epoch [  3]: Loss 0.22293
Epoch [  3]: Loss 0.21038
Epoch [  3]: Loss 0.19983
Epoch [  3]: Loss 0.18618
Validation: Loss 0.17805 Accuracy 1.00000
Validation: Loss 0.17579 Accuracy 1.00000
Epoch [  4]: Loss 0.17986
Epoch [  4]: Loss 0.16977
Epoch [  4]: Loss 0.16377
Epoch [  4]: Loss 0.15651
Epoch [  4]: Loss 0.14737
Epoch [  4]: Loss 0.14050
Epoch [  4]: Loss 0.13777
Validation: Loss 0.12716 Accuracy 1.00000
Validation: Loss 0.12574 Accuracy 1.00000
Epoch [  5]: Loss 0.12793
Epoch [  5]: Loss 0.12405
Epoch [  5]: Loss 0.11673
Epoch [  5]: Loss 0.11214
Epoch [  5]: Loss 0.10696
Epoch [  5]: Loss 0.10181
Epoch [  5]: Loss 0.09565
Validation: Loss 0.09239 Accuracy 1.00000
Validation: Loss 0.09119 Accuracy 1.00000
Epoch [  6]: Loss 0.09509
Epoch [  6]: Loss 0.09007
Epoch [  6]: Loss 0.08456
Epoch [  6]: Loss 0.08112
Epoch [  6]: Loss 0.07758
Epoch [  6]: Loss 0.07556
Epoch [  6]: Loss 0.07289
Validation: Loss 0.06787 Accuracy 1.00000
Validation: Loss 0.06661 Accuracy 1.00000
Epoch [  7]: Loss 0.06871
Epoch [  7]: Loss 0.06479
Epoch [  7]: Loss 0.06377
Epoch [  7]: Loss 0.05913
Epoch [  7]: Loss 0.05952
Epoch [  7]: Loss 0.05766
Epoch [  7]: Loss 0.05535
Validation: Loss 0.05037 Accuracy 1.00000
Validation: Loss 0.04910 Accuracy 1.00000
Epoch [  8]: Loss 0.05128
Epoch [  8]: Loss 0.04904
Epoch [  8]: Loss 0.04810
Epoch [  8]: Loss 0.04694
Epoch [  8]: Loss 0.04289
Epoch [  8]: Loss 0.04283
Epoch [  8]: Loss 0.03664
Validation: Loss 0.03776 Accuracy 1.00000
Validation: Loss 0.03655 Accuracy 1.00000
Epoch [  9]: Loss 0.04009
Epoch [  9]: Loss 0.03849
Epoch [  9]: Loss 0.03499
Epoch [  9]: Loss 0.03388
Epoch [  9]: Loss 0.03400
Epoch [  9]: Loss 0.03097
Epoch [  9]: Loss 0.03007
Validation: Loss 0.02874 Accuracy 1.00000
Validation: Loss 0.02760 Accuracy 1.00000
Epoch [ 10]: Loss 0.03043
Epoch [ 10]: Loss 0.02818
Epoch [ 10]: Loss 0.02806
Epoch [ 10]: Loss 0.02751
Epoch [ 10]: Loss 0.02577
Epoch [ 10]: Loss 0.02379
Epoch [ 10]: Loss 0.02582
Validation: Loss 0.02255 Accuracy 1.00000
Validation: Loss 0.02151 Accuracy 1.00000
Epoch [ 11]: Loss 0.02392
Epoch [ 11]: Loss 0.02307
Epoch [ 11]: Loss 0.02175
Epoch [ 11]: Loss 0.02130
Epoch [ 11]: Loss 0.01989
Epoch [ 11]: Loss 0.02140
Epoch [ 11]: Loss 0.01799
Validation: Loss 0.01833 Accuracy 1.00000
Validation: Loss 0.01742 Accuracy 1.00000
Epoch [ 12]: Loss 0.02028
Epoch [ 12]: Loss 0.01944
Epoch [ 12]: Loss 0.01678
Epoch [ 12]: Loss 0.01764
Epoch [ 12]: Loss 0.01681
Epoch [ 12]: Loss 0.01712
Epoch [ 12]: Loss 0.01630
Validation: Loss 0.01541 Accuracy 1.00000
Validation: Loss 0.01461 Accuracy 1.00000
Epoch [ 13]: Loss 0.01650
Epoch [ 13]: Loss 0.01495
Epoch [ 13]: Loss 0.01554
Epoch [ 13]: Loss 0.01560
Epoch [ 13]: Loss 0.01432
Epoch [ 13]: Loss 0.01464
Epoch [ 13]: Loss 0.01517
Validation: Loss 0.01331 Accuracy 1.00000
Validation: Loss 0.01260 Accuracy 1.00000
Epoch [ 14]: Loss 0.01419
Epoch [ 14]: Loss 0.01341
Epoch [ 14]: Loss 0.01290
Epoch [ 14]: Loss 0.01320
Epoch [ 14]: Loss 0.01342
Epoch [ 14]: Loss 0.01264
Epoch [ 14]: Loss 0.01342
Validation: Loss 0.01173 Accuracy 1.00000
Validation: Loss 0.01110 Accuracy 1.00000
Epoch [ 15]: Loss 0.01268
Epoch [ 15]: Loss 0.01271
Epoch [ 15]: Loss 0.01125
Epoch [ 15]: Loss 0.01180
Epoch [ 15]: Loss 0.01178
Epoch [ 15]: Loss 0.01081
Epoch [ 15]: Loss 0.01107
Validation: Loss 0.01050 Accuracy 1.00000
Validation: Loss 0.00993 Accuracy 1.00000
Epoch [ 16]: Loss 0.01111
Epoch [ 16]: Loss 0.01139
Epoch [ 16]: Loss 0.01085
Epoch [ 16]: Loss 0.01118
Epoch [ 16]: Loss 0.01018
Epoch [ 16]: Loss 0.00961
Epoch [ 16]: Loss 0.00825
Validation: Loss 0.00950 Accuracy 1.00000
Validation: Loss 0.00898 Accuracy 1.00000
Epoch [ 17]: Loss 0.01004
Epoch [ 17]: Loss 0.00980
Epoch [ 17]: Loss 0.00999
Epoch [ 17]: Loss 0.00968
Epoch [ 17]: Loss 0.00961
Epoch [ 17]: Loss 0.00909
Epoch [ 17]: Loss 0.00833
Validation: Loss 0.00867 Accuracy 1.00000
Validation: Loss 0.00819 Accuracy 1.00000
Epoch [ 18]: Loss 0.00988
Epoch [ 18]: Loss 0.00908
Epoch [ 18]: Loss 0.00878
Epoch [ 18]: Loss 0.00862
Epoch [ 18]: Loss 0.00850
Epoch [ 18]: Loss 0.00852
Epoch [ 18]: Loss 0.00740
Validation: Loss 0.00797 Accuracy 1.00000
Validation: Loss 0.00751 Accuracy 1.00000
Epoch [ 19]: Loss 0.00803
Epoch [ 19]: Loss 0.00825
Epoch [ 19]: Loss 0.00814
Epoch [ 19]: Loss 0.00818
Epoch [ 19]: Loss 0.00836
Epoch [ 19]: Loss 0.00780
Epoch [ 19]: Loss 0.00847
Validation: Loss 0.00737 Accuracy 1.00000
Validation: Loss 0.00694 Accuracy 1.00000
Epoch [ 20]: Loss 0.00766
Epoch [ 20]: Loss 0.00807
Epoch [ 20]: Loss 0.00743
Epoch [ 20]: Loss 0.00725
Epoch [ 20]: Loss 0.00772
Epoch [ 20]: Loss 0.00728
Epoch [ 20]: Loss 0.00699
Validation: Loss 0.00684 Accuracy 1.00000
Validation: Loss 0.00644 Accuracy 1.00000
Epoch [ 21]: Loss 0.00720
Epoch [ 21]: Loss 0.00735
Epoch [ 21]: Loss 0.00722
Epoch [ 21]: Loss 0.00698
Epoch [ 21]: Loss 0.00719
Epoch [ 21]: Loss 0.00629
Epoch [ 21]: Loss 0.00648
Validation: Loss 0.00637 Accuracy 1.00000
Validation: Loss 0.00600 Accuracy 1.00000
Epoch [ 22]: Loss 0.00621
Epoch [ 22]: Loss 0.00638
Epoch [ 22]: Loss 0.00701
Epoch [ 22]: Loss 0.00678
Epoch [ 22]: Loss 0.00627
Epoch [ 22]: Loss 0.00676
Epoch [ 22]: Loss 0.00603
Validation: Loss 0.00596 Accuracy 1.00000
Validation: Loss 0.00560 Accuracy 1.00000
Epoch [ 23]: Loss 0.00672
Epoch [ 23]: Loss 0.00609
Epoch [ 23]: Loss 0.00565
Epoch [ 23]: Loss 0.00630
Epoch [ 23]: Loss 0.00602
Epoch [ 23]: Loss 0.00605
Epoch [ 23]: Loss 0.00606
Validation: Loss 0.00559 Accuracy 1.00000
Validation: Loss 0.00526 Accuracy 1.00000
Epoch [ 24]: Loss 0.00616
Epoch [ 24]: Loss 0.00549
Epoch [ 24]: Loss 0.00580
Epoch [ 24]: Loss 0.00566
Epoch [ 24]: Loss 0.00614
Epoch [ 24]: Loss 0.00540
Epoch [ 24]: Loss 0.00552
Validation: Loss 0.00526 Accuracy 1.00000
Validation: Loss 0.00494 Accuracy 1.00000
Epoch [ 25]: Loss 0.00573
Epoch [ 25]: Loss 0.00563
Epoch [ 25]: Loss 0.00526
Epoch [ 25]: Loss 0.00494
Epoch [ 25]: Loss 0.00570
Epoch [ 25]: Loss 0.00538
Epoch [ 25]: Loss 0.00519
Validation: Loss 0.00496 Accuracy 1.00000
Validation: Loss 0.00466 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.56342
Epoch [  1]: Loss 0.51174
Epoch [  1]: Loss 0.47055
Epoch [  1]: Loss 0.44916
Epoch [  1]: Loss 0.42693
Epoch [  1]: Loss 0.39392
Epoch [  1]: Loss 0.38560
Validation: Loss 0.36276 Accuracy 1.00000
Validation: Loss 0.37331 Accuracy 1.00000
Epoch [  2]: Loss 0.36464
Epoch [  2]: Loss 0.34799
Epoch [  2]: Loss 0.33342
Epoch [  2]: Loss 0.31229
Epoch [  2]: Loss 0.29785
Epoch [  2]: Loss 0.28512
Epoch [  2]: Loss 0.28736
Validation: Loss 0.25380 Accuracy 1.00000
Validation: Loss 0.26008 Accuracy 1.00000
Epoch [  3]: Loss 0.25482
Epoch [  3]: Loss 0.24366
Epoch [  3]: Loss 0.23291
Epoch [  3]: Loss 0.21914
Epoch [  3]: Loss 0.20505
Epoch [  3]: Loss 0.19689
Epoch [  3]: Loss 0.19210
Validation: Loss 0.17712 Accuracy 1.00000
Validation: Loss 0.17983 Accuracy 1.00000
Epoch [  4]: Loss 0.17696
Epoch [  4]: Loss 0.16904
Epoch [  4]: Loss 0.16209
Epoch [  4]: Loss 0.15411
Epoch [  4]: Loss 0.14584
Epoch [  4]: Loss 0.13900
Epoch [  4]: Loss 0.13418
Validation: Loss 0.12602 Accuracy 1.00000
Validation: Loss 0.12763 Accuracy 1.00000
Epoch [  5]: Loss 0.12851
Epoch [  5]: Loss 0.12105
Epoch [  5]: Loss 0.11500
Epoch [  5]: Loss 0.10893
Epoch [  5]: Loss 0.10499
Epoch [  5]: Loss 0.10008
Epoch [  5]: Loss 0.09624
Validation: Loss 0.09114 Accuracy 1.00000
Validation: Loss 0.09274 Accuracy 1.00000
Epoch [  6]: Loss 0.09053
Epoch [  6]: Loss 0.08798
Epoch [  6]: Loss 0.08334
Epoch [  6]: Loss 0.08173
Epoch [  6]: Loss 0.07778
Epoch [  6]: Loss 0.07352
Epoch [  6]: Loss 0.06766
Validation: Loss 0.06682 Accuracy 1.00000
Validation: Loss 0.06869 Accuracy 1.00000
Epoch [  7]: Loss 0.06846
Epoch [  7]: Loss 0.06332
Epoch [  7]: Loss 0.06166
Epoch [  7]: Loss 0.06041
Epoch [  7]: Loss 0.05694
Epoch [  7]: Loss 0.05427
Epoch [  7]: Loss 0.05397
Validation: Loss 0.04964 Accuracy 1.00000
Validation: Loss 0.05162 Accuracy 1.00000
Epoch [  8]: Loss 0.05092
Epoch [  8]: Loss 0.04654
Epoch [  8]: Loss 0.04737
Epoch [  8]: Loss 0.04448
Epoch [  8]: Loss 0.04322
Epoch [  8]: Loss 0.04124
Epoch [  8]: Loss 0.03906
Validation: Loss 0.03729 Accuracy 1.00000
Validation: Loss 0.03920 Accuracy 1.00000
Epoch [  9]: Loss 0.03813
Epoch [  9]: Loss 0.03605
Epoch [  9]: Loss 0.03716
Epoch [  9]: Loss 0.03281
Epoch [  9]: Loss 0.03181
Epoch [  9]: Loss 0.03139
Epoch [  9]: Loss 0.03033
Validation: Loss 0.02846 Accuracy 1.00000
Validation: Loss 0.03023 Accuracy 1.00000
Epoch [ 10]: Loss 0.02829
Epoch [ 10]: Loss 0.02778
Epoch [ 10]: Loss 0.02547
Epoch [ 10]: Loss 0.02622
Epoch [ 10]: Loss 0.02728
Epoch [ 10]: Loss 0.02509
Epoch [ 10]: Loss 0.02426
Validation: Loss 0.02238 Accuracy 1.00000
Validation: Loss 0.02397 Accuracy 1.00000
Epoch [ 11]: Loss 0.02300
Epoch [ 11]: Loss 0.02166
Epoch [ 11]: Loss 0.02187
Epoch [ 11]: Loss 0.02083
Epoch [ 11]: Loss 0.02104
Epoch [ 11]: Loss 0.01967
Epoch [ 11]: Loss 0.01868
Validation: Loss 0.01821 Accuracy 1.00000
Validation: Loss 0.01958 Accuracy 1.00000
Epoch [ 12]: Loss 0.01887
Epoch [ 12]: Loss 0.01748
Epoch [ 12]: Loss 0.01753
Epoch [ 12]: Loss 0.01770
Epoch [ 12]: Loss 0.01643
Epoch [ 12]: Loss 0.01707
Epoch [ 12]: Loss 0.01751
Validation: Loss 0.01531 Accuracy 1.00000
Validation: Loss 0.01651 Accuracy 1.00000
Epoch [ 13]: Loss 0.01597
Epoch [ 13]: Loss 0.01597
Epoch [ 13]: Loss 0.01527
Epoch [ 13]: Loss 0.01417
Epoch [ 13]: Loss 0.01457
Epoch [ 13]: Loss 0.01362
Epoch [ 13]: Loss 0.01429
Validation: Loss 0.01323 Accuracy 1.00000
Validation: Loss 0.01427 Accuracy 1.00000
Epoch [ 14]: Loss 0.01391
Epoch [ 14]: Loss 0.01363
Epoch [ 14]: Loss 0.01304
Epoch [ 14]: Loss 0.01307
Epoch [ 14]: Loss 0.01237
Epoch [ 14]: Loss 0.01215
Epoch [ 14]: Loss 0.01187
Validation: Loss 0.01167 Accuracy 1.00000
Validation: Loss 0.01259 Accuracy 1.00000
Epoch [ 15]: Loss 0.01143
Epoch [ 15]: Loss 0.01206
Epoch [ 15]: Loss 0.01203
Epoch [ 15]: Loss 0.01157
Epoch [ 15]: Loss 0.01072
Epoch [ 15]: Loss 0.01116
Epoch [ 15]: Loss 0.01213
Validation: Loss 0.01046 Accuracy 1.00000
Validation: Loss 0.01129 Accuracy 1.00000
Epoch [ 16]: Loss 0.01069
Epoch [ 16]: Loss 0.01114
Epoch [ 16]: Loss 0.01030
Epoch [ 16]: Loss 0.00990
Epoch [ 16]: Loss 0.01003
Epoch [ 16]: Loss 0.01008
Epoch [ 16]: Loss 0.01070
Validation: Loss 0.00947 Accuracy 1.00000
Validation: Loss 0.01024 Accuracy 1.00000
Epoch [ 17]: Loss 0.00955
Epoch [ 17]: Loss 0.00912
Epoch [ 17]: Loss 0.00908
Epoch [ 17]: Loss 0.01029
Epoch [ 17]: Loss 0.00887
Epoch [ 17]: Loss 0.00972
Epoch [ 17]: Loss 0.00901
Validation: Loss 0.00865 Accuracy 1.00000
Validation: Loss 0.00935 Accuracy 1.00000
Epoch [ 18]: Loss 0.00884
Epoch [ 18]: Loss 0.00880
Epoch [ 18]: Loss 0.00863
Epoch [ 18]: Loss 0.00864
Epoch [ 18]: Loss 0.00874
Epoch [ 18]: Loss 0.00845
Epoch [ 18]: Loss 0.00738
Validation: Loss 0.00795 Accuracy 1.00000
Validation: Loss 0.00861 Accuracy 1.00000
Epoch [ 19]: Loss 0.00844
Epoch [ 19]: Loss 0.00816
Epoch [ 19]: Loss 0.00782
Epoch [ 19]: Loss 0.00827
Epoch [ 19]: Loss 0.00736
Epoch [ 19]: Loss 0.00805
Epoch [ 19]: Loss 0.00654
Validation: Loss 0.00735 Accuracy 1.00000
Validation: Loss 0.00796 Accuracy 1.00000
Epoch [ 20]: Loss 0.00792
Epoch [ 20]: Loss 0.00719
Epoch [ 20]: Loss 0.00741
Epoch [ 20]: Loss 0.00701
Epoch [ 20]: Loss 0.00763
Epoch [ 20]: Loss 0.00691
Epoch [ 20]: Loss 0.00797
Validation: Loss 0.00683 Accuracy 1.00000
Validation: Loss 0.00741 Accuracy 1.00000
Epoch [ 21]: Loss 0.00655
Epoch [ 21]: Loss 0.00744
Epoch [ 21]: Loss 0.00738
Epoch [ 21]: Loss 0.00648
Epoch [ 21]: Loss 0.00658
Epoch [ 21]: Loss 0.00664
Epoch [ 21]: Loss 0.00723
Validation: Loss 0.00637 Accuracy 1.00000
Validation: Loss 0.00691 Accuracy 1.00000
Epoch [ 22]: Loss 0.00686
Epoch [ 22]: Loss 0.00684
Epoch [ 22]: Loss 0.00601
Epoch [ 22]: Loss 0.00633
Epoch [ 22]: Loss 0.00598
Epoch [ 22]: Loss 0.00634
Epoch [ 22]: Loss 0.00676
Validation: Loss 0.00596 Accuracy 1.00000
Validation: Loss 0.00647 Accuracy 1.00000
Epoch [ 23]: Loss 0.00623
Epoch [ 23]: Loss 0.00632
Epoch [ 23]: Loss 0.00549
Epoch [ 23]: Loss 0.00607
Epoch [ 23]: Loss 0.00598
Epoch [ 23]: Loss 0.00593
Epoch [ 23]: Loss 0.00599
Validation: Loss 0.00559 Accuracy 1.00000
Validation: Loss 0.00607 Accuracy 1.00000
Epoch [ 24]: Loss 0.00616
Epoch [ 24]: Loss 0.00550
Epoch [ 24]: Loss 0.00566
Epoch [ 24]: Loss 0.00580
Epoch [ 24]: Loss 0.00537
Epoch [ 24]: Loss 0.00541
Epoch [ 24]: Loss 0.00546
Validation: Loss 0.00526 Accuracy 1.00000
Validation: Loss 0.00571 Accuracy 1.00000
Epoch [ 25]: Loss 0.00582
Epoch [ 25]: Loss 0.00523
Epoch [ 25]: Loss 0.00553
Epoch [ 25]: Loss 0.00479
Epoch [ 25]: Loss 0.00527
Epoch [ 25]: Loss 0.00525
Epoch [ 25]: Loss 0.00530
Validation: Loss 0.00497 Accuracy 1.00000
Validation: Loss 0.00539 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" {compress = true} ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 8 default, 0 interactive, 4 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 8
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.4, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.4.5
- CURAND: 10.3.5
- CUFFT: 11.2.1
- CUSOLVER: 11.6.1
- CUSPARSE: 12.3.1
- CUPTI: 22.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.3.4
- CUDA_Driver_jll: 0.8.1+0
- CUDA_Runtime_jll: 0.12.1+0

Toolchain:
- Julia: 1.10.3
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.391 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.