Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = gpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.62105
Epoch [  1]: Loss 0.59012
Epoch [  1]: Loss 0.57846
Epoch [  1]: Loss 0.53616
Epoch [  1]: Loss 0.51960
Epoch [  1]: Loss 0.50188
Epoch [  1]: Loss 0.47619
Validation: Loss 0.46270 Accuracy 1.00000
Validation: Loss 0.47484 Accuracy 1.00000
Epoch [  2]: Loss 0.46282
Epoch [  2]: Loss 0.46087
Epoch [  2]: Loss 0.44761
Epoch [  2]: Loss 0.42874
Epoch [  2]: Loss 0.41352
Epoch [  2]: Loss 0.39352
Epoch [  2]: Loss 0.36613
Validation: Loss 0.36464 Accuracy 1.00000
Validation: Loss 0.37918 Accuracy 1.00000
Epoch [  3]: Loss 0.38250
Epoch [  3]: Loss 0.35839
Epoch [  3]: Loss 0.34617
Epoch [  3]: Loss 0.32904
Epoch [  3]: Loss 0.33328
Epoch [  3]: Loss 0.29715
Epoch [  3]: Loss 0.29360
Validation: Loss 0.27939 Accuracy 1.00000
Validation: Loss 0.29558 Accuracy 1.00000
Epoch [  4]: Loss 0.29342
Epoch [  4]: Loss 0.27501
Epoch [  4]: Loss 0.27018
Epoch [  4]: Loss 0.25193
Epoch [  4]: Loss 0.25523
Epoch [  4]: Loss 0.23191
Epoch [  4]: Loss 0.20409
Validation: Loss 0.20976 Accuracy 1.00000
Validation: Loss 0.22603 Accuracy 1.00000
Epoch [  5]: Loss 0.22495
Epoch [  5]: Loss 0.20204
Epoch [  5]: Loss 0.20478
Epoch [  5]: Loss 0.19399
Epoch [  5]: Loss 0.18426
Epoch [  5]: Loss 0.17860
Epoch [  5]: Loss 0.15872
Validation: Loss 0.15488 Accuracy 1.00000
Validation: Loss 0.16980 Accuracy 1.00000
Epoch [  6]: Loss 0.16499
Epoch [  6]: Loss 0.15829
Epoch [  6]: Loss 0.13559
Epoch [  6]: Loss 0.15194
Epoch [  6]: Loss 0.14243
Epoch [  6]: Loss 0.12202
Epoch [  6]: Loss 0.14704
Validation: Loss 0.11319 Accuracy 1.00000
Validation: Loss 0.12540 Accuracy 1.00000
Epoch [  7]: Loss 0.12147
Epoch [  7]: Loss 0.12116
Epoch [  7]: Loss 0.10334
Epoch [  7]: Loss 0.09850
Epoch [  7]: Loss 0.10009
Epoch [  7]: Loss 0.09603
Epoch [  7]: Loss 0.09994
Validation: Loss 0.08091 Accuracy 1.00000
Validation: Loss 0.08983 Accuracy 1.00000
Epoch [  8]: Loss 0.08312
Epoch [  8]: Loss 0.09099
Epoch [  8]: Loss 0.07382
Epoch [  8]: Loss 0.07461
Epoch [  8]: Loss 0.06279
Epoch [  8]: Loss 0.06942
Epoch [  8]: Loss 0.06492
Validation: Loss 0.05633 Accuracy 1.00000
Validation: Loss 0.06219 Accuracy 1.00000
Epoch [  9]: Loss 0.06049
Epoch [  9]: Loss 0.05702
Epoch [  9]: Loss 0.05282
Epoch [  9]: Loss 0.05185
Epoch [  9]: Loss 0.05046
Epoch [  9]: Loss 0.04714
Epoch [  9]: Loss 0.04247
Validation: Loss 0.04164 Accuracy 1.00000
Validation: Loss 0.04569 Accuracy 1.00000
Epoch [ 10]: Loss 0.04766
Epoch [ 10]: Loss 0.04016
Epoch [ 10]: Loss 0.03854
Epoch [ 10]: Loss 0.03958
Epoch [ 10]: Loss 0.03905
Epoch [ 10]: Loss 0.03783
Epoch [ 10]: Loss 0.03740
Validation: Loss 0.03355 Accuracy 1.00000
Validation: Loss 0.03683 Accuracy 1.00000
Epoch [ 11]: Loss 0.03473
Epoch [ 11]: Loss 0.03632
Epoch [ 11]: Loss 0.03378
Epoch [ 11]: Loss 0.03240
Epoch [ 11]: Loss 0.03268
Epoch [ 11]: Loss 0.02964
Epoch [ 11]: Loss 0.03159
Validation: Loss 0.02836 Accuracy 1.00000
Validation: Loss 0.03120 Accuracy 1.00000
Epoch [ 12]: Loss 0.03158
Epoch [ 12]: Loss 0.03068
Epoch [ 12]: Loss 0.02767
Epoch [ 12]: Loss 0.02755
Epoch [ 12]: Loss 0.02723
Epoch [ 12]: Loss 0.02627
Epoch [ 12]: Loss 0.02643
Validation: Loss 0.02461 Accuracy 1.00000
Validation: Loss 0.02713 Accuracy 1.00000
Epoch [ 13]: Loss 0.02790
Epoch [ 13]: Loss 0.02456
Epoch [ 13]: Loss 0.02478
Epoch [ 13]: Loss 0.02604
Epoch [ 13]: Loss 0.02477
Epoch [ 13]: Loss 0.02167
Epoch [ 13]: Loss 0.02249
Validation: Loss 0.02173 Accuracy 1.00000
Validation: Loss 0.02400 Accuracy 1.00000
Epoch [ 14]: Loss 0.02234
Epoch [ 14]: Loss 0.02312
Epoch [ 14]: Loss 0.02333
Epoch [ 14]: Loss 0.01972
Epoch [ 14]: Loss 0.02224
Epoch [ 14]: Loss 0.02209
Epoch [ 14]: Loss 0.02009
Validation: Loss 0.01943 Accuracy 1.00000
Validation: Loss 0.02150 Accuracy 1.00000
Epoch [ 15]: Loss 0.02052
Epoch [ 15]: Loss 0.01968
Epoch [ 15]: Loss 0.01956
Epoch [ 15]: Loss 0.01926
Epoch [ 15]: Loss 0.02073
Epoch [ 15]: Loss 0.01973
Epoch [ 15]: Loss 0.01741
Validation: Loss 0.01752 Accuracy 1.00000
Validation: Loss 0.01943 Accuracy 1.00000
Epoch [ 16]: Loss 0.01952
Epoch [ 16]: Loss 0.01795
Epoch [ 16]: Loss 0.01999
Epoch [ 16]: Loss 0.01662
Epoch [ 16]: Loss 0.01575
Epoch [ 16]: Loss 0.01730
Epoch [ 16]: Loss 0.02022
Validation: Loss 0.01590 Accuracy 1.00000
Validation: Loss 0.01768 Accuracy 1.00000
Epoch [ 17]: Loss 0.01623
Epoch [ 17]: Loss 0.01831
Epoch [ 17]: Loss 0.01533
Epoch [ 17]: Loss 0.01709
Epoch [ 17]: Loss 0.01673
Epoch [ 17]: Loss 0.01474
Epoch [ 17]: Loss 0.01500
Validation: Loss 0.01452 Accuracy 1.00000
Validation: Loss 0.01617 Accuracy 1.00000
Epoch [ 18]: Loss 0.01582
Epoch [ 18]: Loss 0.01593
Epoch [ 18]: Loss 0.01487
Epoch [ 18]: Loss 0.01555
Epoch [ 18]: Loss 0.01493
Epoch [ 18]: Loss 0.01323
Epoch [ 18]: Loss 0.01321
Validation: Loss 0.01334 Accuracy 1.00000
Validation: Loss 0.01488 Accuracy 1.00000
Epoch [ 19]: Loss 0.01348
Epoch [ 19]: Loss 0.01426
Epoch [ 19]: Loss 0.01277
Epoch [ 19]: Loss 0.01448
Epoch [ 19]: Loss 0.01456
Epoch [ 19]: Loss 0.01285
Epoch [ 19]: Loss 0.01529
Validation: Loss 0.01232 Accuracy 1.00000
Validation: Loss 0.01377 Accuracy 1.00000
Epoch [ 20]: Loss 0.01322
Epoch [ 20]: Loss 0.01219
Epoch [ 20]: Loss 0.01256
Epoch [ 20]: Loss 0.01366
Epoch [ 20]: Loss 0.01241
Epoch [ 20]: Loss 0.01289
Epoch [ 20]: Loss 0.01181
Validation: Loss 0.01143 Accuracy 1.00000
Validation: Loss 0.01278 Accuracy 1.00000
Epoch [ 21]: Loss 0.01226
Epoch [ 21]: Loss 0.01272
Epoch [ 21]: Loss 0.01131
Epoch [ 21]: Loss 0.01084
Epoch [ 21]: Loss 0.01234
Epoch [ 21]: Loss 0.01221
Epoch [ 21]: Loss 0.01008
Validation: Loss 0.01062 Accuracy 1.00000
Validation: Loss 0.01189 Accuracy 1.00000
Epoch [ 22]: Loss 0.00985
Epoch [ 22]: Loss 0.01096
Epoch [ 22]: Loss 0.01161
Epoch [ 22]: Loss 0.01149
Epoch [ 22]: Loss 0.01158
Epoch [ 22]: Loss 0.01009
Epoch [ 22]: Loss 0.01356
Validation: Loss 0.00987 Accuracy 1.00000
Validation: Loss 0.01104 Accuracy 1.00000
Epoch [ 23]: Loss 0.01111
Epoch [ 23]: Loss 0.01138
Epoch [ 23]: Loss 0.00954
Epoch [ 23]: Loss 0.01012
Epoch [ 23]: Loss 0.00952
Epoch [ 23]: Loss 0.00941
Epoch [ 23]: Loss 0.01128
Validation: Loss 0.00906 Accuracy 1.00000
Validation: Loss 0.01012 Accuracy 1.00000
Epoch [ 24]: Loss 0.00973
Epoch [ 24]: Loss 0.00976
Epoch [ 24]: Loss 0.00935
Epoch [ 24]: Loss 0.00869
Epoch [ 24]: Loss 0.00934
Epoch [ 24]: Loss 0.00861
Epoch [ 24]: Loss 0.01079
Validation: Loss 0.00815 Accuracy 1.00000
Validation: Loss 0.00908 Accuracy 1.00000
Epoch [ 25]: Loss 0.00906
Epoch [ 25]: Loss 0.00833
Epoch [ 25]: Loss 0.00850
Epoch [ 25]: Loss 0.00768
Epoch [ 25]: Loss 0.00831
Epoch [ 25]: Loss 0.00794
Epoch [ 25]: Loss 0.00836
Validation: Loss 0.00726 Accuracy 1.00000
Validation: Loss 0.00805 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.62411
Epoch [  1]: Loss 0.60250
Epoch [  1]: Loss 0.55918
Epoch [  1]: Loss 0.53302
Epoch [  1]: Loss 0.52124
Epoch [  1]: Loss 0.49859
Epoch [  1]: Loss 0.48319
Validation: Loss 0.47542 Accuracy 1.00000
Validation: Loss 0.46135 Accuracy 1.00000
Epoch [  2]: Loss 0.46626
Epoch [  2]: Loss 0.45230
Epoch [  2]: Loss 0.43965
Epoch [  2]: Loss 0.42906
Epoch [  2]: Loss 0.40814
Epoch [  2]: Loss 0.39873
Epoch [  2]: Loss 0.37876
Validation: Loss 0.37879 Accuracy 1.00000
Validation: Loss 0.36238 Accuracy 1.00000
Epoch [  3]: Loss 0.37277
Epoch [  3]: Loss 0.35597
Epoch [  3]: Loss 0.35464
Epoch [  3]: Loss 0.32969
Epoch [  3]: Loss 0.30974
Epoch [  3]: Loss 0.30793
Epoch [  3]: Loss 0.28532
Validation: Loss 0.29386 Accuracy 1.00000
Validation: Loss 0.27579 Accuracy 1.00000
Epoch [  4]: Loss 0.28785
Epoch [  4]: Loss 0.28319
Epoch [  4]: Loss 0.26330
Epoch [  4]: Loss 0.23880
Epoch [  4]: Loss 0.22931
Epoch [  4]: Loss 0.24417
Epoch [  4]: Loss 0.22439
Validation: Loss 0.22371 Accuracy 1.00000
Validation: Loss 0.20525 Accuracy 1.00000
Epoch [  5]: Loss 0.21265
Epoch [  5]: Loss 0.20872
Epoch [  5]: Loss 0.20606
Epoch [  5]: Loss 0.17968
Epoch [  5]: Loss 0.17798
Epoch [  5]: Loss 0.17276
Epoch [  5]: Loss 0.16496
Validation: Loss 0.16736 Accuracy 1.00000
Validation: Loss 0.15021 Accuracy 1.00000
Epoch [  6]: Loss 0.15522
Epoch [  6]: Loss 0.15093
Epoch [  6]: Loss 0.14203
Epoch [  6]: Loss 0.14152
Epoch [  6]: Loss 0.12858
Epoch [  6]: Loss 0.13211
Epoch [  6]: Loss 0.12994
Validation: Loss 0.12318 Accuracy 1.00000
Validation: Loss 0.10892 Accuracy 1.00000
Epoch [  7]: Loss 0.11197
Epoch [  7]: Loss 0.11679
Epoch [  7]: Loss 0.10388
Epoch [  7]: Loss 0.11171
Epoch [  7]: Loss 0.08685
Epoch [  7]: Loss 0.08763
Epoch [  7]: Loss 0.09038
Validation: Loss 0.08800 Accuracy 1.00000
Validation: Loss 0.07745 Accuracy 1.00000
Epoch [  8]: Loss 0.08216
Epoch [  8]: Loss 0.08446
Epoch [  8]: Loss 0.07212
Epoch [  8]: Loss 0.07275
Epoch [  8]: Loss 0.06664
Epoch [  8]: Loss 0.06006
Epoch [  8]: Loss 0.05861
Validation: Loss 0.06108 Accuracy 1.00000
Validation: Loss 0.05411 Accuracy 1.00000
Epoch [  9]: Loss 0.05673
Epoch [  9]: Loss 0.05319
Epoch [  9]: Loss 0.05019
Epoch [  9]: Loss 0.05022
Epoch [  9]: Loss 0.05087
Epoch [  9]: Loss 0.04814
Epoch [  9]: Loss 0.03889
Validation: Loss 0.04542 Accuracy 1.00000
Validation: Loss 0.04053 Accuracy 1.00000
Epoch [ 10]: Loss 0.04298
Epoch [ 10]: Loss 0.04091
Epoch [ 10]: Loss 0.03959
Epoch [ 10]: Loss 0.03840
Epoch [ 10]: Loss 0.03912
Epoch [ 10]: Loss 0.03804
Epoch [ 10]: Loss 0.03007
Validation: Loss 0.03686 Accuracy 1.00000
Validation: Loss 0.03287 Accuracy 1.00000
Epoch [ 11]: Loss 0.03375
Epoch [ 11]: Loss 0.03373
Epoch [ 11]: Loss 0.03673
Epoch [ 11]: Loss 0.03162
Epoch [ 11]: Loss 0.03012
Epoch [ 11]: Loss 0.03026
Epoch [ 11]: Loss 0.03130
Validation: Loss 0.03139 Accuracy 1.00000
Validation: Loss 0.02790 Accuracy 1.00000
Epoch [ 12]: Loss 0.03066
Epoch [ 12]: Loss 0.02882
Epoch [ 12]: Loss 0.03109
Epoch [ 12]: Loss 0.02646
Epoch [ 12]: Loss 0.02504
Epoch [ 12]: Loss 0.02637
Epoch [ 12]: Loss 0.02725
Validation: Loss 0.02736 Accuracy 1.00000
Validation: Loss 0.02426 Accuracy 1.00000
Epoch [ 13]: Loss 0.02541
Epoch [ 13]: Loss 0.02753
Epoch [ 13]: Loss 0.02341
Epoch [ 13]: Loss 0.02444
Epoch [ 13]: Loss 0.02363
Epoch [ 13]: Loss 0.02384
Epoch [ 13]: Loss 0.02080
Validation: Loss 0.02423 Accuracy 1.00000
Validation: Loss 0.02143 Accuracy 1.00000
Epoch [ 14]: Loss 0.02227
Epoch [ 14]: Loss 0.02210
Epoch [ 14]: Loss 0.02270
Epoch [ 14]: Loss 0.02215
Epoch [ 14]: Loss 0.02180
Epoch [ 14]: Loss 0.02108
Epoch [ 14]: Loss 0.01711
Validation: Loss 0.02173 Accuracy 1.00000
Validation: Loss 0.01916 Accuracy 1.00000
Epoch [ 15]: Loss 0.02045
Epoch [ 15]: Loss 0.01908
Epoch [ 15]: Loss 0.02023
Epoch [ 15]: Loss 0.02046
Epoch [ 15]: Loss 0.01981
Epoch [ 15]: Loss 0.01799
Epoch [ 15]: Loss 0.01803
Validation: Loss 0.01965 Accuracy 1.00000
Validation: Loss 0.01728 Accuracy 1.00000
Epoch [ 16]: Loss 0.01815
Epoch [ 16]: Loss 0.01757
Epoch [ 16]: Loss 0.01568
Epoch [ 16]: Loss 0.01723
Epoch [ 16]: Loss 0.01976
Epoch [ 16]: Loss 0.01790
Epoch [ 16]: Loss 0.01833
Validation: Loss 0.01790 Accuracy 1.00000
Validation: Loss 0.01569 Accuracy 1.00000
Epoch [ 17]: Loss 0.01838
Epoch [ 17]: Loss 0.01613
Epoch [ 17]: Loss 0.01591
Epoch [ 17]: Loss 0.01555
Epoch [ 17]: Loss 0.01580
Epoch [ 17]: Loss 0.01546
Epoch [ 17]: Loss 0.01555
Validation: Loss 0.01638 Accuracy 1.00000
Validation: Loss 0.01432 Accuracy 1.00000
Epoch [ 18]: Loss 0.01492
Epoch [ 18]: Loss 0.01474
Epoch [ 18]: Loss 0.01587
Epoch [ 18]: Loss 0.01481
Epoch [ 18]: Loss 0.01435
Epoch [ 18]: Loss 0.01475
Epoch [ 18]: Loss 0.01265
Validation: Loss 0.01508 Accuracy 1.00000
Validation: Loss 0.01315 Accuracy 1.00000
Epoch [ 19]: Loss 0.01359
Epoch [ 19]: Loss 0.01383
Epoch [ 19]: Loss 0.01516
Epoch [ 19]: Loss 0.01326
Epoch [ 19]: Loss 0.01442
Epoch [ 19]: Loss 0.01206
Epoch [ 19]: Loss 0.01218
Validation: Loss 0.01394 Accuracy 1.00000
Validation: Loss 0.01215 Accuracy 1.00000
Epoch [ 20]: Loss 0.01255
Epoch [ 20]: Loss 0.01141
Epoch [ 20]: Loss 0.01402
Epoch [ 20]: Loss 0.01277
Epoch [ 20]: Loss 0.01201
Epoch [ 20]: Loss 0.01296
Epoch [ 20]: Loss 0.01272
Validation: Loss 0.01292 Accuracy 1.00000
Validation: Loss 0.01124 Accuracy 1.00000
Epoch [ 21]: Loss 0.01316
Epoch [ 21]: Loss 0.01102
Epoch [ 21]: Loss 0.01137
Epoch [ 21]: Loss 0.01191
Epoch [ 21]: Loss 0.01167
Epoch [ 21]: Loss 0.01144
Epoch [ 21]: Loss 0.01001
Validation: Loss 0.01192 Accuracy 1.00000
Validation: Loss 0.01038 Accuracy 1.00000
Epoch [ 22]: Loss 0.01180
Epoch [ 22]: Loss 0.01156
Epoch [ 22]: Loss 0.01143
Epoch [ 22]: Loss 0.01034
Epoch [ 22]: Loss 0.01000
Epoch [ 22]: Loss 0.00935
Epoch [ 22]: Loss 0.01109
Validation: Loss 0.01086 Accuracy 1.00000
Validation: Loss 0.00946 Accuracy 1.00000
Epoch [ 23]: Loss 0.00920
Epoch [ 23]: Loss 0.01079
Epoch [ 23]: Loss 0.01024
Epoch [ 23]: Loss 0.01034
Epoch [ 23]: Loss 0.00892
Epoch [ 23]: Loss 0.00905
Epoch [ 23]: Loss 0.00887
Validation: Loss 0.00966 Accuracy 1.00000
Validation: Loss 0.00845 Accuracy 1.00000
Epoch [ 24]: Loss 0.00932
Epoch [ 24]: Loss 0.00890
Epoch [ 24]: Loss 0.00932
Epoch [ 24]: Loss 0.00806
Epoch [ 24]: Loss 0.00886
Epoch [ 24]: Loss 0.00743
Epoch [ 24]: Loss 0.00850
Validation: Loss 0.00857 Accuracy 1.00000
Validation: Loss 0.00754 Accuracy 1.00000
Epoch [ 25]: Loss 0.00735
Epoch [ 25]: Loss 0.00781
Epoch [ 25]: Loss 0.00812
Epoch [ 25]: Loss 0.00814
Epoch [ 25]: Loss 0.00749
Epoch [ 25]: Loss 0.00772
Epoch [ 25]: Loss 0.00692
Validation: Loss 0.00779 Accuracy 1.00000
Validation: Loss 0.00688 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6

CUDA libraries: 
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6

Julia packages: 
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0

Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.453 GiB / 4.750 GiB available)

This page was generated using Literate.jl.