Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
      Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(model, ps, st, (x, y))
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    # Get the dataloaders
    (train_loader, val_loader) = get_dataloaders()

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)

    dev = gpu_device()
    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(0.01f0); transform_variables=dev)

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            x = x |> dev
            y = y |> dev

            gs, loss, _, train_state = Lux.Experimental.compute_gradients(
                AutoZygote(), compute_loss, (x, y), train_state)
            train_state = Lux.Experimental.apply_gradients!(train_state, gs)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            x = x |> dev
            y = y |> dev
            loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
            acc = accuracy(ret.y_pred, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.56027
Epoch [  1]: Loss 0.51734
Epoch [  1]: Loss 0.46993
Epoch [  1]: Loss 0.43978
Epoch [  1]: Loss 0.42088
Epoch [  1]: Loss 0.41137
Epoch [  1]: Loss 0.39442
Validation: Loss 0.36877 Accuracy 1.00000
Validation: Loss 0.37007 Accuracy 1.00000
Epoch [  2]: Loss 0.36342
Epoch [  2]: Loss 0.34780
Epoch [  2]: Loss 0.34298
Epoch [  2]: Loss 0.31380
Epoch [  2]: Loss 0.30056
Epoch [  2]: Loss 0.28931
Epoch [  2]: Loss 0.26284
Validation: Loss 0.25843 Accuracy 1.00000
Validation: Loss 0.25920 Accuracy 1.00000
Epoch [  3]: Loss 0.25257
Epoch [  3]: Loss 0.24946
Epoch [  3]: Loss 0.23134
Epoch [  3]: Loss 0.21942
Epoch [  3]: Loss 0.20871
Epoch [  3]: Loss 0.20082
Epoch [  3]: Loss 0.19377
Validation: Loss 0.18033 Accuracy 1.00000
Validation: Loss 0.18068 Accuracy 1.00000
Epoch [  4]: Loss 0.17998
Epoch [  4]: Loss 0.17144
Epoch [  4]: Loss 0.16417
Epoch [  4]: Loss 0.15330
Epoch [  4]: Loss 0.14728
Epoch [  4]: Loss 0.14089
Epoch [  4]: Loss 0.13621
Validation: Loss 0.12850 Accuracy 1.00000
Validation: Loss 0.12868 Accuracy 1.00000
Epoch [  5]: Loss 0.12672
Epoch [  5]: Loss 0.12143
Epoch [  5]: Loss 0.11789
Epoch [  5]: Loss 0.11045
Epoch [  5]: Loss 0.10922
Epoch [  5]: Loss 0.10184
Epoch [  5]: Loss 0.09568
Validation: Loss 0.09345 Accuracy 1.00000
Validation: Loss 0.09366 Accuracy 1.00000
Epoch [  6]: Loss 0.09427
Epoch [  6]: Loss 0.08937
Epoch [  6]: Loss 0.08508
Epoch [  6]: Loss 0.08143
Epoch [  6]: Loss 0.07863
Epoch [  6]: Loss 0.07405
Epoch [  6]: Loss 0.06973
Validation: Loss 0.06900 Accuracy 1.00000
Validation: Loss 0.06927 Accuracy 1.00000
Epoch [  7]: Loss 0.06958
Epoch [  7]: Loss 0.06498
Epoch [  7]: Loss 0.06178
Epoch [  7]: Loss 0.06212
Epoch [  7]: Loss 0.05779
Epoch [  7]: Loss 0.05512
Epoch [  7]: Loss 0.05789
Validation: Loss 0.05155 Accuracy 1.00000
Validation: Loss 0.05185 Accuracy 1.00000
Epoch [  8]: Loss 0.05049
Epoch [  8]: Loss 0.04799
Epoch [  8]: Loss 0.04885
Epoch [  8]: Loss 0.04573
Epoch [  8]: Loss 0.04329
Epoch [  8]: Loss 0.04224
Epoch [  8]: Loss 0.04190
Validation: Loss 0.03885 Accuracy 1.00000
Validation: Loss 0.03915 Accuracy 1.00000
Epoch [  9]: Loss 0.03635
Epoch [  9]: Loss 0.03723
Epoch [  9]: Loss 0.03617
Epoch [  9]: Loss 0.03516
Epoch [  9]: Loss 0.03346
Epoch [  9]: Loss 0.03240
Epoch [  9]: Loss 0.03211
Validation: Loss 0.02974 Accuracy 1.00000
Validation: Loss 0.03001 Accuracy 1.00000
Epoch [ 10]: Loss 0.02940
Epoch [ 10]: Loss 0.02917
Epoch [ 10]: Loss 0.02920
Epoch [ 10]: Loss 0.02618
Epoch [ 10]: Loss 0.02405
Epoch [ 10]: Loss 0.02518
Epoch [ 10]: Loss 0.02354
Validation: Loss 0.02340 Accuracy 1.00000
Validation: Loss 0.02364 Accuracy 1.00000
Epoch [ 11]: Loss 0.02363
Epoch [ 11]: Loss 0.02372
Epoch [ 11]: Loss 0.02133
Epoch [ 11]: Loss 0.02131
Epoch [ 11]: Loss 0.02107
Epoch [ 11]: Loss 0.01894
Epoch [ 11]: Loss 0.01889
Validation: Loss 0.01905 Accuracy 1.00000
Validation: Loss 0.01926 Accuracy 1.00000
Epoch [ 12]: Loss 0.01889
Epoch [ 12]: Loss 0.01839
Epoch [ 12]: Loss 0.01814
Epoch [ 12]: Loss 0.01767
Epoch [ 12]: Loss 0.01713
Epoch [ 12]: Loss 0.01707
Epoch [ 12]: Loss 0.01462
Validation: Loss 0.01603 Accuracy 1.00000
Validation: Loss 0.01620 Accuracy 1.00000
Epoch [ 13]: Loss 0.01619
Epoch [ 13]: Loss 0.01510
Epoch [ 13]: Loss 0.01530
Epoch [ 13]: Loss 0.01455
Epoch [ 13]: Loss 0.01503
Epoch [ 13]: Loss 0.01447
Epoch [ 13]: Loss 0.01449
Validation: Loss 0.01386 Accuracy 1.00000
Validation: Loss 0.01401 Accuracy 1.00000
Epoch [ 14]: Loss 0.01390
Epoch [ 14]: Loss 0.01433
Epoch [ 14]: Loss 0.01329
Epoch [ 14]: Loss 0.01214
Epoch [ 14]: Loss 0.01219
Epoch [ 14]: Loss 0.01304
Epoch [ 14]: Loss 0.01313
Validation: Loss 0.01222 Accuracy 1.00000
Validation: Loss 0.01236 Accuracy 1.00000
Epoch [ 15]: Loss 0.01164
Epoch [ 15]: Loss 0.01275
Epoch [ 15]: Loss 0.01127
Epoch [ 15]: Loss 0.01145
Epoch [ 15]: Loss 0.01140
Epoch [ 15]: Loss 0.01123
Epoch [ 15]: Loss 0.01236
Validation: Loss 0.01094 Accuracy 1.00000
Validation: Loss 0.01106 Accuracy 1.00000
Epoch [ 16]: Loss 0.01100
Epoch [ 16]: Loss 0.01044
Epoch [ 16]: Loss 0.01052
Epoch [ 16]: Loss 0.01071
Epoch [ 16]: Loss 0.00990
Epoch [ 16]: Loss 0.01020
Epoch [ 16]: Loss 0.01089
Validation: Loss 0.00990 Accuracy 1.00000
Validation: Loss 0.01001 Accuracy 1.00000
Epoch [ 17]: Loss 0.01016
Epoch [ 17]: Loss 0.00966
Epoch [ 17]: Loss 0.00927
Epoch [ 17]: Loss 0.00926
Epoch [ 17]: Loss 0.00994
Epoch [ 17]: Loss 0.00867
Epoch [ 17]: Loss 0.00987
Validation: Loss 0.00903 Accuracy 1.00000
Validation: Loss 0.00913 Accuracy 1.00000
Epoch [ 18]: Loss 0.00912
Epoch [ 18]: Loss 0.00869
Epoch [ 18]: Loss 0.00917
Epoch [ 18]: Loss 0.00820
Epoch [ 18]: Loss 0.00864
Epoch [ 18]: Loss 0.00806
Epoch [ 18]: Loss 0.00986
Validation: Loss 0.00829 Accuracy 1.00000
Validation: Loss 0.00839 Accuracy 1.00000
Epoch [ 19]: Loss 0.00830
Epoch [ 19]: Loss 0.00805
Epoch [ 19]: Loss 0.00860
Epoch [ 19]: Loss 0.00801
Epoch [ 19]: Loss 0.00756
Epoch [ 19]: Loss 0.00747
Epoch [ 19]: Loss 0.00817
Validation: Loss 0.00766 Accuracy 1.00000
Validation: Loss 0.00775 Accuracy 1.00000
Epoch [ 20]: Loss 0.00762
Epoch [ 20]: Loss 0.00754
Epoch [ 20]: Loss 0.00713
Epoch [ 20]: Loss 0.00765
Epoch [ 20]: Loss 0.00761
Epoch [ 20]: Loss 0.00705
Epoch [ 20]: Loss 0.00668
Validation: Loss 0.00710 Accuracy 1.00000
Validation: Loss 0.00719 Accuracy 1.00000
Epoch [ 21]: Loss 0.00707
Epoch [ 21]: Loss 0.00696
Epoch [ 21]: Loss 0.00652
Epoch [ 21]: Loss 0.00717
Epoch [ 21]: Loss 0.00695
Epoch [ 21]: Loss 0.00679
Epoch [ 21]: Loss 0.00625
Validation: Loss 0.00662 Accuracy 1.00000
Validation: Loss 0.00670 Accuracy 1.00000
Epoch [ 22]: Loss 0.00703
Epoch [ 22]: Loss 0.00655
Epoch [ 22]: Loss 0.00683
Epoch [ 22]: Loss 0.00631
Epoch [ 22]: Loss 0.00615
Epoch [ 22]: Loss 0.00576
Epoch [ 22]: Loss 0.00625
Validation: Loss 0.00620 Accuracy 1.00000
Validation: Loss 0.00627 Accuracy 1.00000
Epoch [ 23]: Loss 0.00604
Epoch [ 23]: Loss 0.00611
Epoch [ 23]: Loss 0.00606
Epoch [ 23]: Loss 0.00618
Epoch [ 23]: Loss 0.00604
Epoch [ 23]: Loss 0.00550
Epoch [ 23]: Loss 0.00676
Validation: Loss 0.00582 Accuracy 1.00000
Validation: Loss 0.00589 Accuracy 1.00000
Epoch [ 24]: Loss 0.00576
Epoch [ 24]: Loss 0.00529
Epoch [ 24]: Loss 0.00565
Epoch [ 24]: Loss 0.00575
Epoch [ 24]: Loss 0.00586
Epoch [ 24]: Loss 0.00579
Epoch [ 24]: Loss 0.00498
Validation: Loss 0.00547 Accuracy 1.00000
Validation: Loss 0.00554 Accuracy 1.00000
Epoch [ 25]: Loss 0.00532
Epoch [ 25]: Loss 0.00532
Epoch [ 25]: Loss 0.00537
Epoch [ 25]: Loss 0.00516
Epoch [ 25]: Loss 0.00534
Epoch [ 25]: Loss 0.00537
Epoch [ 25]: Loss 0.00567
Validation: Loss 0.00517 Accuracy 1.00000
Validation: Loss 0.00523 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.56111
Epoch [  1]: Loss 0.51024
Epoch [  1]: Loss 0.47972
Epoch [  1]: Loss 0.45146
Epoch [  1]: Loss 0.43833
Epoch [  1]: Loss 0.39483
Epoch [  1]: Loss 0.38543
Validation: Loss 0.37798 Accuracy 1.00000
Validation: Loss 0.37733 Accuracy 1.00000
Epoch [  2]: Loss 0.38061
Epoch [  2]: Loss 0.36099
Epoch [  2]: Loss 0.34008
Epoch [  2]: Loss 0.30362
Epoch [  2]: Loss 0.30314
Epoch [  2]: Loss 0.28694
Epoch [  2]: Loss 0.25672
Validation: Loss 0.26610 Accuracy 1.00000
Validation: Loss 0.26558 Accuracy 1.00000
Epoch [  3]: Loss 0.25819
Epoch [  3]: Loss 0.25306
Epoch [  3]: Loss 0.23848
Epoch [  3]: Loss 0.21937
Epoch [  3]: Loss 0.21415
Epoch [  3]: Loss 0.19917
Epoch [  3]: Loss 0.19070
Validation: Loss 0.18673 Accuracy 1.00000
Validation: Loss 0.18637 Accuracy 1.00000
Epoch [  4]: Loss 0.18117
Epoch [  4]: Loss 0.17147
Epoch [  4]: Loss 0.16395
Epoch [  4]: Loss 0.15973
Epoch [  4]: Loss 0.15335
Epoch [  4]: Loss 0.14485
Epoch [  4]: Loss 0.14090
Validation: Loss 0.13398 Accuracy 1.00000
Validation: Loss 0.13369 Accuracy 1.00000
Epoch [  5]: Loss 0.13292
Epoch [  5]: Loss 0.12412
Epoch [  5]: Loss 0.11729
Epoch [  5]: Loss 0.11494
Epoch [  5]: Loss 0.10965
Epoch [  5]: Loss 0.10620
Epoch [  5]: Loss 0.10036
Validation: Loss 0.09812 Accuracy 1.00000
Validation: Loss 0.09785 Accuracy 1.00000
Epoch [  6]: Loss 0.09698
Epoch [  6]: Loss 0.09284
Epoch [  6]: Loss 0.08850
Epoch [  6]: Loss 0.08309
Epoch [  6]: Loss 0.07928
Epoch [  6]: Loss 0.07803
Epoch [  6]: Loss 0.07153
Validation: Loss 0.07301 Accuracy 1.00000
Validation: Loss 0.07274 Accuracy 1.00000
Epoch [  7]: Loss 0.06972
Epoch [  7]: Loss 0.06985
Epoch [  7]: Loss 0.06731
Epoch [  7]: Loss 0.06128
Epoch [  7]: Loss 0.05898
Epoch [  7]: Loss 0.05883
Epoch [  7]: Loss 0.05208
Validation: Loss 0.05495 Accuracy 1.00000
Validation: Loss 0.05467 Accuracy 1.00000
Epoch [  8]: Loss 0.05294
Epoch [  8]: Loss 0.05022
Epoch [  8]: Loss 0.05044
Epoch [  8]: Loss 0.04717
Epoch [  8]: Loss 0.04491
Epoch [  8]: Loss 0.04385
Epoch [  8]: Loss 0.03809
Validation: Loss 0.04167 Accuracy 1.00000
Validation: Loss 0.04142 Accuracy 1.00000
Epoch [  9]: Loss 0.04011
Epoch [  9]: Loss 0.03893
Epoch [  9]: Loss 0.03594
Epoch [  9]: Loss 0.03502
Epoch [  9]: Loss 0.03542
Epoch [  9]: Loss 0.03326
Epoch [  9]: Loss 0.02985
Validation: Loss 0.03205 Accuracy 1.00000
Validation: Loss 0.03183 Accuracy 1.00000
Epoch [ 10]: Loss 0.03157
Epoch [ 10]: Loss 0.03037
Epoch [ 10]: Loss 0.02796
Epoch [ 10]: Loss 0.02624
Epoch [ 10]: Loss 0.02717
Epoch [ 10]: Loss 0.02434
Epoch [ 10]: Loss 0.02706
Validation: Loss 0.02532 Accuracy 1.00000
Validation: Loss 0.02513 Accuracy 1.00000
Epoch [ 11]: Loss 0.02408
Epoch [ 11]: Loss 0.02160
Epoch [ 11]: Loss 0.02329
Epoch [ 11]: Loss 0.02210
Epoch [ 11]: Loss 0.02164
Epoch [ 11]: Loss 0.02085
Epoch [ 11]: Loss 0.02027
Validation: Loss 0.02064 Accuracy 1.00000
Validation: Loss 0.02048 Accuracy 1.00000
Epoch [ 12]: Loss 0.01858
Epoch [ 12]: Loss 0.01812
Epoch [ 12]: Loss 0.01884
Epoch [ 12]: Loss 0.01852
Epoch [ 12]: Loss 0.01762
Epoch [ 12]: Loss 0.01749
Epoch [ 12]: Loss 0.01906
Validation: Loss 0.01735 Accuracy 1.00000
Validation: Loss 0.01721 Accuracy 1.00000
Epoch [ 13]: Loss 0.01583
Epoch [ 13]: Loss 0.01730
Epoch [ 13]: Loss 0.01511
Epoch [ 13]: Loss 0.01524
Epoch [ 13]: Loss 0.01491
Epoch [ 13]: Loss 0.01520
Epoch [ 13]: Loss 0.01240
Validation: Loss 0.01494 Accuracy 1.00000
Validation: Loss 0.01482 Accuracy 1.00000
Epoch [ 14]: Loss 0.01439
Epoch [ 14]: Loss 0.01416
Epoch [ 14]: Loss 0.01269
Epoch [ 14]: Loss 0.01284
Epoch [ 14]: Loss 0.01385
Epoch [ 14]: Loss 0.01314
Epoch [ 14]: Loss 0.01162
Validation: Loss 0.01317 Accuracy 1.00000
Validation: Loss 0.01306 Accuracy 1.00000
Epoch [ 15]: Loss 0.01345
Epoch [ 15]: Loss 0.01244
Epoch [ 15]: Loss 0.01139
Epoch [ 15]: Loss 0.01141
Epoch [ 15]: Loss 0.01119
Epoch [ 15]: Loss 0.01195
Epoch [ 15]: Loss 0.01034
Validation: Loss 0.01179 Accuracy 1.00000
Validation: Loss 0.01168 Accuracy 1.00000
Epoch [ 16]: Loss 0.01060
Epoch [ 16]: Loss 0.01131
Epoch [ 16]: Loss 0.01151
Epoch [ 16]: Loss 0.01016
Epoch [ 16]: Loss 0.01028
Epoch [ 16]: Loss 0.01028
Epoch [ 16]: Loss 0.01062
Validation: Loss 0.01067 Accuracy 1.00000
Validation: Loss 0.01058 Accuracy 1.00000
Epoch [ 17]: Loss 0.01003
Epoch [ 17]: Loss 0.01000
Epoch [ 17]: Loss 0.00974
Epoch [ 17]: Loss 0.00961
Epoch [ 17]: Loss 0.00893
Epoch [ 17]: Loss 0.00973
Epoch [ 17]: Loss 0.01025
Validation: Loss 0.00974 Accuracy 1.00000
Validation: Loss 0.00965 Accuracy 1.00000
Epoch [ 18]: Loss 0.00883
Epoch [ 18]: Loss 0.00914
Epoch [ 18]: Loss 0.00863
Epoch [ 18]: Loss 0.00891
Epoch [ 18]: Loss 0.00902
Epoch [ 18]: Loss 0.00829
Epoch [ 18]: Loss 0.01028
Validation: Loss 0.00894 Accuracy 1.00000
Validation: Loss 0.00886 Accuracy 1.00000
Epoch [ 19]: Loss 0.00805
Epoch [ 19]: Loss 0.00792
Epoch [ 19]: Loss 0.00810
Epoch [ 19]: Loss 0.00855
Epoch [ 19]: Loss 0.00800
Epoch [ 19]: Loss 0.00826
Epoch [ 19]: Loss 0.00817
Validation: Loss 0.00825 Accuracy 1.00000
Validation: Loss 0.00818 Accuracy 1.00000
Epoch [ 20]: Loss 0.00748
Epoch [ 20]: Loss 0.00837
Epoch [ 20]: Loss 0.00771
Epoch [ 20]: Loss 0.00705
Epoch [ 20]: Loss 0.00757
Epoch [ 20]: Loss 0.00713
Epoch [ 20]: Loss 0.00714
Validation: Loss 0.00765 Accuracy 1.00000
Validation: Loss 0.00758 Accuracy 1.00000
Epoch [ 21]: Loss 0.00697
Epoch [ 21]: Loss 0.00710
Epoch [ 21]: Loss 0.00685
Epoch [ 21]: Loss 0.00739
Epoch [ 21]: Loss 0.00724
Epoch [ 21]: Loss 0.00641
Epoch [ 21]: Loss 0.00705
Validation: Loss 0.00713 Accuracy 1.00000
Validation: Loss 0.00706 Accuracy 1.00000
Epoch [ 22]: Loss 0.00684
Epoch [ 22]: Loss 0.00664
Epoch [ 22]: Loss 0.00648
Epoch [ 22]: Loss 0.00699
Epoch [ 22]: Loss 0.00626
Epoch [ 22]: Loss 0.00630
Epoch [ 22]: Loss 0.00512
Validation: Loss 0.00667 Accuracy 1.00000
Validation: Loss 0.00660 Accuracy 1.00000
Epoch [ 23]: Loss 0.00628
Epoch [ 23]: Loss 0.00616
Epoch [ 23]: Loss 0.00632
Epoch [ 23]: Loss 0.00591
Epoch [ 23]: Loss 0.00616
Epoch [ 23]: Loss 0.00595
Epoch [ 23]: Loss 0.00558
Validation: Loss 0.00626 Accuracy 1.00000
Validation: Loss 0.00620 Accuracy 1.00000
Epoch [ 24]: Loss 0.00590
Epoch [ 24]: Loss 0.00577
Epoch [ 24]: Loss 0.00556
Epoch [ 24]: Loss 0.00589
Epoch [ 24]: Loss 0.00560
Epoch [ 24]: Loss 0.00554
Epoch [ 24]: Loss 0.00636
Validation: Loss 0.00589 Accuracy 1.00000
Validation: Loss 0.00583 Accuracy 1.00000
Epoch [ 25]: Loss 0.00536
Epoch [ 25]: Loss 0.00554
Epoch [ 25]: Loss 0.00541
Epoch [ 25]: Loss 0.00570
Epoch [ 25]: Loss 0.00522
Epoch [ 25]: Loss 0.00520
Epoch [ 25]: Loss 0.00533
Validation: Loss 0.00556 Accuracy 1.00000
Validation: Loss 0.00550 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" {compress = true} ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 4 default, 0 interactive, 2 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 4
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.2

CUDA libraries: 
- CUBLAS: 12.5.2
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.2
- CUSPARSE: 12.4.1
- CUPTI: 23.0.0
- NVML: 12.0.0+555.42.2

Julia packages: 
- CUDA: 5.4.2
- CUDA_Driver_jll: 0.9.0+0
- CUDA_Runtime_jll: 0.14.0+1

Toolchain:
- Julia: 1.10.4
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.453 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/DaGeB/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.