Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
      Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(model, ps, st, (x, y))
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    # Get the dataloaders
    (train_loader, val_loader) = get_dataloaders()

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)

    dev = gpu_device()
    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(0.01f0); transform_variables=dev)

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            x = x |> dev
            y = y |> dev

            gs, loss, _, train_state = Lux.Experimental.compute_gradients(
                AutoZygote(), compute_loss, (x, y), train_state)
            train_state = Lux.Experimental.apply_gradients(train_state, gs)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            x = x |> dev
            y = y |> dev
            loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
            acc = accuracy(ret.y_pred, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.56216
Epoch [  1]: Loss 0.50947
Epoch [  1]: Loss 0.47099
Epoch [  1]: Loss 0.45396
Epoch [  1]: Loss 0.43113
Epoch [  1]: Loss 0.40741
Epoch [  1]: Loss 0.37072
Validation: Loss 0.36481 Accuracy 1.00000
Validation: Loss 0.36788 Accuracy 1.00000
Epoch [  2]: Loss 0.36044
Epoch [  2]: Loss 0.36075
Epoch [  2]: Loss 0.34252
Epoch [  2]: Loss 0.31277
Epoch [  2]: Loss 0.29768
Epoch [  2]: Loss 0.28613
Epoch [  2]: Loss 0.27515
Validation: Loss 0.25598 Accuracy 1.00000
Validation: Loss 0.25828 Accuracy 1.00000
Epoch [  3]: Loss 0.25474
Epoch [  3]: Loss 0.24399
Epoch [  3]: Loss 0.23025
Epoch [  3]: Loss 0.22009
Epoch [  3]: Loss 0.21229
Epoch [  3]: Loss 0.20447
Epoch [  3]: Loss 0.19513
Validation: Loss 0.17891 Accuracy 1.00000
Validation: Loss 0.18034 Accuracy 1.00000
Epoch [  4]: Loss 0.17998
Epoch [  4]: Loss 0.16859
Epoch [  4]: Loss 0.16317
Epoch [  4]: Loss 0.15694
Epoch [  4]: Loss 0.14959
Epoch [  4]: Loss 0.14291
Epoch [  4]: Loss 0.13386
Validation: Loss 0.12768 Accuracy 1.00000
Validation: Loss 0.12861 Accuracy 1.00000
Epoch [  5]: Loss 0.12948
Epoch [  5]: Loss 0.12322
Epoch [  5]: Loss 0.11806
Epoch [  5]: Loss 0.11145
Epoch [  5]: Loss 0.10643
Epoch [  5]: Loss 0.10282
Epoch [  5]: Loss 0.09695
Validation: Loss 0.09275 Accuracy 1.00000
Validation: Loss 0.09353 Accuracy 1.00000
Epoch [  6]: Loss 0.09352
Epoch [  6]: Loss 0.08778
Epoch [  6]: Loss 0.08703
Epoch [  6]: Loss 0.08224
Epoch [  6]: Loss 0.07932
Epoch [  6]: Loss 0.07489
Epoch [  6]: Loss 0.07420
Validation: Loss 0.06824 Accuracy 1.00000
Validation: Loss 0.06899 Accuracy 1.00000
Epoch [  7]: Loss 0.06777
Epoch [  7]: Loss 0.06662
Epoch [  7]: Loss 0.06366
Epoch [  7]: Loss 0.06216
Epoch [  7]: Loss 0.05804
Epoch [  7]: Loss 0.05632
Epoch [  7]: Loss 0.05337
Validation: Loss 0.05078 Accuracy 1.00000
Validation: Loss 0.05146 Accuracy 1.00000
Epoch [  8]: Loss 0.04993
Epoch [  8]: Loss 0.05121
Epoch [  8]: Loss 0.04801
Epoch [  8]: Loss 0.04541
Epoch [  8]: Loss 0.04433
Epoch [  8]: Loss 0.04130
Epoch [  8]: Loss 0.04095
Validation: Loss 0.03810 Accuracy 1.00000
Validation: Loss 0.03871 Accuracy 1.00000
Epoch [  9]: Loss 0.03899
Epoch [  9]: Loss 0.03676
Epoch [  9]: Loss 0.03531
Epoch [  9]: Loss 0.03542
Epoch [  9]: Loss 0.03254
Epoch [  9]: Loss 0.03253
Epoch [  9]: Loss 0.03199
Validation: Loss 0.02901 Accuracy 1.00000
Validation: Loss 0.02955 Accuracy 1.00000
Epoch [ 10]: Loss 0.03114
Epoch [ 10]: Loss 0.02920
Epoch [ 10]: Loss 0.02646
Epoch [ 10]: Loss 0.02603
Epoch [ 10]: Loss 0.02513
Epoch [ 10]: Loss 0.02436
Epoch [ 10]: Loss 0.02790
Validation: Loss 0.02274 Accuracy 1.00000
Validation: Loss 0.02320 Accuracy 1.00000
Epoch [ 11]: Loss 0.02425
Epoch [ 11]: Loss 0.02198
Epoch [ 11]: Loss 0.02290
Epoch [ 11]: Loss 0.02161
Epoch [ 11]: Loss 0.01992
Epoch [ 11]: Loss 0.01920
Epoch [ 11]: Loss 0.01940
Validation: Loss 0.01845 Accuracy 1.00000
Validation: Loss 0.01885 Accuracy 1.00000
Epoch [ 12]: Loss 0.01696
Epoch [ 12]: Loss 0.01915
Epoch [ 12]: Loss 0.01794
Epoch [ 12]: Loss 0.01814
Epoch [ 12]: Loss 0.01747
Epoch [ 12]: Loss 0.01730
Epoch [ 12]: Loss 0.01502
Validation: Loss 0.01549 Accuracy 1.00000
Validation: Loss 0.01583 Accuracy 1.00000
Epoch [ 13]: Loss 0.01611
Epoch [ 13]: Loss 0.01563
Epoch [ 13]: Loss 0.01509
Epoch [ 13]: Loss 0.01577
Epoch [ 13]: Loss 0.01455
Epoch [ 13]: Loss 0.01378
Epoch [ 13]: Loss 0.01286
Validation: Loss 0.01337 Accuracy 1.00000
Validation: Loss 0.01367 Accuracy 1.00000
Epoch [ 14]: Loss 0.01389
Epoch [ 14]: Loss 0.01403
Epoch [ 14]: Loss 0.01379
Epoch [ 14]: Loss 0.01241
Epoch [ 14]: Loss 0.01249
Epoch [ 14]: Loss 0.01240
Epoch [ 14]: Loss 0.01188
Validation: Loss 0.01179 Accuracy 1.00000
Validation: Loss 0.01205 Accuracy 1.00000
Epoch [ 15]: Loss 0.01126
Epoch [ 15]: Loss 0.01172
Epoch [ 15]: Loss 0.01174
Epoch [ 15]: Loss 0.01150
Epoch [ 15]: Loss 0.01173
Epoch [ 15]: Loss 0.01172
Epoch [ 15]: Loss 0.01193
Validation: Loss 0.01055 Accuracy 1.00000
Validation: Loss 0.01080 Accuracy 1.00000
Epoch [ 16]: Loss 0.01118
Epoch [ 16]: Loss 0.01075
Epoch [ 16]: Loss 0.01037
Epoch [ 16]: Loss 0.01066
Epoch [ 16]: Loss 0.01070
Epoch [ 16]: Loss 0.00937
Epoch [ 16]: Loss 0.00964
Validation: Loss 0.00955 Accuracy 1.00000
Validation: Loss 0.00977 Accuracy 1.00000
Epoch [ 17]: Loss 0.00976
Epoch [ 17]: Loss 0.00966
Epoch [ 17]: Loss 0.00944
Epoch [ 17]: Loss 0.00927
Epoch [ 17]: Loss 0.00950
Epoch [ 17]: Loss 0.00946
Epoch [ 17]: Loss 0.00908
Validation: Loss 0.00871 Accuracy 1.00000
Validation: Loss 0.00892 Accuracy 1.00000
Epoch [ 18]: Loss 0.00862
Epoch [ 18]: Loss 0.00898
Epoch [ 18]: Loss 0.00896
Epoch [ 18]: Loss 0.00844
Epoch [ 18]: Loss 0.00816
Epoch [ 18]: Loss 0.00898
Epoch [ 18]: Loss 0.00880
Validation: Loss 0.00800 Accuracy 1.00000
Validation: Loss 0.00819 Accuracy 1.00000
Epoch [ 19]: Loss 0.00871
Epoch [ 19]: Loss 0.00848
Epoch [ 19]: Loss 0.00818
Epoch [ 19]: Loss 0.00797
Epoch [ 19]: Loss 0.00786
Epoch [ 19]: Loss 0.00674
Epoch [ 19]: Loss 0.00858
Validation: Loss 0.00739 Accuracy 1.00000
Validation: Loss 0.00757 Accuracy 1.00000
Epoch [ 20]: Loss 0.00742
Epoch [ 20]: Loss 0.00726
Epoch [ 20]: Loss 0.00791
Epoch [ 20]: Loss 0.00712
Epoch [ 20]: Loss 0.00734
Epoch [ 20]: Loss 0.00731
Epoch [ 20]: Loss 0.00779
Validation: Loss 0.00686 Accuracy 1.00000
Validation: Loss 0.00702 Accuracy 1.00000
Epoch [ 21]: Loss 0.00708
Epoch [ 21]: Loss 0.00719
Epoch [ 21]: Loss 0.00706
Epoch [ 21]: Loss 0.00685
Epoch [ 21]: Loss 0.00659
Epoch [ 21]: Loss 0.00644
Epoch [ 21]: Loss 0.00742
Validation: Loss 0.00639 Accuracy 1.00000
Validation: Loss 0.00654 Accuracy 1.00000
Epoch [ 22]: Loss 0.00697
Epoch [ 22]: Loss 0.00655
Epoch [ 22]: Loss 0.00606
Epoch [ 22]: Loss 0.00657
Epoch [ 22]: Loss 0.00613
Epoch [ 22]: Loss 0.00620
Epoch [ 22]: Loss 0.00675
Validation: Loss 0.00597 Accuracy 1.00000
Validation: Loss 0.00611 Accuracy 1.00000
Epoch [ 23]: Loss 0.00601
Epoch [ 23]: Loss 0.00576
Epoch [ 23]: Loss 0.00563
Epoch [ 23]: Loss 0.00650
Epoch [ 23]: Loss 0.00609
Epoch [ 23]: Loss 0.00600
Epoch [ 23]: Loss 0.00638
Validation: Loss 0.00560 Accuracy 1.00000
Validation: Loss 0.00574 Accuracy 1.00000
Epoch [ 24]: Loss 0.00577
Epoch [ 24]: Loss 0.00583
Epoch [ 24]: Loss 0.00534
Epoch [ 24]: Loss 0.00586
Epoch [ 24]: Loss 0.00589
Epoch [ 24]: Loss 0.00538
Epoch [ 24]: Loss 0.00499
Validation: Loss 0.00527 Accuracy 1.00000
Validation: Loss 0.00540 Accuracy 1.00000
Epoch [ 25]: Loss 0.00562
Epoch [ 25]: Loss 0.00504
Epoch [ 25]: Loss 0.00506
Epoch [ 25]: Loss 0.00509
Epoch [ 25]: Loss 0.00578
Epoch [ 25]: Loss 0.00528
Epoch [ 25]: Loss 0.00549
Validation: Loss 0.00497 Accuracy 1.00000
Validation: Loss 0.00509 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.56216
Epoch [  1]: Loss 0.50868
Epoch [  1]: Loss 0.47372
Epoch [  1]: Loss 0.44439
Epoch [  1]: Loss 0.42985
Epoch [  1]: Loss 0.40872
Epoch [  1]: Loss 0.36998
Validation: Loss 0.36843 Accuracy 1.00000
Validation: Loss 0.37179 Accuracy 1.00000
Epoch [  2]: Loss 0.36474
Epoch [  2]: Loss 0.34381
Epoch [  2]: Loss 0.34225
Epoch [  2]: Loss 0.32119
Epoch [  2]: Loss 0.29820
Epoch [  2]: Loss 0.28367
Epoch [  2]: Loss 0.26663
Validation: Loss 0.25806 Accuracy 1.00000
Validation: Loss 0.26088 Accuracy 1.00000
Epoch [  3]: Loss 0.25824
Epoch [  3]: Loss 0.24314
Epoch [  3]: Loss 0.23112
Epoch [  3]: Loss 0.21876
Epoch [  3]: Loss 0.21075
Epoch [  3]: Loss 0.19981
Epoch [  3]: Loss 0.18853
Validation: Loss 0.17962 Accuracy 1.00000
Validation: Loss 0.18157 Accuracy 1.00000
Epoch [  4]: Loss 0.18040
Epoch [  4]: Loss 0.17081
Epoch [  4]: Loss 0.16053
Epoch [  4]: Loss 0.15302
Epoch [  4]: Loss 0.14704
Epoch [  4]: Loss 0.14122
Epoch [  4]: Loss 0.13906
Validation: Loss 0.12760 Accuracy 1.00000
Validation: Loss 0.12901 Accuracy 1.00000
Epoch [  5]: Loss 0.12808
Epoch [  5]: Loss 0.12097
Epoch [  5]: Loss 0.11625
Epoch [  5]: Loss 0.10962
Epoch [  5]: Loss 0.10535
Epoch [  5]: Loss 0.10234
Epoch [  5]: Loss 0.10095
Validation: Loss 0.09255 Accuracy 1.00000
Validation: Loss 0.09366 Accuracy 1.00000
Epoch [  6]: Loss 0.09217
Epoch [  6]: Loss 0.08949
Epoch [  6]: Loss 0.08399
Epoch [  6]: Loss 0.08128
Epoch [  6]: Loss 0.07689
Epoch [  6]: Loss 0.07450
Epoch [  6]: Loss 0.07211
Validation: Loss 0.06819 Accuracy 1.00000
Validation: Loss 0.06912 Accuracy 1.00000
Epoch [  7]: Loss 0.06926
Epoch [  7]: Loss 0.06516
Epoch [  7]: Loss 0.06327
Epoch [  7]: Loss 0.06017
Epoch [  7]: Loss 0.05699
Epoch [  7]: Loss 0.05319
Epoch [  7]: Loss 0.05490
Validation: Loss 0.05080 Accuracy 1.00000
Validation: Loss 0.05159 Accuracy 1.00000
Epoch [  8]: Loss 0.04940
Epoch [  8]: Loss 0.04820
Epoch [  8]: Loss 0.04708
Epoch [  8]: Loss 0.04425
Epoch [  8]: Loss 0.04362
Epoch [  8]: Loss 0.04149
Epoch [  8]: Loss 0.04321
Validation: Loss 0.03819 Accuracy 1.00000
Validation: Loss 0.03887 Accuracy 1.00000
Epoch [  9]: Loss 0.03912
Epoch [  9]: Loss 0.03695
Epoch [  9]: Loss 0.03401
Epoch [  9]: Loss 0.03360
Epoch [  9]: Loss 0.03163
Epoch [  9]: Loss 0.03180
Epoch [  9]: Loss 0.03070
Validation: Loss 0.02914 Accuracy 1.00000
Validation: Loss 0.02973 Accuracy 1.00000
Epoch [ 10]: Loss 0.03030
Epoch [ 10]: Loss 0.02597
Epoch [ 10]: Loss 0.02902
Epoch [ 10]: Loss 0.02525
Epoch [ 10]: Loss 0.02521
Epoch [ 10]: Loss 0.02366
Epoch [ 10]: Loss 0.02336
Validation: Loss 0.02290 Accuracy 1.00000
Validation: Loss 0.02341 Accuracy 1.00000
Epoch [ 11]: Loss 0.02286
Epoch [ 11]: Loss 0.02164
Epoch [ 11]: Loss 0.02158
Epoch [ 11]: Loss 0.01980
Epoch [ 11]: Loss 0.02143
Epoch [ 11]: Loss 0.01945
Epoch [ 11]: Loss 0.01864
Validation: Loss 0.01866 Accuracy 1.00000
Validation: Loss 0.01910 Accuracy 1.00000
Epoch [ 12]: Loss 0.01724
Epoch [ 12]: Loss 0.01858
Epoch [ 12]: Loss 0.01719
Epoch [ 12]: Loss 0.01757
Epoch [ 12]: Loss 0.01665
Epoch [ 12]: Loss 0.01684
Epoch [ 12]: Loss 0.01681
Validation: Loss 0.01570 Accuracy 1.00000
Validation: Loss 0.01609 Accuracy 1.00000
Epoch [ 13]: Loss 0.01583
Epoch [ 13]: Loss 0.01480
Epoch [ 13]: Loss 0.01407
Epoch [ 13]: Loss 0.01583
Epoch [ 13]: Loss 0.01396
Epoch [ 13]: Loss 0.01424
Epoch [ 13]: Loss 0.01335
Validation: Loss 0.01356 Accuracy 1.00000
Validation: Loss 0.01390 Accuracy 1.00000
Epoch [ 14]: Loss 0.01397
Epoch [ 14]: Loss 0.01308
Epoch [ 14]: Loss 0.01396
Epoch [ 14]: Loss 0.01217
Epoch [ 14]: Loss 0.01212
Epoch [ 14]: Loss 0.01224
Epoch [ 14]: Loss 0.01079
Validation: Loss 0.01196 Accuracy 1.00000
Validation: Loss 0.01226 Accuracy 1.00000
Epoch [ 15]: Loss 0.01181
Epoch [ 15]: Loss 0.01158
Epoch [ 15]: Loss 0.01114
Epoch [ 15]: Loss 0.01218
Epoch [ 15]: Loss 0.01163
Epoch [ 15]: Loss 0.01041
Epoch [ 15]: Loss 0.00963
Validation: Loss 0.01071 Accuracy 1.00000
Validation: Loss 0.01098 Accuracy 1.00000
Epoch [ 16]: Loss 0.01042
Epoch [ 16]: Loss 0.01030
Epoch [ 16]: Loss 0.01024
Epoch [ 16]: Loss 0.01022
Epoch [ 16]: Loss 0.01050
Epoch [ 16]: Loss 0.01015
Epoch [ 16]: Loss 0.00866
Validation: Loss 0.00971 Accuracy 1.00000
Validation: Loss 0.00997 Accuracy 1.00000
Epoch [ 17]: Loss 0.01032
Epoch [ 17]: Loss 0.00935
Epoch [ 17]: Loss 0.00956
Epoch [ 17]: Loss 0.00872
Epoch [ 17]: Loss 0.00930
Epoch [ 17]: Loss 0.00887
Epoch [ 17]: Loss 0.00841
Validation: Loss 0.00887 Accuracy 1.00000
Validation: Loss 0.00910 Accuracy 1.00000
Epoch [ 18]: Loss 0.00958
Epoch [ 18]: Loss 0.00849
Epoch [ 18]: Loss 0.00897
Epoch [ 18]: Loss 0.00809
Epoch [ 18]: Loss 0.00799
Epoch [ 18]: Loss 0.00831
Epoch [ 18]: Loss 0.00751
Validation: Loss 0.00816 Accuracy 1.00000
Validation: Loss 0.00838 Accuracy 1.00000
Epoch [ 19]: Loss 0.00837
Epoch [ 19]: Loss 0.00796
Epoch [ 19]: Loss 0.00783
Epoch [ 19]: Loss 0.00791
Epoch [ 19]: Loss 0.00720
Epoch [ 19]: Loss 0.00788
Epoch [ 19]: Loss 0.00783
Validation: Loss 0.00754 Accuracy 1.00000
Validation: Loss 0.00774 Accuracy 1.00000
Epoch [ 20]: Loss 0.00731
Epoch [ 20]: Loss 0.00755
Epoch [ 20]: Loss 0.00772
Epoch [ 20]: Loss 0.00784
Epoch [ 20]: Loss 0.00662
Epoch [ 20]: Loss 0.00671
Epoch [ 20]: Loss 0.00702
Validation: Loss 0.00700 Accuracy 1.00000
Validation: Loss 0.00719 Accuracy 1.00000
Epoch [ 21]: Loss 0.00712
Epoch [ 21]: Loss 0.00690
Epoch [ 21]: Loss 0.00672
Epoch [ 21]: Loss 0.00594
Epoch [ 21]: Loss 0.00671
Epoch [ 21]: Loss 0.00702
Epoch [ 21]: Loss 0.00752
Validation: Loss 0.00653 Accuracy 1.00000
Validation: Loss 0.00671 Accuracy 1.00000
Epoch [ 22]: Loss 0.00670
Epoch [ 22]: Loss 0.00615
Epoch [ 22]: Loss 0.00674
Epoch [ 22]: Loss 0.00633
Epoch [ 22]: Loss 0.00591
Epoch [ 22]: Loss 0.00625
Epoch [ 22]: Loss 0.00563
Validation: Loss 0.00610 Accuracy 1.00000
Validation: Loss 0.00627 Accuracy 1.00000
Epoch [ 23]: Loss 0.00648
Epoch [ 23]: Loss 0.00603
Epoch [ 23]: Loss 0.00580
Epoch [ 23]: Loss 0.00591
Epoch [ 23]: Loss 0.00577
Epoch [ 23]: Loss 0.00554
Epoch [ 23]: Loss 0.00577
Validation: Loss 0.00573 Accuracy 1.00000
Validation: Loss 0.00588 Accuracy 1.00000
Epoch [ 24]: Loss 0.00549
Epoch [ 24]: Loss 0.00543
Epoch [ 24]: Loss 0.00569
Epoch [ 24]: Loss 0.00584
Epoch [ 24]: Loss 0.00532
Epoch [ 24]: Loss 0.00559
Epoch [ 24]: Loss 0.00542
Validation: Loss 0.00539 Accuracy 1.00000
Validation: Loss 0.00554 Accuracy 1.00000
Epoch [ 25]: Loss 0.00541
Epoch [ 25]: Loss 0.00557
Epoch [ 25]: Loss 0.00539
Epoch [ 25]: Loss 0.00482
Epoch [ 25]: Loss 0.00526
Epoch [ 25]: Loss 0.00513
Epoch [ 25]: Loss 0.00453
Validation: Loss 0.00508 Accuracy 1.00000
Validation: Loss 0.00523 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" {compress = true} ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 8 default, 0 interactive, 4 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_NUM_THREADS = 8
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.4, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.4.5
- CURAND: 10.3.5
- CUFFT: 11.2.1
- CUSOLVER: 11.6.1
- CUSPARSE: 12.3.1
- CUPTI: 22.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.3.3
- CUDA_Driver_jll: 0.8.1+0
- CUDA_Runtime_jll: 0.12.1+0

Toolchain:
- Julia: 1.10.3
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.359 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.