Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
      Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(model, ps, st, (x, y))
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    # Get the dataloaders
    (train_loader, val_loader) = get_dataloaders()

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)

    dev = gpu_device()
    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(0.01f0); transform_variables=dev)

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            x = x |> dev
            y = y |> dev

            gs, loss, _, train_state = Lux.Experimental.compute_gradients(
                AutoZygote(), compute_loss, (x, y), train_state)
            train_state = Lux.Experimental.apply_gradients!(train_state, gs)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            x = x |> dev
            y = y |> dev
            loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
            acc = accuracy(ret.y_pred, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.56138
Epoch [  1]: Loss 0.51543
Epoch [  1]: Loss 0.47491
Epoch [  1]: Loss 0.45421
Epoch [  1]: Loss 0.42064
Epoch [  1]: Loss 0.40526
Epoch [  1]: Loss 0.39037
Validation: Loss 0.37900 Accuracy 1.00000
Validation: Loss 0.37851 Accuracy 1.00000
Epoch [  2]: Loss 0.37573
Epoch [  2]: Loss 0.34884
Epoch [  2]: Loss 0.32759
Epoch [  2]: Loss 0.32150
Epoch [  2]: Loss 0.30367
Epoch [  2]: Loss 0.28439
Epoch [  2]: Loss 0.28083
Validation: Loss 0.26544 Accuracy 1.00000
Validation: Loss 0.26492 Accuracy 1.00000
Epoch [  3]: Loss 0.25931
Epoch [  3]: Loss 0.25155
Epoch [  3]: Loss 0.23314
Epoch [  3]: Loss 0.22306
Epoch [  3]: Loss 0.20806
Epoch [  3]: Loss 0.19905
Epoch [  3]: Loss 0.19053
Validation: Loss 0.18445 Accuracy 1.00000
Validation: Loss 0.18396 Accuracy 1.00000
Epoch [  4]: Loss 0.17847
Epoch [  4]: Loss 0.17285
Epoch [  4]: Loss 0.16358
Epoch [  4]: Loss 0.15646
Epoch [  4]: Loss 0.15239
Epoch [  4]: Loss 0.14299
Epoch [  4]: Loss 0.13152
Validation: Loss 0.13168 Accuracy 1.00000
Validation: Loss 0.13129 Accuracy 1.00000
Epoch [  5]: Loss 0.13088
Epoch [  5]: Loss 0.12280
Epoch [  5]: Loss 0.11662
Epoch [  5]: Loss 0.11423
Epoch [  5]: Loss 0.10691
Epoch [  5]: Loss 0.10404
Epoch [  5]: Loss 0.09745
Validation: Loss 0.09632 Accuracy 1.00000
Validation: Loss 0.09601 Accuracy 1.00000
Epoch [  6]: Loss 0.09393
Epoch [  6]: Loss 0.08927
Epoch [  6]: Loss 0.08688
Epoch [  6]: Loss 0.08293
Epoch [  6]: Loss 0.07885
Epoch [  6]: Loss 0.07783
Epoch [  6]: Loss 0.06830
Validation: Loss 0.07166 Accuracy 1.00000
Validation: Loss 0.07145 Accuracy 1.00000
Epoch [  7]: Loss 0.06980
Epoch [  7]: Loss 0.06711
Epoch [  7]: Loss 0.06292
Epoch [  7]: Loss 0.06126
Epoch [  7]: Loss 0.05942
Epoch [  7]: Loss 0.05655
Epoch [  7]: Loss 0.05489
Validation: Loss 0.05396 Accuracy 1.00000
Validation: Loss 0.05381 Accuracy 1.00000
Epoch [  8]: Loss 0.05193
Epoch [  8]: Loss 0.04947
Epoch [  8]: Loss 0.04609
Epoch [  8]: Loss 0.04554
Epoch [  8]: Loss 0.04507
Epoch [  8]: Loss 0.04272
Epoch [  8]: Loss 0.04699
Validation: Loss 0.04089 Accuracy 1.00000
Validation: Loss 0.04080 Accuracy 1.00000
Epoch [  9]: Loss 0.03963
Epoch [  9]: Loss 0.03453
Epoch [  9]: Loss 0.03825
Epoch [  9]: Loss 0.03493
Epoch [  9]: Loss 0.03426
Epoch [  9]: Loss 0.03224
Epoch [  9]: Loss 0.02861
Validation: Loss 0.03137 Accuracy 1.00000
Validation: Loss 0.03131 Accuracy 1.00000
Epoch [ 10]: Loss 0.02911
Epoch [ 10]: Loss 0.02893
Epoch [ 10]: Loss 0.02840
Epoch [ 10]: Loss 0.02861
Epoch [ 10]: Loss 0.02599
Epoch [ 10]: Loss 0.02336
Epoch [ 10]: Loss 0.02314
Validation: Loss 0.02476 Accuracy 1.00000
Validation: Loss 0.02471 Accuracy 1.00000
Epoch [ 11]: Loss 0.02262
Epoch [ 11]: Loss 0.02253
Epoch [ 11]: Loss 0.02262
Epoch [ 11]: Loss 0.02113
Epoch [ 11]: Loss 0.02130
Epoch [ 11]: Loss 0.01967
Epoch [ 11]: Loss 0.02151
Validation: Loss 0.02021 Accuracy 1.00000
Validation: Loss 0.02017 Accuracy 1.00000
Epoch [ 12]: Loss 0.01902
Epoch [ 12]: Loss 0.01905
Epoch [ 12]: Loss 0.01843
Epoch [ 12]: Loss 0.01761
Epoch [ 12]: Loss 0.01761
Epoch [ 12]: Loss 0.01591
Epoch [ 12]: Loss 0.01515
Validation: Loss 0.01699 Accuracy 1.00000
Validation: Loss 0.01695 Accuracy 1.00000
Epoch [ 13]: Loss 0.01655
Epoch [ 13]: Loss 0.01613
Epoch [ 13]: Loss 0.01452
Epoch [ 13]: Loss 0.01515
Epoch [ 13]: Loss 0.01410
Epoch [ 13]: Loss 0.01455
Epoch [ 13]: Loss 0.01383
Validation: Loss 0.01466 Accuracy 1.00000
Validation: Loss 0.01463 Accuracy 1.00000
Epoch [ 14]: Loss 0.01398
Epoch [ 14]: Loss 0.01429
Epoch [ 14]: Loss 0.01315
Epoch [ 14]: Loss 0.01254
Epoch [ 14]: Loss 0.01296
Epoch [ 14]: Loss 0.01228
Epoch [ 14]: Loss 0.01180
Validation: Loss 0.01293 Accuracy 1.00000
Validation: Loss 0.01290 Accuracy 1.00000
Epoch [ 15]: Loss 0.01179
Epoch [ 15]: Loss 0.01202
Epoch [ 15]: Loss 0.01100
Epoch [ 15]: Loss 0.01172
Epoch [ 15]: Loss 0.01255
Epoch [ 15]: Loss 0.01077
Epoch [ 15]: Loss 0.01150
Validation: Loss 0.01157 Accuracy 1.00000
Validation: Loss 0.01154 Accuracy 1.00000
Epoch [ 16]: Loss 0.01084
Epoch [ 16]: Loss 0.01053
Epoch [ 16]: Loss 0.01013
Epoch [ 16]: Loss 0.01080
Epoch [ 16]: Loss 0.01059
Epoch [ 16]: Loss 0.01016
Epoch [ 16]: Loss 0.00916
Validation: Loss 0.01047 Accuracy 1.00000
Validation: Loss 0.01045 Accuracy 1.00000
Epoch [ 17]: Loss 0.00997
Epoch [ 17]: Loss 0.00974
Epoch [ 17]: Loss 0.00969
Epoch [ 17]: Loss 0.00985
Epoch [ 17]: Loss 0.00901
Epoch [ 17]: Loss 0.00868
Epoch [ 17]: Loss 0.00951
Validation: Loss 0.00955 Accuracy 1.00000
Validation: Loss 0.00953 Accuracy 1.00000
Epoch [ 18]: Loss 0.00897
Epoch [ 18]: Loss 0.00904
Epoch [ 18]: Loss 0.00892
Epoch [ 18]: Loss 0.00839
Epoch [ 18]: Loss 0.00817
Epoch [ 18]: Loss 0.00860
Epoch [ 18]: Loss 0.00848
Validation: Loss 0.00877 Accuracy 1.00000
Validation: Loss 0.00876 Accuracy 1.00000
Epoch [ 19]: Loss 0.00820
Epoch [ 19]: Loss 0.00824
Epoch [ 19]: Loss 0.00772
Epoch [ 19]: Loss 0.00835
Epoch [ 19]: Loss 0.00793
Epoch [ 19]: Loss 0.00715
Epoch [ 19]: Loss 0.00924
Validation: Loss 0.00811 Accuracy 1.00000
Validation: Loss 0.00809 Accuracy 1.00000
Epoch [ 20]: Loss 0.00744
Epoch [ 20]: Loss 0.00761
Epoch [ 20]: Loss 0.00777
Epoch [ 20]: Loss 0.00713
Epoch [ 20]: Loss 0.00765
Epoch [ 20]: Loss 0.00685
Epoch [ 20]: Loss 0.00683
Validation: Loss 0.00752 Accuracy 1.00000
Validation: Loss 0.00750 Accuracy 1.00000
Epoch [ 21]: Loss 0.00664
Epoch [ 21]: Loss 0.00713
Epoch [ 21]: Loss 0.00683
Epoch [ 21]: Loss 0.00677
Epoch [ 21]: Loss 0.00675
Epoch [ 21]: Loss 0.00716
Epoch [ 21]: Loss 0.00630
Validation: Loss 0.00701 Accuracy 1.00000
Validation: Loss 0.00700 Accuracy 1.00000
Epoch [ 22]: Loss 0.00614
Epoch [ 22]: Loss 0.00659
Epoch [ 22]: Loss 0.00678
Epoch [ 22]: Loss 0.00637
Epoch [ 22]: Loss 0.00632
Epoch [ 22]: Loss 0.00608
Epoch [ 22]: Loss 0.00687
Validation: Loss 0.00655 Accuracy 1.00000
Validation: Loss 0.00654 Accuracy 1.00000
Epoch [ 23]: Loss 0.00638
Epoch [ 23]: Loss 0.00595
Epoch [ 23]: Loss 0.00559
Epoch [ 23]: Loss 0.00624
Epoch [ 23]: Loss 0.00601
Epoch [ 23]: Loss 0.00594
Epoch [ 23]: Loss 0.00533
Validation: Loss 0.00615 Accuracy 1.00000
Validation: Loss 0.00614 Accuracy 1.00000
Epoch [ 24]: Loss 0.00574
Epoch [ 24]: Loss 0.00573
Epoch [ 24]: Loss 0.00557
Epoch [ 24]: Loss 0.00536
Epoch [ 24]: Loss 0.00622
Epoch [ 24]: Loss 0.00524
Epoch [ 24]: Loss 0.00519
Validation: Loss 0.00579 Accuracy 1.00000
Validation: Loss 0.00577 Accuracy 1.00000
Epoch [ 25]: Loss 0.00547
Epoch [ 25]: Loss 0.00519
Epoch [ 25]: Loss 0.00507
Epoch [ 25]: Loss 0.00544
Epoch [ 25]: Loss 0.00568
Epoch [ 25]: Loss 0.00508
Epoch [ 25]: Loss 0.00471
Validation: Loss 0.00546 Accuracy 1.00000
Validation: Loss 0.00545 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.56152
Epoch [  1]: Loss 0.51135
Epoch [  1]: Loss 0.47369
Epoch [  1]: Loss 0.45134
Epoch [  1]: Loss 0.41875
Epoch [  1]: Loss 0.41115
Epoch [  1]: Loss 0.38459
Validation: Loss 0.38126 Accuracy 1.00000
Validation: Loss 0.37176 Accuracy 1.00000
Epoch [  2]: Loss 0.36216
Epoch [  2]: Loss 0.34810
Epoch [  2]: Loss 0.32734
Epoch [  2]: Loss 0.31887
Epoch [  2]: Loss 0.30620
Epoch [  2]: Loss 0.29313
Epoch [  2]: Loss 0.26615
Validation: Loss 0.26659 Accuracy 1.00000
Validation: Loss 0.26028 Accuracy 1.00000
Epoch [  3]: Loss 0.25885
Epoch [  3]: Loss 0.24435
Epoch [  3]: Loss 0.23333
Epoch [  3]: Loss 0.22244
Epoch [  3]: Loss 0.20810
Epoch [  3]: Loss 0.19997
Epoch [  3]: Loss 0.18157
Validation: Loss 0.18451 Accuracy 1.00000
Validation: Loss 0.18101 Accuracy 1.00000
Epoch [  4]: Loss 0.18049
Epoch [  4]: Loss 0.16932
Epoch [  4]: Loss 0.16371
Epoch [  4]: Loss 0.15550
Epoch [  4]: Loss 0.14694
Epoch [  4]: Loss 0.14089
Epoch [  4]: Loss 0.13504
Validation: Loss 0.13142 Accuracy 1.00000
Validation: Loss 0.12891 Accuracy 1.00000
Epoch [  5]: Loss 0.12704
Epoch [  5]: Loss 0.12201
Epoch [  5]: Loss 0.11937
Epoch [  5]: Loss 0.11059
Epoch [  5]: Loss 0.10553
Epoch [  5]: Loss 0.10194
Epoch [  5]: Loss 0.09634
Validation: Loss 0.09584 Accuracy 1.00000
Validation: Loss 0.09368 Accuracy 1.00000
Epoch [  6]: Loss 0.09143
Epoch [  6]: Loss 0.09040
Epoch [  6]: Loss 0.08453
Epoch [  6]: Loss 0.08110
Epoch [  6]: Loss 0.07696
Epoch [  6]: Loss 0.07703
Epoch [  6]: Loss 0.06825
Validation: Loss 0.07126 Accuracy 1.00000
Validation: Loss 0.06924 Accuracy 1.00000
Epoch [  7]: Loss 0.06977
Epoch [  7]: Loss 0.06600
Epoch [  7]: Loss 0.06306
Epoch [  7]: Loss 0.05917
Epoch [  7]: Loss 0.05780
Epoch [  7]: Loss 0.05524
Epoch [  7]: Loss 0.05186
Validation: Loss 0.05365 Accuracy 1.00000
Validation: Loss 0.05180 Accuracy 1.00000
Epoch [  8]: Loss 0.05150
Epoch [  8]: Loss 0.04984
Epoch [  8]: Loss 0.04647
Epoch [  8]: Loss 0.04526
Epoch [  8]: Loss 0.04193
Epoch [  8]: Loss 0.04218
Epoch [  8]: Loss 0.04029
Validation: Loss 0.04074 Accuracy 1.00000
Validation: Loss 0.03907 Accuracy 1.00000
Epoch [  9]: Loss 0.03702
Epoch [  9]: Loss 0.03895
Epoch [  9]: Loss 0.03590
Epoch [  9]: Loss 0.03244
Epoch [  9]: Loss 0.03297
Epoch [  9]: Loss 0.03232
Epoch [  9]: Loss 0.02928
Validation: Loss 0.03139 Accuracy 1.00000
Validation: Loss 0.02991 Accuracy 1.00000
Epoch [ 10]: Loss 0.02910
Epoch [ 10]: Loss 0.02884
Epoch [ 10]: Loss 0.02706
Epoch [ 10]: Loss 0.02716
Epoch [ 10]: Loss 0.02491
Epoch [ 10]: Loss 0.02373
Epoch [ 10]: Loss 0.02521
Validation: Loss 0.02483 Accuracy 1.00000
Validation: Loss 0.02354 Accuracy 1.00000
Epoch [ 11]: Loss 0.02133
Epoch [ 11]: Loss 0.02185
Epoch [ 11]: Loss 0.02317
Epoch [ 11]: Loss 0.02072
Epoch [ 11]: Loss 0.02003
Epoch [ 11]: Loss 0.02067
Epoch [ 11]: Loss 0.01993
Validation: Loss 0.02025 Accuracy 1.00000
Validation: Loss 0.01916 Accuracy 1.00000
Epoch [ 12]: Loss 0.01973
Epoch [ 12]: Loss 0.01756
Epoch [ 12]: Loss 0.01650
Epoch [ 12]: Loss 0.01687
Epoch [ 12]: Loss 0.01800
Epoch [ 12]: Loss 0.01595
Epoch [ 12]: Loss 0.01844
Validation: Loss 0.01700 Accuracy 1.00000
Validation: Loss 0.01608 Accuracy 1.00000
Epoch [ 13]: Loss 0.01633
Epoch [ 13]: Loss 0.01592
Epoch [ 13]: Loss 0.01486
Epoch [ 13]: Loss 0.01465
Epoch [ 13]: Loss 0.01384
Epoch [ 13]: Loss 0.01358
Epoch [ 13]: Loss 0.01391
Validation: Loss 0.01464 Accuracy 1.00000
Validation: Loss 0.01384 Accuracy 1.00000
Epoch [ 14]: Loss 0.01312
Epoch [ 14]: Loss 0.01257
Epoch [ 14]: Loss 0.01316
Epoch [ 14]: Loss 0.01301
Epoch [ 14]: Loss 0.01301
Epoch [ 14]: Loss 0.01267
Epoch [ 14]: Loss 0.01121
Validation: Loss 0.01290 Accuracy 1.00000
Validation: Loss 0.01219 Accuracy 1.00000
Epoch [ 15]: Loss 0.01196
Epoch [ 15]: Loss 0.01134
Epoch [ 15]: Loss 0.01121
Epoch [ 15]: Loss 0.01131
Epoch [ 15]: Loss 0.01209
Epoch [ 15]: Loss 0.01063
Epoch [ 15]: Loss 0.01070
Validation: Loss 0.01154 Accuracy 1.00000
Validation: Loss 0.01090 Accuracy 1.00000
Epoch [ 16]: Loss 0.01107
Epoch [ 16]: Loss 0.01056
Epoch [ 16]: Loss 0.01046
Epoch [ 16]: Loss 0.01053
Epoch [ 16]: Loss 0.00978
Epoch [ 16]: Loss 0.00929
Epoch [ 16]: Loss 0.00933
Validation: Loss 0.01044 Accuracy 1.00000
Validation: Loss 0.00986 Accuracy 1.00000
Epoch [ 17]: Loss 0.00976
Epoch [ 17]: Loss 0.00938
Epoch [ 17]: Loss 0.00957
Epoch [ 17]: Loss 0.00962
Epoch [ 17]: Loss 0.00862
Epoch [ 17]: Loss 0.00905
Epoch [ 17]: Loss 0.00822
Validation: Loss 0.00954 Accuracy 1.00000
Validation: Loss 0.00900 Accuracy 1.00000
Epoch [ 18]: Loss 0.00881
Epoch [ 18]: Loss 0.00902
Epoch [ 18]: Loss 0.00868
Epoch [ 18]: Loss 0.00802
Epoch [ 18]: Loss 0.00821
Epoch [ 18]: Loss 0.00812
Epoch [ 18]: Loss 0.00905
Validation: Loss 0.00877 Accuracy 1.00000
Validation: Loss 0.00827 Accuracy 1.00000
Epoch [ 19]: Loss 0.00797
Epoch [ 19]: Loss 0.00765
Epoch [ 19]: Loss 0.00797
Epoch [ 19]: Loss 0.00767
Epoch [ 19]: Loss 0.00793
Epoch [ 19]: Loss 0.00774
Epoch [ 19]: Loss 0.00787
Validation: Loss 0.00811 Accuracy 1.00000
Validation: Loss 0.00764 Accuracy 1.00000
Epoch [ 20]: Loss 0.00791
Epoch [ 20]: Loss 0.00724
Epoch [ 20]: Loss 0.00698
Epoch [ 20]: Loss 0.00682
Epoch [ 20]: Loss 0.00696
Epoch [ 20]: Loss 0.00772
Epoch [ 20]: Loss 0.00647
Validation: Loss 0.00753 Accuracy 1.00000
Validation: Loss 0.00709 Accuracy 1.00000
Epoch [ 21]: Loss 0.00688
Epoch [ 21]: Loss 0.00665
Epoch [ 21]: Loss 0.00666
Epoch [ 21]: Loss 0.00658
Epoch [ 21]: Loss 0.00700
Epoch [ 21]: Loss 0.00676
Epoch [ 21]: Loss 0.00602
Validation: Loss 0.00702 Accuracy 1.00000
Validation: Loss 0.00661 Accuracy 1.00000
Epoch [ 22]: Loss 0.00615
Epoch [ 22]: Loss 0.00651
Epoch [ 22]: Loss 0.00678
Epoch [ 22]: Loss 0.00624
Epoch [ 22]: Loss 0.00624
Epoch [ 22]: Loss 0.00586
Epoch [ 22]: Loss 0.00597
Validation: Loss 0.00657 Accuracy 1.00000
Validation: Loss 0.00619 Accuracy 1.00000
Epoch [ 23]: Loss 0.00637
Epoch [ 23]: Loss 0.00614
Epoch [ 23]: Loss 0.00582
Epoch [ 23]: Loss 0.00537
Epoch [ 23]: Loss 0.00601
Epoch [ 23]: Loss 0.00569
Epoch [ 23]: Loss 0.00548
Validation: Loss 0.00617 Accuracy 1.00000
Validation: Loss 0.00580 Accuracy 1.00000
Epoch [ 24]: Loss 0.00521
Epoch [ 24]: Loss 0.00552
Epoch [ 24]: Loss 0.00509
Epoch [ 24]: Loss 0.00596
Epoch [ 24]: Loss 0.00569
Epoch [ 24]: Loss 0.00558
Epoch [ 24]: Loss 0.00592
Validation: Loss 0.00581 Accuracy 1.00000
Validation: Loss 0.00546 Accuracy 1.00000
Epoch [ 25]: Loss 0.00531
Epoch [ 25]: Loss 0.00511
Epoch [ 25]: Loss 0.00546
Epoch [ 25]: Loss 0.00522
Epoch [ 25]: Loss 0.00511
Epoch [ 25]: Loss 0.00509
Epoch [ 25]: Loss 0.00498
Validation: Loss 0.00548 Accuracy 1.00000
Validation: Loss 0.00515 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" {compress = true} ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 8 default, 0 interactive, 4 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 8
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 550.54.15, originally for CUDA 12.4

CUDA libraries: 
- CUBLAS: 12.5.2
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.2
- CUSPARSE: 12.4.1
- CUPTI: 23.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.4.2
- CUDA_Driver_jll: 0.9.0+0
- CUDA_Runtime_jll: 0.14.0+1

Toolchain:
- Julia: 1.10.3
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.422 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/DaGeB/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.