Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
Precompiling Lux...
    399.0 ms  ✓ Future
    331.7 ms  ✓ CEnum
    493.0 ms  ✓ Statistics
    326.3 ms  ✓ Reexport
    362.5 ms  ✓ Zlib_jll
    364.8 ms  ✓ StaticArraysCore
    379.3 ms  ✓ DiffResults
    477.9 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
   2138.5 ms  ✓ Hwloc
   2644.5 ms  ✓ WeightInitializers
   1524.4 ms  ✓ Setfield
    954.3 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
    450.4 ms  ✓ LuxCore → LuxCoreSetfieldExt
   3637.7 ms  ✓ ForwardDiff
   7258.5 ms  ✓ StaticArrays
    598.4 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    620.1 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    630.7 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    595.9 ms  ✓ Adapt → AdaptStaticArraysExt
    664.9 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    894.4 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   4009.2 ms  ✓ KernelAbstractions
    662.3 ms  ✓ KernelAbstractions → LinearAlgebraExt
    717.4 ms  ✓ KernelAbstractions → EnzymeExt
   5199.2 ms  ✓ NNlib
    865.2 ms  ✓ NNlib → NNlibEnzymeCoreExt
    946.1 ms  ✓ NNlib → NNlibForwardDiffExt
   5848.9 ms  ✓ LuxLib
   9170.3 ms  ✓ Lux
  29 dependencies successfully precompiled in 37 seconds. 91 already precompiled.
Precompiling LuxCUDA...
    294.6 ms  ✓ IteratorInterfaceExtensions
    351.9 ms  ✓ ExprTools
    538.2 ms  ✓ AbstractFFTs
    419.3 ms  ✓ SuiteSparse_jll
    549.9 ms  ✓ Serialization
    474.7 ms  ✓ OrderedCollections
    288.2 ms  ✓ DataValueInterfaces
    347.0 ms  ✓ DataAPI
    326.5 ms  ✓ TableTraits
   2299.1 ms  ✓ FixedPointNumbers
   2726.5 ms  ✓ TimerOutputs
   3750.6 ms  ✓ SparseArrays
   6688.7 ms  ✓ LLVM
    460.6 ms  ✓ PooledArrays
   3990.7 ms  ✓ Test
    434.2 ms  ✓ Missings
   1744.7 ms  ✓ DataStructures
    860.1 ms  ✓ Tables
    655.7 ms  ✓ Statistics → SparseArraysExt
    958.8 ms  ✓ KernelAbstractions → SparseArraysExt
   2273.5 ms  ✓ ColorTypes
   1894.9 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
   2159.9 ms  ✓ GPUArrays
   1333.5 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
    504.7 ms  ✓ SortingAlgorithms
   1343.5 ms  ✓ LLVM → BFloat16sExt
   4418.7 ms  ✓ Colors
   1388.4 ms  ✓ NVTX
  20641.1 ms  ✓ PrettyTables
  27790.8 ms  ✓ GPUCompiler
  47430.6 ms  ✓ DataFrames
  51963.9 ms  ✓ CUDA
   5130.2 ms  ✓ Atomix → AtomixCUDAExt
   8690.9 ms  ✓ cuDNN
   5936.7 ms  ✓ LuxCUDA
  35 dependencies successfully precompiled in 154 seconds. 65 already precompiled.
Precompiling MLDataDevicesGPUArraysExt...
   1469.2 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
  1 dependency successfully precompiled in 2 seconds. 42 already precompiled.
Precompiling WeightInitializersGPUArraysExt...
   1558.6 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
  1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
    590.3 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 7 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
    643.5 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    647.4 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling AbstractFFTsChainRulesCoreExt...
    429.9 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
  1 dependency successfully precompiled in 0 seconds. 9 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
   5012.5 ms  ✓ ArrayInterface → ArrayInterfaceCUDAExt
  1 dependency successfully precompiled in 5 seconds. 101 already precompiled.
