Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
      Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(model, ps, st, (x, y))
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    # Get the dataloaders
    (train_loader, val_loader) = get_dataloaders()

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)

    dev = gpu_device()
    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(0.01f0); transform_variables=dev)

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            x = x |> dev
            y = y |> dev

            gs, loss, _, train_state = Lux.Experimental.compute_gradients(
                AutoZygote(), compute_loss, (x, y), train_state)
            train_state = Lux.Experimental.apply_gradients!(train_state, gs)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            x = x |> dev
            y = y |> dev
            loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
            acc = accuracy(ret.y_pred, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.56061
Epoch [  1]: Loss 0.51041
Epoch [  1]: Loss 0.47761
Epoch [  1]: Loss 0.44978
Epoch [  1]: Loss 0.43704
Epoch [  1]: Loss 0.40352
Epoch [  1]: Loss 0.40581
Validation: Loss 0.36414 Accuracy 1.00000
Validation: Loss 0.37612 Accuracy 1.00000
Epoch [  2]: Loss 0.36597
Epoch [  2]: Loss 0.35769
Epoch [  2]: Loss 0.33075
Epoch [  2]: Loss 0.32168
Epoch [  2]: Loss 0.30561
Epoch [  2]: Loss 0.28804
Epoch [  2]: Loss 0.28112
Validation: Loss 0.25651 Accuracy 1.00000
Validation: Loss 0.26338 Accuracy 1.00000
Epoch [  3]: Loss 0.25853
Epoch [  3]: Loss 0.25129
Epoch [  3]: Loss 0.22944
Epoch [  3]: Loss 0.22397
Epoch [  3]: Loss 0.21308
Epoch [  3]: Loss 0.20224
Epoch [  3]: Loss 0.19362
Validation: Loss 0.18085 Accuracy 1.00000
Validation: Loss 0.18404 Accuracy 1.00000
Epoch [  4]: Loss 0.18384
Epoch [  4]: Loss 0.17253
Epoch [  4]: Loss 0.16611
Epoch [  4]: Loss 0.15718
Epoch [  4]: Loss 0.15150
Epoch [  4]: Loss 0.14417
Epoch [  4]: Loss 0.13700
Validation: Loss 0.13015 Accuracy 1.00000
Validation: Loss 0.13211 Accuracy 1.00000
Epoch [  5]: Loss 0.13227
Epoch [  5]: Loss 0.12463
Epoch [  5]: Loss 0.11945
Epoch [  5]: Loss 0.11443
Epoch [  5]: Loss 0.10882
Epoch [  5]: Loss 0.10574
Epoch [  5]: Loss 0.10210
Validation: Loss 0.09490 Accuracy 1.00000
Validation: Loss 0.09691 Accuracy 1.00000
Epoch [  6]: Loss 0.09616
Epoch [  6]: Loss 0.09421
Epoch [  6]: Loss 0.08799
Epoch [  6]: Loss 0.08353
Epoch [  6]: Loss 0.08042
Epoch [  6]: Loss 0.07685
Epoch [  6]: Loss 0.07199
Validation: Loss 0.06992 Accuracy 1.00000
Validation: Loss 0.07210 Accuracy 1.00000
Epoch [  7]: Loss 0.07166
Epoch [  7]: Loss 0.06775
Epoch [  7]: Loss 0.06664
Epoch [  7]: Loss 0.06283
Epoch [  7]: Loss 0.05973
Epoch [  7]: Loss 0.05722
Epoch [  7]: Loss 0.05451
Validation: Loss 0.05192 Accuracy 1.00000
Validation: Loss 0.05419 Accuracy 1.00000
Epoch [  8]: Loss 0.05500
Epoch [  8]: Loss 0.04992
Epoch [  8]: Loss 0.05018
Epoch [  8]: Loss 0.04840
Epoch [  8]: Loss 0.04243
Epoch [  8]: Loss 0.04277
Epoch [  8]: Loss 0.04320
Validation: Loss 0.03883 Accuracy 1.00000
Validation: Loss 0.04100 Accuracy 1.00000
Epoch [  9]: Loss 0.03970
Epoch [  9]: Loss 0.03795
Epoch [  9]: Loss 0.03644
Epoch [  9]: Loss 0.03531
Epoch [  9]: Loss 0.03454
Epoch [  9]: Loss 0.03269
Epoch [  9]: Loss 0.03880
Validation: Loss 0.02946 Accuracy 1.00000
Validation: Loss 0.03146 Accuracy 1.00000
Epoch [ 10]: Loss 0.03066
Epoch [ 10]: Loss 0.03000
Epoch [ 10]: Loss 0.02984
Epoch [ 10]: Loss 0.02857
Epoch [ 10]: Loss 0.02536
Epoch [ 10]: Loss 0.02427
Epoch [ 10]: Loss 0.02346
Validation: Loss 0.02302 Accuracy 1.00000
Validation: Loss 0.02475 Accuracy 1.00000
Epoch [ 11]: Loss 0.02440
Epoch [ 11]: Loss 0.02317
Epoch [ 11]: Loss 0.02321
Epoch [ 11]: Loss 0.02141
Epoch [ 11]: Loss 0.01975
Epoch [ 11]: Loss 0.02140
Epoch [ 11]: Loss 0.02109
Validation: Loss 0.01865 Accuracy 1.00000
Validation: Loss 0.02016 Accuracy 1.00000
Epoch [ 12]: Loss 0.01996
Epoch [ 12]: Loss 0.01858
Epoch [ 12]: Loss 0.01899
Epoch [ 12]: Loss 0.01830
Epoch [ 12]: Loss 0.01788
Epoch [ 12]: Loss 0.01578
Epoch [ 12]: Loss 0.01818
Validation: Loss 0.01561 Accuracy 1.00000
Validation: Loss 0.01692 Accuracy 1.00000
Epoch [ 13]: Loss 0.01629
Epoch [ 13]: Loss 0.01617
Epoch [ 13]: Loss 0.01607
Epoch [ 13]: Loss 0.01503
Epoch [ 13]: Loss 0.01498
Epoch [ 13]: Loss 0.01462
Epoch [ 13]: Loss 0.01388
Validation: Loss 0.01344 Accuracy 1.00000
Validation: Loss 0.01459 Accuracy 1.00000
Epoch [ 14]: Loss 0.01425
Epoch [ 14]: Loss 0.01302
Epoch [ 14]: Loss 0.01451
Epoch [ 14]: Loss 0.01301
Epoch [ 14]: Loss 0.01293
Epoch [ 14]: Loss 0.01279
Epoch [ 14]: Loss 0.01372
Validation: Loss 0.01182 Accuracy 1.00000
Validation: Loss 0.01286 Accuracy 1.00000
Epoch [ 15]: Loss 0.01279
Epoch [ 15]: Loss 0.01245
Epoch [ 15]: Loss 0.01226
Epoch [ 15]: Loss 0.01185
Epoch [ 15]: Loss 0.01176
Epoch [ 15]: Loss 0.01073
Epoch [ 15]: Loss 0.01024
Validation: Loss 0.01055 Accuracy 1.00000
Validation: Loss 0.01149 Accuracy 1.00000
Epoch [ 16]: Loss 0.01112
Epoch [ 16]: Loss 0.01111
Epoch [ 16]: Loss 0.01057
Epoch [ 16]: Loss 0.01043
Epoch [ 16]: Loss 0.01053
Epoch [ 16]: Loss 0.01067
Epoch [ 16]: Loss 0.00923
Validation: Loss 0.00953 Accuracy 1.00000
Validation: Loss 0.01040 Accuracy 1.00000
Epoch [ 17]: Loss 0.00981
Epoch [ 17]: Loss 0.00980
Epoch [ 17]: Loss 0.00989
Epoch [ 17]: Loss 0.01004
Epoch [ 17]: Loss 0.00950
Epoch [ 17]: Loss 0.00942
Epoch [ 17]: Loss 0.00839
Validation: Loss 0.00869 Accuracy 1.00000
Validation: Loss 0.00949 Accuracy 1.00000
Epoch [ 18]: Loss 0.00909
Epoch [ 18]: Loss 0.00898
Epoch [ 18]: Loss 0.00915
Epoch [ 18]: Loss 0.00837
Epoch [ 18]: Loss 0.00935
Epoch [ 18]: Loss 0.00845
Epoch [ 18]: Loss 0.00804
Validation: Loss 0.00797 Accuracy 1.00000
Validation: Loss 0.00872 Accuracy 1.00000
Epoch [ 19]: Loss 0.00818
Epoch [ 19]: Loss 0.00820
Epoch [ 19]: Loss 0.00789
Epoch [ 19]: Loss 0.00840
Epoch [ 19]: Loss 0.00833
Epoch [ 19]: Loss 0.00810
Epoch [ 19]: Loss 0.00756
Validation: Loss 0.00736 Accuracy 1.00000
Validation: Loss 0.00807 Accuracy 1.00000
Epoch [ 20]: Loss 0.00786
Epoch [ 20]: Loss 0.00769
Epoch [ 20]: Loss 0.00781
Epoch [ 20]: Loss 0.00703
Epoch [ 20]: Loss 0.00730
Epoch [ 20]: Loss 0.00757
Epoch [ 20]: Loss 0.00776
Validation: Loss 0.00683 Accuracy 1.00000
Validation: Loss 0.00749 Accuracy 1.00000
Epoch [ 21]: Loss 0.00735
Epoch [ 21]: Loss 0.00727
Epoch [ 21]: Loss 0.00661
Epoch [ 21]: Loss 0.00702
Epoch [ 21]: Loss 0.00723
Epoch [ 21]: Loss 0.00660
Epoch [ 21]: Loss 0.00717
Validation: Loss 0.00636 Accuracy 1.00000
Validation: Loss 0.00698 Accuracy 1.00000
Epoch [ 22]: Loss 0.00636
Epoch [ 22]: Loss 0.00620
Epoch [ 22]: Loss 0.00693
Epoch [ 22]: Loss 0.00665
Epoch [ 22]: Loss 0.00619
Epoch [ 22]: Loss 0.00679
Epoch [ 22]: Loss 0.00712
Validation: Loss 0.00594 Accuracy 1.00000
Validation: Loss 0.00652 Accuracy 1.00000
Epoch [ 23]: Loss 0.00596
Epoch [ 23]: Loss 0.00608
Epoch [ 23]: Loss 0.00673
Epoch [ 23]: Loss 0.00604
Epoch [ 23]: Loss 0.00601
Epoch [ 23]: Loss 0.00593
Epoch [ 23]: Loss 0.00612
Validation: Loss 0.00557 Accuracy 1.00000
Validation: Loss 0.00612 Accuracy 1.00000
Epoch [ 24]: Loss 0.00602
Epoch [ 24]: Loss 0.00628
Epoch [ 24]: Loss 0.00541
Epoch [ 24]: Loss 0.00571
Epoch [ 24]: Loss 0.00533
Epoch [ 24]: Loss 0.00580
Epoch [ 24]: Loss 0.00554
Validation: Loss 0.00523 Accuracy 1.00000
Validation: Loss 0.00575 Accuracy 1.00000
Epoch [ 25]: Loss 0.00548
Epoch [ 25]: Loss 0.00565
Epoch [ 25]: Loss 0.00573
Epoch [ 25]: Loss 0.00553
Epoch [ 25]: Loss 0.00489
Epoch [ 25]: Loss 0.00528
Epoch [ 25]: Loss 0.00507
Validation: Loss 0.00493 Accuracy 1.00000
Validation: Loss 0.00542 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.56239
Epoch [  1]: Loss 0.51533
Epoch [  1]: Loss 0.47971
Epoch [  1]: Loss 0.44744
Epoch [  1]: Loss 0.43312
Epoch [  1]: Loss 0.40772
Epoch [  1]: Loss 0.39670
Validation: Loss 0.36821 Accuracy 1.00000
Validation: Loss 0.35638 Accuracy 1.00000
Epoch [  2]: Loss 0.36604
Epoch [  2]: Loss 0.35873
Epoch [  2]: Loss 0.33930
Epoch [  2]: Loss 0.31550
Epoch [  2]: Loss 0.30249
Epoch [  2]: Loss 0.29254
Epoch [  2]: Loss 0.27112
Validation: Loss 0.25846 Accuracy 1.00000
Validation: Loss 0.25211 Accuracy 1.00000
Epoch [  3]: Loss 0.25436
Epoch [  3]: Loss 0.25121
Epoch [  3]: Loss 0.23634
Epoch [  3]: Loss 0.22196
Epoch [  3]: Loss 0.21276
Epoch [  3]: Loss 0.19904
Epoch [  3]: Loss 0.19526
Validation: Loss 0.18116 Accuracy 1.00000
Validation: Loss 0.17837 Accuracy 1.00000
Epoch [  4]: Loss 0.18289
Epoch [  4]: Loss 0.17284
Epoch [  4]: Loss 0.16530
Epoch [  4]: Loss 0.15706
Epoch [  4]: Loss 0.14911
Epoch [  4]: Loss 0.14248
Epoch [  4]: Loss 0.13646
Validation: Loss 0.12987 Accuracy 1.00000
Validation: Loss 0.12835 Accuracy 1.00000
Epoch [  5]: Loss 0.13024
Epoch [  5]: Loss 0.12220
Epoch [  5]: Loss 0.11891
Epoch [  5]: Loss 0.11458
Epoch [  5]: Loss 0.10792
Epoch [  5]: Loss 0.10483
Epoch [  5]: Loss 0.10000
Validation: Loss 0.09455 Accuracy 1.00000
Validation: Loss 0.09275 Accuracy 1.00000
Epoch [  6]: Loss 0.09510
Epoch [  6]: Loss 0.09096
Epoch [  6]: Loss 0.08584
Epoch [  6]: Loss 0.08435
Epoch [  6]: Loss 0.08116
Epoch [  6]: Loss 0.07499
Epoch [  6]: Loss 0.07121
Validation: Loss 0.06971 Accuracy 1.00000
Validation: Loss 0.06761 Accuracy 1.00000
Epoch [  7]: Loss 0.06947
Epoch [  7]: Loss 0.06660
Epoch [  7]: Loss 0.06444
Epoch [  7]: Loss 0.06145
Epoch [  7]: Loss 0.05983
Epoch [  7]: Loss 0.05758
Epoch [  7]: Loss 0.05757
Validation: Loss 0.05188 Accuracy 1.00000
Validation: Loss 0.04960 Accuracy 1.00000
Epoch [  8]: Loss 0.05197
Epoch [  8]: Loss 0.04969
Epoch [  8]: Loss 0.04778
Epoch [  8]: Loss 0.04813
Epoch [  8]: Loss 0.04353
Epoch [  8]: Loss 0.04338
Epoch [  8]: Loss 0.04190
Validation: Loss 0.03895 Accuracy 1.00000
Validation: Loss 0.03676 Accuracy 1.00000
Epoch [  9]: Loss 0.04018
Epoch [  9]: Loss 0.03852
Epoch [  9]: Loss 0.03651
Epoch [  9]: Loss 0.03559
Epoch [  9]: Loss 0.03398
Epoch [  9]: Loss 0.03076
Epoch [  9]: Loss 0.03052
Validation: Loss 0.02967 Accuracy 1.00000
Validation: Loss 0.02767 Accuracy 1.00000
Epoch [ 10]: Loss 0.03162
Epoch [ 10]: Loss 0.02917
Epoch [ 10]: Loss 0.02937
Epoch [ 10]: Loss 0.02729
Epoch [ 10]: Loss 0.02405
Epoch [ 10]: Loss 0.02523
Epoch [ 10]: Loss 0.02200
Validation: Loss 0.02326 Accuracy 1.00000
Validation: Loss 0.02148 Accuracy 1.00000
Epoch [ 11]: Loss 0.02350
Epoch [ 11]: Loss 0.02334
Epoch [ 11]: Loss 0.02236
Epoch [ 11]: Loss 0.02152
Epoch [ 11]: Loss 0.02174
Epoch [ 11]: Loss 0.01943
Epoch [ 11]: Loss 0.02069
Validation: Loss 0.01890 Accuracy 1.00000
Validation: Loss 0.01733 Accuracy 1.00000
Epoch [ 12]: Loss 0.01885
Epoch [ 12]: Loss 0.01813
Epoch [ 12]: Loss 0.01830
Epoch [ 12]: Loss 0.01885
Epoch [ 12]: Loss 0.01781
Epoch [ 12]: Loss 0.01637
Epoch [ 12]: Loss 0.01820
Validation: Loss 0.01586 Accuracy 1.00000
Validation: Loss 0.01448 Accuracy 1.00000
Epoch [ 13]: Loss 0.01541
Epoch [ 13]: Loss 0.01590
Epoch [ 13]: Loss 0.01523
Epoch [ 13]: Loss 0.01542
Epoch [ 13]: Loss 0.01561
Epoch [ 13]: Loss 0.01444
Epoch [ 13]: Loss 0.01485
Validation: Loss 0.01367 Accuracy 1.00000
Validation: Loss 0.01246 Accuracy 1.00000
Epoch [ 14]: Loss 0.01383
Epoch [ 14]: Loss 0.01445
Epoch [ 14]: Loss 0.01419
Epoch [ 14]: Loss 0.01323
Epoch [ 14]: Loss 0.01248
Epoch [ 14]: Loss 0.01222
Epoch [ 14]: Loss 0.01175
Validation: Loss 0.01203 Accuracy 1.00000
Validation: Loss 0.01096 Accuracy 1.00000
Epoch [ 15]: Loss 0.01306
Epoch [ 15]: Loss 0.01095
Epoch [ 15]: Loss 0.01247
Epoch [ 15]: Loss 0.01132
Epoch [ 15]: Loss 0.01176
Epoch [ 15]: Loss 0.01137
Epoch [ 15]: Loss 0.01137
Validation: Loss 0.01076 Accuracy 1.00000
Validation: Loss 0.00979 Accuracy 1.00000
Epoch [ 16]: Loss 0.01076
Epoch [ 16]: Loss 0.01057
Epoch [ 16]: Loss 0.01043
Epoch [ 16]: Loss 0.01058
Epoch [ 16]: Loss 0.01093
Epoch [ 16]: Loss 0.01020
Epoch [ 16]: Loss 0.01115
Validation: Loss 0.00973 Accuracy 1.00000
Validation: Loss 0.00884 Accuracy 1.00000
Epoch [ 17]: Loss 0.00935
Epoch [ 17]: Loss 0.00987
Epoch [ 17]: Loss 0.00975
Epoch [ 17]: Loss 0.00999
Epoch [ 17]: Loss 0.00968
Epoch [ 17]: Loss 0.00916
Epoch [ 17]: Loss 0.00934
Validation: Loss 0.00887 Accuracy 1.00000
Validation: Loss 0.00805 Accuracy 1.00000
Epoch [ 18]: Loss 0.00882
Epoch [ 18]: Loss 0.00959
Epoch [ 18]: Loss 0.00946
Epoch [ 18]: Loss 0.00830
Epoch [ 18]: Loss 0.00864
Epoch [ 18]: Loss 0.00814
Epoch [ 18]: Loss 0.00834
Validation: Loss 0.00814 Accuracy 1.00000
Validation: Loss 0.00738 Accuracy 1.00000
Epoch [ 19]: Loss 0.00862
Epoch [ 19]: Loss 0.00762
Epoch [ 19]: Loss 0.00800
Epoch [ 19]: Loss 0.00820
Epoch [ 19]: Loss 0.00854
Epoch [ 19]: Loss 0.00778
Epoch [ 19]: Loss 0.00749
Validation: Loss 0.00752 Accuracy 1.00000
Validation: Loss 0.00681 Accuracy 1.00000
Epoch [ 20]: Loss 0.00768
Epoch [ 20]: Loss 0.00741
Epoch [ 20]: Loss 0.00791
Epoch [ 20]: Loss 0.00749
Epoch [ 20]: Loss 0.00737
Epoch [ 20]: Loss 0.00722
Epoch [ 20]: Loss 0.00715
Validation: Loss 0.00698 Accuracy 1.00000
Validation: Loss 0.00631 Accuracy 1.00000
Epoch [ 21]: Loss 0.00679
Epoch [ 21]: Loss 0.00689
Epoch [ 21]: Loss 0.00692
Epoch [ 21]: Loss 0.00696
Epoch [ 21]: Loss 0.00713
Epoch [ 21]: Loss 0.00704
Epoch [ 21]: Loss 0.00730
Validation: Loss 0.00650 Accuracy 1.00000
Validation: Loss 0.00587 Accuracy 1.00000
Epoch [ 22]: Loss 0.00659
Epoch [ 22]: Loss 0.00662
Epoch [ 22]: Loss 0.00682
Epoch [ 22]: Loss 0.00653
Epoch [ 22]: Loss 0.00639
Epoch [ 22]: Loss 0.00634
Epoch [ 22]: Loss 0.00548
Validation: Loss 0.00608 Accuracy 1.00000
Validation: Loss 0.00549 Accuracy 1.00000
Epoch [ 23]: Loss 0.00658
Epoch [ 23]: Loss 0.00601
Epoch [ 23]: Loss 0.00585
Epoch [ 23]: Loss 0.00637
Epoch [ 23]: Loss 0.00594
Epoch [ 23]: Loss 0.00585
Epoch [ 23]: Loss 0.00588
Validation: Loss 0.00570 Accuracy 1.00000
Validation: Loss 0.00514 Accuracy 1.00000
Epoch [ 24]: Loss 0.00600
Epoch [ 24]: Loss 0.00603
Epoch [ 24]: Loss 0.00538
Epoch [ 24]: Loss 0.00572
Epoch [ 24]: Loss 0.00565
Epoch [ 24]: Loss 0.00546
Epoch [ 24]: Loss 0.00604
Validation: Loss 0.00537 Accuracy 1.00000
Validation: Loss 0.00484 Accuracy 1.00000
Epoch [ 25]: Loss 0.00549
Epoch [ 25]: Loss 0.00556
Epoch [ 25]: Loss 0.00548
Epoch [ 25]: Loss 0.00565
Epoch [ 25]: Loss 0.00492
Epoch [ 25]: Loss 0.00517
Epoch [ 25]: Loss 0.00561
Validation: Loss 0.00506 Accuracy 1.00000
Validation: Loss 0.00456 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" {compress = true} ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 8 default, 0 interactive, 4 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 8
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.4, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.4.5
- CURAND: 10.3.5
- CUFFT: 11.2.1
- CUSOLVER: 11.6.1
- CUSPARSE: 12.3.1
- CUPTI: 22.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.3.4
- CUDA_Driver_jll: 0.8.1+0
- CUDA_Runtime_jll: 0.12.1+0

Toolchain:
- Julia: 1.10.3
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.422 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.