Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
      Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(model, ps, st, (x, y))
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    # Get the dataloaders
    (train_loader, val_loader) = get_dataloaders()

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)

    dev = gpu_device()
    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(0.01f0); transform_variables=dev)

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            x = x |> dev
            y = y |> dev

            gs, loss, _, train_state = Lux.Experimental.compute_gradients(
                AutoZygote(), compute_loss, (x, y), train_state)
            train_state = Lux.Experimental.apply_gradients!(train_state, gs)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            x = x |> dev
            y = y |> dev
            loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
            acc = accuracy(ret.y_pred, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.56230
Epoch [  1]: Loss 0.50787
Epoch [  1]: Loss 0.47183
Epoch [  1]: Loss 0.44831
Epoch [  1]: Loss 0.42515
Epoch [  1]: Loss 0.39968
Epoch [  1]: Loss 0.38847
Validation: Loss 0.37128 Accuracy 1.00000
Validation: Loss 0.36232 Accuracy 1.00000
Epoch [  2]: Loss 0.36567
Epoch [  2]: Loss 0.35083
Epoch [  2]: Loss 0.32326
Epoch [  2]: Loss 0.31313
Epoch [  2]: Loss 0.29613
Epoch [  2]: Loss 0.29390
Epoch [  2]: Loss 0.26999
Validation: Loss 0.25869 Accuracy 1.00000
Validation: Loss 0.25399 Accuracy 1.00000
Epoch [  3]: Loss 0.25840
Epoch [  3]: Loss 0.23970
Epoch [  3]: Loss 0.22939
Epoch [  3]: Loss 0.22172
Epoch [  3]: Loss 0.20551
Epoch [  3]: Loss 0.19663
Epoch [  3]: Loss 0.19071
Validation: Loss 0.17909 Accuracy 1.00000
Validation: Loss 0.17725 Accuracy 1.00000
Epoch [  4]: Loss 0.17683
Epoch [  4]: Loss 0.16940
Epoch [  4]: Loss 0.16207
Epoch [  4]: Loss 0.15299
Epoch [  4]: Loss 0.14434
Epoch [  4]: Loss 0.13954
Epoch [  4]: Loss 0.13248
Validation: Loss 0.12694 Accuracy 1.00000
Validation: Loss 0.12591 Accuracy 1.00000
Epoch [  5]: Loss 0.12561
Epoch [  5]: Loss 0.12243
Epoch [  5]: Loss 0.11473
Epoch [  5]: Loss 0.10923
Epoch [  5]: Loss 0.10484
Epoch [  5]: Loss 0.09969
Epoch [  5]: Loss 0.09117
Validation: Loss 0.09210 Accuracy 1.00000
Validation: Loss 0.09087 Accuracy 1.00000
Epoch [  6]: Loss 0.09119
Epoch [  6]: Loss 0.08721
Epoch [  6]: Loss 0.08429
Epoch [  6]: Loss 0.08053
Epoch [  6]: Loss 0.07627
Epoch [  6]: Loss 0.07256
Epoch [  6]: Loss 0.06711
Validation: Loss 0.06816 Accuracy 1.00000
Validation: Loss 0.06653 Accuracy 1.00000
Epoch [  7]: Loss 0.06757
Epoch [  7]: Loss 0.06418
Epoch [  7]: Loss 0.06203
Epoch [  7]: Loss 0.05897
Epoch [  7]: Loss 0.05659
Epoch [  7]: Loss 0.05474
Epoch [  7]: Loss 0.05000
Validation: Loss 0.05117 Accuracy 1.00000
Validation: Loss 0.04938 Accuracy 1.00000
Epoch [  8]: Loss 0.05164
Epoch [  8]: Loss 0.04705
Epoch [  8]: Loss 0.04762
Epoch [  8]: Loss 0.04371
Epoch [  8]: Loss 0.04206
Epoch [  8]: Loss 0.04097
Epoch [  8]: Loss 0.03638
Validation: Loss 0.03881 Accuracy 1.00000
Validation: Loss 0.03704 Accuracy 1.00000
Epoch [  9]: Loss 0.03905
Epoch [  9]: Loss 0.03536
Epoch [  9]: Loss 0.03580
Epoch [  9]: Loss 0.03409
Epoch [  9]: Loss 0.03250
Epoch [  9]: Loss 0.02967
Epoch [  9]: Loss 0.02998
Validation: Loss 0.02991 Accuracy 1.00000
Validation: Loss 0.02827 Accuracy 1.00000
Epoch [ 10]: Loss 0.02873
Epoch [ 10]: Loss 0.02851
Epoch [ 10]: Loss 0.02644
Epoch [ 10]: Loss 0.02676
Epoch [ 10]: Loss 0.02606
Epoch [ 10]: Loss 0.02348
Epoch [ 10]: Loss 0.02260
Validation: Loss 0.02368 Accuracy 1.00000
Validation: Loss 0.02222 Accuracy 1.00000
Epoch [ 11]: Loss 0.02268
Epoch [ 11]: Loss 0.02264
Epoch [ 11]: Loss 0.02253
Epoch [ 11]: Loss 0.02078
Epoch [ 11]: Loss 0.01990
Epoch [ 11]: Loss 0.01909
Epoch [ 11]: Loss 0.01871
Validation: Loss 0.01934 Accuracy 1.00000
Validation: Loss 0.01808 Accuracy 1.00000
Epoch [ 12]: Loss 0.01856
Epoch [ 12]: Loss 0.01859
Epoch [ 12]: Loss 0.01782
Epoch [ 12]: Loss 0.01613
Epoch [ 12]: Loss 0.01744
Epoch [ 12]: Loss 0.01647
Epoch [ 12]: Loss 0.01616
Validation: Loss 0.01628 Accuracy 1.00000
Validation: Loss 0.01520 Accuracy 1.00000
Epoch [ 13]: Loss 0.01541
Epoch [ 13]: Loss 0.01563
Epoch [ 13]: Loss 0.01470
Epoch [ 13]: Loss 0.01589
Epoch [ 13]: Loss 0.01363
Epoch [ 13]: Loss 0.01424
Epoch [ 13]: Loss 0.01273
Validation: Loss 0.01406 Accuracy 1.00000
Validation: Loss 0.01313 Accuracy 1.00000
Epoch [ 14]: Loss 0.01424
Epoch [ 14]: Loss 0.01366
Epoch [ 14]: Loss 0.01299
Epoch [ 14]: Loss 0.01243
Epoch [ 14]: Loss 0.01298
Epoch [ 14]: Loss 0.01158
Epoch [ 14]: Loss 0.01155
Validation: Loss 0.01240 Accuracy 1.00000
Validation: Loss 0.01158 Accuracy 1.00000
Epoch [ 15]: Loss 0.01231
Epoch [ 15]: Loss 0.01144
Epoch [ 15]: Loss 0.01168
Epoch [ 15]: Loss 0.01154
Epoch [ 15]: Loss 0.01102
Epoch [ 15]: Loss 0.01106
Epoch [ 15]: Loss 0.01025
Validation: Loss 0.01110 Accuracy 1.00000
Validation: Loss 0.01037 Accuracy 1.00000
Epoch [ 16]: Loss 0.01064
Epoch [ 16]: Loss 0.01049
Epoch [ 16]: Loss 0.01047
Epoch [ 16]: Loss 0.00999
Epoch [ 16]: Loss 0.01023
Epoch [ 16]: Loss 0.01026
Epoch [ 16]: Loss 0.00929
Validation: Loss 0.01006 Accuracy 1.00000
Validation: Loss 0.00939 Accuracy 1.00000
Epoch [ 17]: Loss 0.01024
Epoch [ 17]: Loss 0.00965
Epoch [ 17]: Loss 0.00946
Epoch [ 17]: Loss 0.00929
Epoch [ 17]: Loss 0.00959
Epoch [ 17]: Loss 0.00816
Epoch [ 17]: Loss 0.00874
Validation: Loss 0.00919 Accuracy 1.00000
Validation: Loss 0.00857 Accuracy 1.00000
Epoch [ 18]: Loss 0.00881
Epoch [ 18]: Loss 0.00875
Epoch [ 18]: Loss 0.00920
Epoch [ 18]: Loss 0.00820
Epoch [ 18]: Loss 0.00820
Epoch [ 18]: Loss 0.00842
Epoch [ 18]: Loss 0.00803
Validation: Loss 0.00845 Accuracy 1.00000
Validation: Loss 0.00788 Accuracy 1.00000
Epoch [ 19]: Loss 0.00827
Epoch [ 19]: Loss 0.00772
Epoch [ 19]: Loss 0.00804
Epoch [ 19]: Loss 0.00804
Epoch [ 19]: Loss 0.00791
Epoch [ 19]: Loss 0.00741
Epoch [ 19]: Loss 0.00786
Validation: Loss 0.00781 Accuracy 1.00000
Validation: Loss 0.00728 Accuracy 1.00000
Epoch [ 20]: Loss 0.00706
Epoch [ 20]: Loss 0.00747
Epoch [ 20]: Loss 0.00737
Epoch [ 20]: Loss 0.00740
Epoch [ 20]: Loss 0.00702
Epoch [ 20]: Loss 0.00752
Epoch [ 20]: Loss 0.00742
Validation: Loss 0.00726 Accuracy 1.00000
Validation: Loss 0.00676 Accuracy 1.00000
Epoch [ 21]: Loss 0.00680
Epoch [ 21]: Loss 0.00692
Epoch [ 21]: Loss 0.00706
Epoch [ 21]: Loss 0.00728
Epoch [ 21]: Loss 0.00658
Epoch [ 21]: Loss 0.00612
Epoch [ 21]: Loss 0.00699
Validation: Loss 0.00676 Accuracy 1.00000
Validation: Loss 0.00630 Accuracy 1.00000
Epoch [ 22]: Loss 0.00669
Epoch [ 22]: Loss 0.00615
Epoch [ 22]: Loss 0.00620
Epoch [ 22]: Loss 0.00679
Epoch [ 22]: Loss 0.00600
Epoch [ 22]: Loss 0.00630
Epoch [ 22]: Loss 0.00616
Validation: Loss 0.00633 Accuracy 1.00000
Validation: Loss 0.00589 Accuracy 1.00000
Epoch [ 23]: Loss 0.00643
Epoch [ 23]: Loss 0.00583
Epoch [ 23]: Loss 0.00599
Epoch [ 23]: Loss 0.00578
Epoch [ 23]: Loss 0.00626
Epoch [ 23]: Loss 0.00562
Epoch [ 23]: Loss 0.00493
Validation: Loss 0.00594 Accuracy 1.00000
Validation: Loss 0.00553 Accuracy 1.00000
Epoch [ 24]: Loss 0.00601
Epoch [ 24]: Loss 0.00548
Epoch [ 24]: Loss 0.00526
Epoch [ 24]: Loss 0.00567
Epoch [ 24]: Loss 0.00569
Epoch [ 24]: Loss 0.00551
Epoch [ 24]: Loss 0.00508
Validation: Loss 0.00559 Accuracy 1.00000
Validation: Loss 0.00520 Accuracy 1.00000
Epoch [ 25]: Loss 0.00539
Epoch [ 25]: Loss 0.00515
Epoch [ 25]: Loss 0.00546
Epoch [ 25]: Loss 0.00503
Epoch [ 25]: Loss 0.00528
Epoch [ 25]: Loss 0.00533
Epoch [ 25]: Loss 0.00500
Validation: Loss 0.00528 Accuracy 1.00000
Validation: Loss 0.00491 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.56248
Epoch [  1]: Loss 0.51458
Epoch [  1]: Loss 0.48072
Epoch [  1]: Loss 0.45577
Epoch [  1]: Loss 0.43041
Epoch [  1]: Loss 0.40489
Epoch [  1]: Loss 0.38433
Validation: Loss 0.36590 Accuracy 1.00000
Validation: Loss 0.36632 Accuracy 1.00000
Epoch [  2]: Loss 0.36544
Epoch [  2]: Loss 0.35846
Epoch [  2]: Loss 0.33377
Epoch [  2]: Loss 0.31587
Epoch [  2]: Loss 0.30113
Epoch [  2]: Loss 0.29569
Epoch [  2]: Loss 0.28527
Validation: Loss 0.25747 Accuracy 1.00000
Validation: Loss 0.25652 Accuracy 1.00000
Epoch [  3]: Loss 0.26236
Epoch [  3]: Loss 0.24692
Epoch [  3]: Loss 0.23333
Epoch [  3]: Loss 0.22068
Epoch [  3]: Loss 0.21304
Epoch [  3]: Loss 0.20160
Epoch [  3]: Loss 0.19337
Validation: Loss 0.18109 Accuracy 1.00000
Validation: Loss 0.17993 Accuracy 1.00000
Epoch [  4]: Loss 0.18231
Epoch [  4]: Loss 0.17341
Epoch [  4]: Loss 0.16459
Epoch [  4]: Loss 0.15686
Epoch [  4]: Loss 0.15184
Epoch [  4]: Loss 0.14197
Epoch [  4]: Loss 0.13611
Validation: Loss 0.12983 Accuracy 1.00000
Validation: Loss 0.12897 Accuracy 1.00000
Epoch [  5]: Loss 0.12916
Epoch [  5]: Loss 0.12544
Epoch [  5]: Loss 0.11855
Epoch [  5]: Loss 0.11328
Epoch [  5]: Loss 0.10950
Epoch [  5]: Loss 0.10356
Epoch [  5]: Loss 0.09961
Validation: Loss 0.09428 Accuracy 1.00000
Validation: Loss 0.09376 Accuracy 1.00000
Epoch [  6]: Loss 0.09507
Epoch [  6]: Loss 0.09080
Epoch [  6]: Loss 0.08619
Epoch [  6]: Loss 0.08269
Epoch [  6]: Loss 0.07970
Epoch [  6]: Loss 0.07736
Epoch [  6]: Loss 0.07522
Validation: Loss 0.06923 Accuracy 1.00000
Validation: Loss 0.06901 Accuracy 1.00000
Epoch [  7]: Loss 0.06988
Epoch [  7]: Loss 0.06733
Epoch [  7]: Loss 0.06386
Epoch [  7]: Loss 0.06356
Epoch [  7]: Loss 0.05810
Epoch [  7]: Loss 0.05724
Epoch [  7]: Loss 0.05489
Validation: Loss 0.05140 Accuracy 1.00000
Validation: Loss 0.05138 Accuracy 1.00000
Epoch [  8]: Loss 0.05414
Epoch [  8]: Loss 0.04993
Epoch [  8]: Loss 0.04671
Epoch [  8]: Loss 0.04568
Epoch [  8]: Loss 0.04472
Epoch [  8]: Loss 0.04278
Epoch [  8]: Loss 0.04363
Validation: Loss 0.03847 Accuracy 1.00000
Validation: Loss 0.03856 Accuracy 1.00000
Epoch [  9]: Loss 0.03891
Epoch [  9]: Loss 0.03931
Epoch [  9]: Loss 0.03689
Epoch [  9]: Loss 0.03574
Epoch [  9]: Loss 0.03315
Epoch [  9]: Loss 0.03136
Epoch [  9]: Loss 0.03059
Validation: Loss 0.02921 Accuracy 1.00000
Validation: Loss 0.02934 Accuracy 1.00000
Epoch [ 10]: Loss 0.03010
Epoch [ 10]: Loss 0.02901
Epoch [ 10]: Loss 0.02917
Epoch [ 10]: Loss 0.02671
Epoch [ 10]: Loss 0.02502
Epoch [ 10]: Loss 0.02494
Epoch [ 10]: Loss 0.02767
Validation: Loss 0.02283 Accuracy 1.00000
Validation: Loss 0.02298 Accuracy 1.00000
Epoch [ 11]: Loss 0.02308
Epoch [ 11]: Loss 0.02403
Epoch [ 11]: Loss 0.02281
Epoch [ 11]: Loss 0.02141
Epoch [ 11]: Loss 0.02036
Epoch [ 11]: Loss 0.01984
Epoch [ 11]: Loss 0.02063
Validation: Loss 0.01850 Accuracy 1.00000
Validation: Loss 0.01863 Accuracy 1.00000
Epoch [ 12]: Loss 0.01859
Epoch [ 12]: Loss 0.01850
Epoch [ 12]: Loss 0.01799
Epoch [ 12]: Loss 0.01830
Epoch [ 12]: Loss 0.01793
Epoch [ 12]: Loss 0.01714
Epoch [ 12]: Loss 0.01552
Validation: Loss 0.01551 Accuracy 1.00000
Validation: Loss 0.01561 Accuracy 1.00000
Epoch [ 13]: Loss 0.01598
Epoch [ 13]: Loss 0.01515
Epoch [ 13]: Loss 0.01463
Epoch [ 13]: Loss 0.01499
Epoch [ 13]: Loss 0.01540
Epoch [ 13]: Loss 0.01569
Epoch [ 13]: Loss 0.01371
Validation: Loss 0.01337 Accuracy 1.00000
Validation: Loss 0.01346 Accuracy 1.00000
Epoch [ 14]: Loss 0.01438
Epoch [ 14]: Loss 0.01338
Epoch [ 14]: Loss 0.01328
Epoch [ 14]: Loss 0.01347
Epoch [ 14]: Loss 0.01312
Epoch [ 14]: Loss 0.01245
Epoch [ 14]: Loss 0.01167
Validation: Loss 0.01178 Accuracy 1.00000
Validation: Loss 0.01185 Accuracy 1.00000
Epoch [ 15]: Loss 0.01268
Epoch [ 15]: Loss 0.01226
Epoch [ 15]: Loss 0.01078
Epoch [ 15]: Loss 0.01249
Epoch [ 15]: Loss 0.01085
Epoch [ 15]: Loss 0.01165
Epoch [ 15]: Loss 0.01146
Validation: Loss 0.01053 Accuracy 1.00000
Validation: Loss 0.01060 Accuracy 1.00000
Epoch [ 16]: Loss 0.01183
Epoch [ 16]: Loss 0.01098
Epoch [ 16]: Loss 0.01000
Epoch [ 16]: Loss 0.01024
Epoch [ 16]: Loss 0.01003
Epoch [ 16]: Loss 0.01042
Epoch [ 16]: Loss 0.01055
Validation: Loss 0.00952 Accuracy 1.00000
Validation: Loss 0.00958 Accuracy 1.00000
Epoch [ 17]: Loss 0.00961
Epoch [ 17]: Loss 0.00990
Epoch [ 17]: Loss 0.00988
Epoch [ 17]: Loss 0.00918
Epoch [ 17]: Loss 0.00968
Epoch [ 17]: Loss 0.00932
Epoch [ 17]: Loss 0.00968
Validation: Loss 0.00868 Accuracy 1.00000
Validation: Loss 0.00873 Accuracy 1.00000
Epoch [ 18]: Loss 0.00899
Epoch [ 18]: Loss 0.00831
Epoch [ 18]: Loss 0.00917
Epoch [ 18]: Loss 0.00907
Epoch [ 18]: Loss 0.00836
Epoch [ 18]: Loss 0.00884
Epoch [ 18]: Loss 0.00855
Validation: Loss 0.00797 Accuracy 1.00000
Validation: Loss 0.00802 Accuracy 1.00000
Epoch [ 19]: Loss 0.00837
Epoch [ 19]: Loss 0.00777
Epoch [ 19]: Loss 0.00851
Epoch [ 19]: Loss 0.00779
Epoch [ 19]: Loss 0.00817
Epoch [ 19]: Loss 0.00805
Epoch [ 19]: Loss 0.00749
Validation: Loss 0.00736 Accuracy 1.00000
Validation: Loss 0.00740 Accuracy 1.00000
Epoch [ 20]: Loss 0.00780
Epoch [ 20]: Loss 0.00755
Epoch [ 20]: Loss 0.00798
Epoch [ 20]: Loss 0.00762
Epoch [ 20]: Loss 0.00675
Epoch [ 20]: Loss 0.00723
Epoch [ 20]: Loss 0.00741
Validation: Loss 0.00683 Accuracy 1.00000
Validation: Loss 0.00687 Accuracy 1.00000
Epoch [ 21]: Loss 0.00709
Epoch [ 21]: Loss 0.00691
Epoch [ 21]: Loss 0.00693
Epoch [ 21]: Loss 0.00683
Epoch [ 21]: Loss 0.00725
Epoch [ 21]: Loss 0.00669
Epoch [ 21]: Loss 0.00710
Validation: Loss 0.00636 Accuracy 1.00000
Validation: Loss 0.00640 Accuracy 1.00000
Epoch [ 22]: Loss 0.00656
Epoch [ 22]: Loss 0.00672
Epoch [ 22]: Loss 0.00692
Epoch [ 22]: Loss 0.00585
Epoch [ 22]: Loss 0.00658
Epoch [ 22]: Loss 0.00623
Epoch [ 22]: Loss 0.00684
Validation: Loss 0.00595 Accuracy 1.00000
Validation: Loss 0.00598 Accuracy 1.00000
Epoch [ 23]: Loss 0.00560
Epoch [ 23]: Loss 0.00635
Epoch [ 23]: Loss 0.00610
Epoch [ 23]: Loss 0.00596
Epoch [ 23]: Loss 0.00632
Epoch [ 23]: Loss 0.00623
Epoch [ 23]: Loss 0.00568
Validation: Loss 0.00558 Accuracy 1.00000
Validation: Loss 0.00561 Accuracy 1.00000
Epoch [ 24]: Loss 0.00576
Epoch [ 24]: Loss 0.00603
Epoch [ 24]: Loss 0.00541
Epoch [ 24]: Loss 0.00561
Epoch [ 24]: Loss 0.00591
Epoch [ 24]: Loss 0.00551
Epoch [ 24]: Loss 0.00582
Validation: Loss 0.00525 Accuracy 1.00000
Validation: Loss 0.00528 Accuracy 1.00000
Epoch [ 25]: Loss 0.00542
Epoch [ 25]: Loss 0.00535
Epoch [ 25]: Loss 0.00489
Epoch [ 25]: Loss 0.00584
Epoch [ 25]: Loss 0.00537
Epoch [ 25]: Loss 0.00534
Epoch [ 25]: Loss 0.00552
Validation: Loss 0.00495 Accuracy 1.00000
Validation: Loss 0.00498 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" {compress = true} ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 8 default, 0 interactive, 4 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 8
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.4, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.4.5
- CURAND: 10.3.5
- CUFFT: 11.2.1
- CUSOLVER: 11.6.1
- CUSPARSE: 12.3.1
- CUPTI: 22.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.3.4
- CUDA_Driver_jll: 0.8.1+0
- CUDA_Runtime_jll: 0.12.1+0

Toolchain:
- Julia: 1.10.3
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.391 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.