Precompiling NNlibCUDAExt...
   5067.2 ms  ✓ CUDA → ChainRulesCoreExt
   5372.8 ms  ✓ NNlib → NNlibCUDAExt
  2 dependencies successfully precompiled in 6 seconds. 102 already precompiled.
Precompiling MLDataDevicesCUDAExt...
   5194.5 ms  ✓ MLDataDevices → MLDataDevicesCUDAExt
  1 dependency successfully precompiled in 6 seconds. 104 already precompiled.
Precompiling LuxLibCUDAExt...
   5269.6 ms  ✓ CUDA → SpecialFunctionsExt
   5322.2 ms  ✓ CUDA → EnzymeCoreExt
   5864.1 ms  ✓ LuxLib → LuxLibCUDAExt
  3 dependencies successfully precompiled in 6 seconds. 167 already precompiled.
Precompiling WeightInitializersCUDAExt...
   5090.8 ms  ✓ WeightInitializers → WeightInitializersCUDAExt
  1 dependency successfully precompiled in 5 seconds. 109 already precompiled.
Precompiling NNlibCUDACUDNNExt...
   5660.9 ms  ✓ NNlib → NNlibCUDACUDNNExt
  1 dependency successfully precompiled in 6 seconds. 106 already precompiled.
Precompiling MLDataDevicescuDNNExt...
   5046.1 ms  ✓ MLDataDevices → MLDataDevicescuDNNExt
  1 dependency successfully precompiled in 5 seconds. 107 already precompiled.
Precompiling LuxLibcuDNNExt...
   5806.6 ms  ✓ LuxLib → LuxLibcuDNNExt
  1 dependency successfully precompiled in 6 seconds. 174 already precompiled.
Precompiling JLD2...
  33796.3 ms  ✓ JLD2
  1 dependency successfully precompiled in 34 seconds. 31 already precompiled.
Precompiling MLUtils...
    646.5 ms  ✓ Accessors → AccessorsTestExt
   1194.8 ms  ✓ SplittablesBase
    607.4 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    541.1 ms  ✓ BangBang → BangBangTablesExt
    759.5 ms  ✓ BangBang → BangBangStaticArraysExt
   2044.6 ms  ✓ Distributed
    743.7 ms  ✓ Accessors → AccessorsStaticArraysExt
   2400.3 ms  ✓ StatsBase
   2887.4 ms  ✓ Transducers
    728.0 ms  ✓ Transducers → TransducersAdaptExt
   4996.4 ms  ✓ FLoops
   6171.6 ms  ✓ MLUtils
  12 dependencies successfully precompiled in 16 seconds. 86 already precompiled.
Precompiling BangBangDataFramesExt...
   1668.9 ms  ✓ BangBang → BangBangDataFramesExt
  1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling TransducersDataFramesExt...
   1398.5 ms  ✓ Transducers → TransducersDataFramesExt
  1 dependency successfully precompiled in 2 seconds. 61 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1609.5 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 102 already precompiled.
Precompiling LuxMLUtilsExt...
   2423.1 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 3 seconds. 177 already precompiled.
Precompiling Zygote...
    398.2 ms  ✓ FillArrays → FillArraysStatisticsExt
    589.8 ms  ✓ SuiteSparse
    669.2 ms  ✓ FillArrays → FillArraysSparseArraysExt
    746.3 ms  ✓ StructArrays
    602.1 ms  ✓ SparseInverseSubset
    414.6 ms  ✓ StructArrays → StructArraysAdaptExt
    399.2 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
    723.8 ms  ✓ StructArrays → StructArraysSparseArraysExt
   5291.4 ms  ✓ ChainRules
  34392.4 ms  ✓ Zygote
  10 dependencies successfully precompiled in 42 seconds. 76 already precompiled.
Precompiling AccessorsStructArraysExt...
    469.5 ms  ✓ Accessors → AccessorsStructArraysExt
  1 dependency successfully precompiled in 1 seconds. 16 already precompiled.
Precompiling BangBangStructArraysExt...
    496.1 ms  ✓ BangBang → BangBangStructArraysExt
  1 dependency successfully precompiled in 1 seconds. 22 already precompiled.
Precompiling StructArraysStaticArraysExt...
    655.8 ms  ✓ StructArrays → StructArraysStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling ArrayInterfaceChainRulesExt...
    823.9 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 39 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    856.4 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1656.4 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  1 dependency successfully precompiled in 2 seconds. 93 already precompiled.
Precompiling LuxZygoteExt...
   2881.0 ms  ✓ Lux → LuxZygoteExt
  1 dependency successfully precompiled in 3 seconds. 163 already precompiled.
Precompiling ZygoteColorsExt...
   1830.7 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 89 already precompiled.

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##230".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = gpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.61762
Epoch [  1]: Loss 0.60114
Epoch [  1]: Loss 0.56895
Epoch [  1]: Loss 0.53797
Epoch [  1]: Loss 0.51823
Epoch [  1]: Loss 0.49943
Epoch [  1]: Loss 0.49180
Validation: Loss 0.46457 Accuracy 1.00000
Validation: Loss 0.46646 Accuracy 1.00000
Epoch [  2]: Loss 0.47841
Epoch [  2]: Loss 0.45443
Epoch [  2]: Loss 0.43614
Epoch [  2]: Loss 0.42116
Epoch [  2]: Loss 0.40509
Epoch [  2]: Loss 0.40059
Epoch [  2]: Loss 0.40221
Validation: Loss 0.36677 Accuracy 1.00000
Validation: Loss 0.36869 Accuracy 1.00000
Epoch [  3]: Loss 0.36301
Epoch [  3]: Loss 0.36096
Epoch [  3]: Loss 0.35169
Epoch [  3]: Loss 0.33348
Epoch [  3]: Loss 0.32519
Epoch [  3]: Loss 0.30763
Epoch [  3]: Loss 0.29216
Validation: Loss 0.28121 Accuracy 1.00000
Validation: Loss 0.28338 Accuracy 1.00000
Epoch [  4]: Loss 0.28277
Epoch [  4]: Loss 0.27310
Epoch [  4]: Loss 0.26352
Epoch [  4]: Loss 0.26458
Epoch [  4]: Loss 0.24668
Epoch [  4]: Loss 0.22651
Epoch [  4]: Loss 0.24597
Validation: Loss 0.21097 Accuracy 1.00000
Validation: Loss 0.21336 Accuracy 1.00000
Epoch [  5]: Loss 0.21921
Epoch [  5]: Loss 0.21288
Epoch [  5]: Loss 0.20089
Epoch [  5]: Loss 0.18622
Epoch [  5]: Loss 0.18467
Epoch [  5]: Loss 0.17351
Epoch [  5]: Loss 0.15245
Validation: Loss 0.15547 Accuracy 1.00000
Validation: Loss 0.15796 Accuracy 1.00000
Epoch [  6]: Loss 0.14904
Epoch [  6]: Loss 0.16017
Epoch [  6]: Loss 0.14675
Epoch [  6]: Loss 0.13241
Epoch [  6]: Loss 0.13204
Epoch [  6]: Loss 0.13922
Epoch [  6]: Loss 0.15164
Validation: Loss 0.11348 Accuracy 1.00000
Validation: Loss 0.11579 Accuracy 1.00000
Epoch [  7]: Loss 0.13264
Epoch [  7]: Loss 0.10320
Epoch [  7]: Loss 0.11710
Epoch [  7]: Loss 0.09475
Epoch [  7]: Loss 0.09289
Epoch [  7]: Loss 0.09447
Epoch [  7]: Loss 0.08533
Validation: Loss 0.08112 Accuracy 1.00000
Validation: Loss 0.08292 Accuracy 1.00000
Epoch [  8]: Loss 0.08555
Epoch [  8]: Loss 0.07707
Epoch [  8]: Loss 0.08340
Epoch [  8]: Loss 0.07285
Epoch [  8]: Loss 0.06320
Epoch [  8]: Loss 0.06598
Epoch [  8]: Loss 0.06602
Validation: Loss 0.05668 Accuracy 1.00000
Validation: Loss 0.05789 Accuracy 1.00000
Epoch [  9]: Loss 0.05925
Epoch [  9]: Loss 0.05730
Epoch [  9]: Loss 0.05547
Epoch [  9]: Loss 0.04906
Epoch [  9]: Loss 0.04984
Epoch [  9]: Loss 0.04445
Epoch [  9]: Loss 0.04935
Validation: Loss 0.04204 Accuracy 1.00000
Validation: Loss 0.04284 Accuracy 1.00000
Epoch [ 10]: Loss 0.04265
Epoch [ 10]: Loss 0.04230
Epoch [ 10]: Loss 0.03959
Epoch [ 10]: Loss 0.04091
Epoch [ 10]: Loss 0.03845
Epoch [ 10]: Loss 0.03706
Epoch [ 10]: Loss 0.03961
Validation: Loss 0.03397 Accuracy 1.00000
Validation: Loss 0.03461 Accuracy 1.00000
Epoch [ 11]: Loss 0.03564
Epoch [ 11]: Loss 0.03114
Epoch [ 11]: Loss 0.03468
Epoch [ 11]: Loss 0.03287
Epoch [ 11]: Loss 0.03233
Epoch [ 11]: Loss 0.03160
Epoch [ 11]: Loss 0.03439
Validation: Loss 0.02878 Accuracy 1.00000
Validation: Loss 0.02934 Accuracy 1.00000
Epoch [ 12]: Loss 0.03170
Epoch [ 12]: Loss 0.02814
Epoch [ 12]: Loss 0.02901
Epoch [ 12]: Loss 0.02843
Epoch [ 12]: Loss 0.02558
Epoch [ 12]: Loss 0.02779
Epoch [ 12]: Loss 0.02728
Validation: Loss 0.02501 Accuracy 1.00000
Validation: Loss 0.02550 Accuracy 1.00000
Epoch [ 13]: Loss 0.02457
Epoch [ 13]: Loss 0.02605
Epoch [ 13]: Loss 0.02644
Epoch [ 13]: Loss 0.02472
Epoch [ 13]: Loss 0.02223
Epoch [ 13]: Loss 0.02430
Epoch [ 13]: Loss 0.02809
Validation: Loss 0.02211 Accuracy 1.00000
Validation: Loss 0.02255 Accuracy 1.00000
Epoch [ 14]: Loss 0.02093
Epoch [ 14]: Loss 0.02390
Epoch [ 14]: Loss 0.02233
Epoch [ 14]: Loss 0.02314
Epoch [ 14]: Loss 0.01939
Epoch [ 14]: Loss 0.02298
Epoch [ 14]: Loss 0.02155
Validation: Loss 0.01978 Accuracy 1.00000
Validation: Loss 0.02018 Accuracy 1.00000
Epoch [ 15]: Loss 0.02103
Epoch [ 15]: Loss 0.01929
Epoch [ 15]: Loss 0.01901
Epoch [ 15]: Loss 0.02047
Epoch [ 15]: Loss 0.02000
Epoch [ 15]: Loss 0.01989
Epoch [ 15]: Loss 0.01744
Validation: Loss 0.01785 Accuracy 1.00000
Validation: Loss 0.01822 Accuracy 1.00000
Epoch [ 16]: Loss 0.01937
Epoch [ 16]: Loss 0.01578
Epoch [ 16]: Loss 0.01914
Epoch [ 16]: Loss 0.01807
Epoch [ 16]: Loss 0.01926
Epoch [ 16]: Loss 0.01729
Epoch [ 16]: Loss 0.01395
Validation: Loss 0.01622 Accuracy 1.00000
Validation: Loss 0.01657 Accuracy 1.00000
Epoch [ 17]: Loss 0.01746
Epoch [ 17]: Loss 0.01738
Epoch [ 17]: Loss 0.01625
Epoch [ 17]: Loss 0.01603
Epoch [ 17]: Loss 0.01629
Epoch [ 17]: Loss 0.01549
Epoch [ 17]: Loss 0.01443
Validation: Loss 0.01483 Accuracy 1.00000
Validation: Loss 0.01515 Accuracy 1.00000
Epoch [ 18]: Loss 0.01657
Epoch [ 18]: Loss 0.01493
Epoch [ 18]: Loss 0.01632
Epoch [ 18]: Loss 0.01386
Epoch [ 18]: Loss 0.01495
Epoch [ 18]: Loss 0.01367
Epoch [ 18]: Loss 0.01487
Validation: Loss 0.01363 Accuracy 1.00000
Validation: Loss 0.01393 Accuracy 1.00000
Epoch [ 19]: Loss 0.01376
Epoch [ 19]: Loss 0.01388
Epoch [ 19]: Loss 0.01490
Epoch [ 19]: Loss 0.01303
Epoch [ 19]: Loss 0.01406
Epoch [ 19]: Loss 0.01311
Epoch [ 19]: Loss 0.01528
Validation: Loss 0.01258 Accuracy 1.00000
Validation: Loss 0.01286 Accuracy 1.00000
Epoch [ 20]: Loss 0.01283
Epoch [ 20]: Loss 0.01356
Epoch [ 20]: Loss 0.01267
Epoch [ 20]: Loss 0.01240
Epoch [ 20]: Loss 0.01299
Epoch [ 20]: Loss 0.01261
Epoch [ 20]: Loss 0.01173
Validation: Loss 0.01163 Accuracy 1.00000
Validation: Loss 0.01190 Accuracy 1.00000
Epoch [ 21]: Loss 0.01220
Epoch [ 21]: Loss 0.01300
Epoch [ 21]: Loss 0.01228
Epoch [ 21]: Loss 0.01217
Epoch [ 21]: Loss 0.01110
Epoch [ 21]: Loss 0.01087
Epoch [ 21]: Loss 0.00953
Validation: Loss 0.01072 Accuracy 1.00000
Validation: Loss 0.01097 Accuracy 1.00000
Epoch [ 22]: Loss 0.01038
Epoch [ 22]: Loss 0.01069
Epoch [ 22]: Loss 0.01024
Epoch [ 22]: Loss 0.01167
Epoch [ 22]: Loss 0.01092
Epoch [ 22]: Loss 0.01117
Epoch [ 22]: Loss 0.01109
Validation: Loss 0.00977 Accuracy 1.00000
Validation: Loss 0.01000 Accuracy 1.00000
Epoch [ 23]: Loss 0.00984
Epoch [ 23]: Loss 0.01037
Epoch [ 23]: Loss 0.00974
Epoch [ 23]: Loss 0.00957
Epoch [ 23]: Loss 0.01016
Epoch [ 23]: Loss 0.00937
Epoch [ 23]: Loss 0.00884
Validation: Loss 0.00870 Accuracy 1.00000
Validation: Loss 0.00890 Accuracy 1.00000
Epoch [ 24]: Loss 0.00944
Epoch [ 24]: Loss 0.00822
Epoch [ 24]: Loss 0.00928
Epoch [ 24]: Loss 0.00865
Epoch [ 24]: Loss 0.00774
Epoch [ 24]: Loss 0.00881
Epoch [ 24]: Loss 0.00875
Validation: Loss 0.00775 Accuracy 1.00000
Validation: Loss 0.00792 Accuracy 1.00000
Epoch [ 25]: Loss 0.00804
Epoch [ 25]: Loss 0.00770
Epoch [ 25]: Loss 0.00795
Epoch [ 25]: Loss 0.00783
Epoch [ 25]: Loss 0.00747
Epoch [ 25]: Loss 0.00796
Epoch [ 25]: Loss 0.00715
Validation: Loss 0.00707 Accuracy 1.00000
Validation: Loss 0.00722 Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.62686
Epoch [  1]: Loss 0.60238
Epoch [  1]: Loss 0.56560
Epoch [  1]: Loss 0.54008
Epoch [  1]: Loss 0.51887
Epoch [  1]: Loss 0.50023
Epoch [  1]: Loss 0.47535
Validation: Loss 0.46081 Accuracy 1.00000
Validation: Loss 0.46424 Accuracy 1.00000
Epoch [  2]: Loss 0.46765
Epoch [  2]: Loss 0.45766
Epoch [  2]: Loss 0.44123
Epoch [  2]: Loss 0.43015
Epoch [  2]: Loss 0.41003
Epoch [  2]: Loss 0.39269
Epoch [  2]: Loss 0.39828
Validation: Loss 0.36128 Accuracy 1.00000
Validation: Loss 0.36546 Accuracy 1.00000
Epoch [  3]: Loss 0.36435
Epoch [  3]: Loss 0.35550
Epoch [  3]: Loss 0.34530
Epoch [  3]: Loss 0.33401
Epoch [  3]: Loss 0.32831
Epoch [  3]: Loss 0.31457
Epoch [  3]: Loss 0.29039
Validation: Loss 0.27475 Accuracy 1.00000
Validation: Loss 0.27942 Accuracy 1.00000
Epoch [  4]: Loss 0.28417
Epoch [  4]: Loss 0.27505
Epoch [  4]: Loss 0.25364
Epoch [  4]: Loss 0.26571
Epoch [  4]: Loss 0.24322
Epoch [  4]: Loss 0.24181
Epoch [  4]: Loss 0.20724
Validation: Loss 0.20460 Accuracy 1.00000
Validation: Loss 0.20933 Accuracy 1.00000
Epoch [  5]: Loss 0.21660
Epoch [  5]: Loss 0.19363
Epoch [  5]: Loss 0.20366
Epoch [  5]: Loss 0.19392
Epoch [  5]: Loss 0.18396
Epoch [  5]: Loss 0.17375
Epoch [  5]: Loss 0.18596
Validation: Loss 0.15007 Accuracy 1.00000
Validation: Loss 0.15442 Accuracy 1.00000
Epoch [  6]: Loss 0.16112
Epoch [  6]: Loss 0.15323
Epoch [  6]: Loss 0.15114
Epoch [  6]: Loss 0.14477
Epoch [  6]: Loss 0.13492
Epoch [  6]: Loss 0.12403
Epoch [  6]: Loss 0.10891
Validation: Loss 0.10897 Accuracy 1.00000
Validation: Loss 0.11253 Accuracy 1.00000
Epoch [  7]: Loss 0.11893
Epoch [  7]: Loss 0.11821
Epoch [  7]: Loss 0.10115
Epoch [  7]: Loss 0.09920
Epoch [  7]: Loss 0.09080
Epoch [  7]: Loss 0.09665
Epoch [  7]: Loss 0.10272
Validation: Loss 0.07781 Accuracy 1.00000
Validation: Loss 0.08046 Accuracy 1.00000
Epoch [  8]: Loss 0.08218
Epoch [  8]: Loss 0.07534
Epoch [  8]: Loss 0.07830
Epoch [  8]: Loss 0.07305
Epoch [  8]: Loss 0.06916
Epoch [  8]: Loss 0.06497
Epoch [  8]: Loss 0.06468
Validation: Loss 0.05437 Accuracy 1.00000
Validation: Loss 0.05613 Accuracy 1.00000
Epoch [  9]: Loss 0.06475
Epoch [  9]: Loss 0.06051
Epoch [  9]: Loss 0.05020
Epoch [  9]: Loss 0.04801
Epoch [  9]: Loss 0.04612
Epoch [  9]: Loss 0.04395
Epoch [  9]: Loss 0.05028
Validation: Loss 0.04056 Accuracy 1.00000
Validation: Loss 0.04178 Accuracy 1.00000
Epoch [ 10]: Loss 0.04429
Epoch [ 10]: Loss 0.04477
Epoch [ 10]: Loss 0.04020
Epoch [ 10]: Loss 0.04198
Epoch [ 10]: Loss 0.03483
Epoch [ 10]: Loss 0.03405
Epoch [ 10]: Loss 0.04023
Validation: Loss 0.03280 Accuracy 1.00000
Validation: Loss 0.03378 Accuracy 1.00000
Epoch [ 11]: Loss 0.03449
Epoch [ 11]: Loss 0.03468
Epoch [ 11]: Loss 0.03225
Epoch [ 11]: Loss 0.03172
Epoch [ 11]: Loss 0.03406
Epoch [ 11]: Loss 0.03108
Epoch [ 11]: Loss 0.03117
Validation: Loss 0.02780 Accuracy 1.00000
Validation: Loss 0.02865 Accuracy 1.00000
Epoch [ 12]: Loss 0.02827
Epoch [ 12]: Loss 0.03079
Epoch [ 12]: Loss 0.03040
Epoch [ 12]: Loss 0.02774
Epoch [ 12]: Loss 0.02696
Epoch [ 12]: Loss 0.02599
Epoch [ 12]: Loss 0.02700
Validation: Loss 0.02417 Accuracy 1.00000
Validation: Loss 0.02492 Accuracy 1.00000
Epoch [ 13]: Loss 0.02566
Epoch [ 13]: Loss 0.02585
Epoch [ 13]: Loss 0.02599
Epoch [ 13]: Loss 0.02453
Epoch [ 13]: Loss 0.02391
Epoch [ 13]: Loss 0.02329
Epoch [ 13]: Loss 0.02308
Validation: Loss 0.02136 Accuracy 1.00000
Validation: Loss 0.02203 Accuracy 1.00000
Epoch [ 14]: Loss 0.02349
Epoch [ 14]: Loss 0.02217
Epoch [ 14]: Loss 0.02085
Epoch [ 14]: Loss 0.02163
Epoch [ 14]: Loss 0.02223
Epoch [ 14]: Loss 0.02188
Epoch [ 14]: Loss 0.02212
Validation: Loss 0.01911 Accuracy 1.00000
Validation: Loss 0.01973 Accuracy 1.00000
Epoch [ 15]: Loss 0.02175
Epoch [ 15]: Loss 0.01884
Epoch [ 15]: Loss 0.02095
Epoch [ 15]: Loss 0.02015
Epoch [ 15]: Loss 0.02003
Epoch [ 15]: Loss 0.01742
Epoch [ 15]: Loss 0.01929
Validation: Loss 0.01724 Accuracy 1.00000
Validation: Loss 0.01781 Accuracy 1.00000
Epoch [ 16]: Loss 0.01806
Epoch [ 16]: Loss 0.01797
Epoch [ 16]: Loss 0.01884
Epoch [ 16]: Loss 0.01874
Epoch [ 16]: Loss 0.01646
Epoch [ 16]: Loss 0.01731
Epoch [ 16]: Loss 0.01938
Validation: Loss 0.01565 Accuracy 1.00000
Validation: Loss 0.01618 Accuracy 1.00000
Epoch [ 17]: Loss 0.01657
Epoch [ 17]: Loss 0.01726
Epoch [ 17]: Loss 0.01539
Epoch [ 17]: Loss 0.01578
Epoch [ 17]: Loss 0.01659
Epoch [ 17]: Loss 0.01681
Epoch [ 17]: Loss 0.01549
Validation: Loss 0.01429 Accuracy 1.00000
Validation: Loss 0.01478 Accuracy 1.00000
Epoch [ 18]: Loss 0.01527
Epoch [ 18]: Loss 0.01676
Epoch [ 18]: Loss 0.01560
Epoch [ 18]: Loss 0.01445
Epoch [ 18]: Loss 0.01412
Epoch [ 18]: Loss 0.01450
Epoch [ 18]: Loss 0.01208
Validation: Loss 0.01311 Accuracy 1.00000
Validation: Loss 0.01357 Accuracy 1.00000
Epoch [ 19]: Loss 0.01416
Epoch [ 19]: Loss 0.01295
Epoch [ 19]: Loss 0.01384
Epoch [ 19]: Loss 0.01432
Epoch [ 19]: Loss 0.01468
Epoch [ 19]: Loss 0.01312
Epoch [ 19]: Loss 0.01259
Validation: Loss 0.01209 Accuracy 1.00000
Validation: Loss 0.01252 Accuracy 1.00000
Epoch [ 20]: Loss 0.01221
Epoch [ 20]: Loss 0.01318
Epoch [ 20]: Loss 0.01401
Epoch [ 20]: Loss 0.01288
Epoch [ 20]: Loss 0.01131
Epoch [ 20]: Loss 0.01260
Epoch [ 20]: Loss 0.01397
Validation: Loss 0.01118 Accuracy 1.00000
Validation: Loss 0.01158 Accuracy 1.00000
Epoch [ 21]: Loss 0.01163
Epoch [ 21]: Loss 0.01138
Epoch [ 21]: Loss 0.01206
Epoch [ 21]: Loss 0.01201
Epoch [ 21]: Loss 0.01100
Epoch [ 21]: Loss 0.01185
Epoch [ 21]: Loss 0.01461
Validation: Loss 0.01030 Accuracy 1.00000
Validation: Loss 0.01067 Accuracy 1.00000
Epoch [ 22]: Loss 0.01093
Epoch [ 22]: Loss 0.01047
Epoch [ 22]: Loss 0.01073
Epoch [ 22]: Loss 0.01102
Epoch [ 22]: Loss 0.01014
Epoch [ 22]: Loss 0.01100
Epoch [ 22]: Loss 0.01245
Validation: Loss 0.00935 Accuracy 1.00000
Validation: Loss 0.00968 Accuracy 1.00000
Epoch [ 23]: Loss 0.01093
Epoch [ 23]: Loss 0.01038
Epoch [ 23]: Loss 0.00876
Epoch [ 23]: Loss 0.00943
Epoch [ 23]: Loss 0.00960
Epoch [ 23]: Loss 0.00917
Epoch [ 23]: Loss 0.00938
Validation: Loss 0.00831 Accuracy 1.00000
Validation: Loss 0.00860 Accuracy 1.00000
Epoch [ 24]: Loss 0.00848
Epoch [ 24]: Loss 0.00872
Epoch [ 24]: Loss 0.00931
Epoch [ 24]: Loss 0.00804
Epoch [ 24]: Loss 0.00858
Epoch [ 24]: Loss 0.00830
Epoch [ 24]: Loss 0.00884
Validation: Loss 0.00744 Accuracy 1.00000
Validation: Loss 0.00768 Accuracy 1.00000
Epoch [ 25]: Loss 0.00798
Epoch [ 25]: Loss 0.00808
Epoch [ 25]: Loss 0.00749
Epoch [ 25]: Loss 0.00739
Epoch [ 25]: Loss 0.00827
Epoch [ 25]: Loss 0.00707
Epoch [ 25]: Loss 0.00804
Validation: Loss 0.00680 Accuracy 1.00000
Validation: Loss 0.00701 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.141 GiB / 4.750 GiB available)

This page was generated using Literate.jl.