Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
      Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(model, ps, st, (x, y))
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main()
    # Get the dataloaders
    (train_loader, val_loader) = get_dataloaders()

    # Create the model
    model = SpiralClassifier(2, 8, 1)
    rng = Xoshiro(0)

    dev = gpu_device()
    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(0.01f0); transform_variables=dev)

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            x = x |> dev
            y = y |> dev

            gs, loss, _, train_state = Lux.Experimental.compute_gradients(
                AutoZygote(), compute_loss, (x, y), train_state)
            train_state = Lux.Experimental.apply_gradients(train_state, gs)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            x = x |> dev
            y = y |> dev
            loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
            acc = accuracy(ret.y_pred, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main()
Epoch [  1]: Loss 0.56235
Epoch [  1]: Loss 0.51151
Epoch [  1]: Loss 0.47914
Epoch [  1]: Loss 0.45635
Epoch [  1]: Loss 0.42339
Epoch [  1]: Loss 0.41773
Epoch [  1]: Loss 0.41768
Validation: Loss 0.37153 Accuracy 1.00000
Validation: Loss 0.36875 Accuracy 1.00000
Epoch [  2]: Loss 0.37936
Epoch [  2]: Loss 0.34991
Epoch [  2]: Loss 0.34274
Epoch [  2]: Loss 0.31910
Epoch [  2]: Loss 0.30310
Epoch [  2]: Loss 0.29096
Epoch [  2]: Loss 0.28280
Validation: Loss 0.26254 Accuracy 1.00000
Validation: Loss 0.26106 Accuracy 1.00000
Epoch [  3]: Loss 0.26050
Epoch [  3]: Loss 0.25222
Epoch [  3]: Loss 0.23802
Epoch [  3]: Loss 0.22420
Epoch [  3]: Loss 0.21580
Epoch [  3]: Loss 0.20403
Epoch [  3]: Loss 0.19697
Validation: Loss 0.18550 Accuracy 1.00000
Validation: Loss 0.18480 Accuracy 1.00000
Epoch [  4]: Loss 0.18422
Epoch [  4]: Loss 0.17635
Epoch [  4]: Loss 0.17018
Epoch [  4]: Loss 0.16104
Epoch [  4]: Loss 0.15308
Epoch [  4]: Loss 0.14689
Epoch [  4]: Loss 0.14509
Validation: Loss 0.13456 Accuracy 1.00000
Validation: Loss 0.13406 Accuracy 1.00000
Epoch [  5]: Loss 0.13665
Epoch [  5]: Loss 0.12755
Epoch [  5]: Loss 0.12420
Epoch [  5]: Loss 0.11720
Epoch [  5]: Loss 0.11334
Epoch [  5]: Loss 0.10736
Epoch [  5]: Loss 0.10463
Validation: Loss 0.09955 Accuracy 1.00000
Validation: Loss 0.09906 Accuracy 1.00000
Epoch [  6]: Loss 0.09854
Epoch [  6]: Loss 0.09558
Epoch [  6]: Loss 0.09159
Epoch [  6]: Loss 0.08865
Epoch [  6]: Loss 0.08344
Epoch [  6]: Loss 0.08169
Epoch [  6]: Loss 0.07743
Validation: Loss 0.07425 Accuracy 1.00000
Validation: Loss 0.07366 Accuracy 1.00000
Epoch [  7]: Loss 0.07425
Epoch [  7]: Loss 0.07175
Epoch [  7]: Loss 0.06786
Epoch [  7]: Loss 0.06666
Epoch [  7]: Loss 0.06266
Epoch [  7]: Loss 0.05979
Epoch [  7]: Loss 0.06205
Validation: Loss 0.05576 Accuracy 1.00000
Validation: Loss 0.05512 Accuracy 1.00000
Epoch [  8]: Loss 0.05682
Epoch [  8]: Loss 0.05305
Epoch [  8]: Loss 0.05003
Epoch [  8]: Loss 0.05069
Epoch [  8]: Loss 0.04792
Epoch [  8]: Loss 0.04655
Epoch [  8]: Loss 0.04070
Validation: Loss 0.04199 Accuracy 1.00000
Validation: Loss 0.04138 Accuracy 1.00000
Epoch [  9]: Loss 0.04330
Epoch [  9]: Loss 0.04062
Epoch [  9]: Loss 0.03796
Epoch [  9]: Loss 0.03726
Epoch [  9]: Loss 0.03676
Epoch [  9]: Loss 0.03459
Epoch [  9]: Loss 0.03161
Validation: Loss 0.03193 Accuracy 1.00000
Validation: Loss 0.03137 Accuracy 1.00000
Epoch [ 10]: Loss 0.03298
Epoch [ 10]: Loss 0.03151
Epoch [ 10]: Loss 0.02842
Epoch [ 10]: Loss 0.02961
Epoch [ 10]: Loss 0.02920
Epoch [ 10]: Loss 0.02529
Epoch [ 10]: Loss 0.02528
Validation: Loss 0.02498 Accuracy 1.00000
Validation: Loss 0.02450 Accuracy 1.00000
Epoch [ 11]: Loss 0.02410
Epoch [ 11]: Loss 0.02415
Epoch [ 11]: Loss 0.02397
Epoch [ 11]: Loss 0.02241
Epoch [ 11]: Loss 0.02284
Epoch [ 11]: Loss 0.02217
Epoch [ 11]: Loss 0.02287
Validation: Loss 0.02024 Accuracy 1.00000
Validation: Loss 0.01983 Accuracy 1.00000
Epoch [ 12]: Loss 0.02098
Epoch [ 12]: Loss 0.01865
Epoch [ 12]: Loss 0.01899
Epoch [ 12]: Loss 0.01814
Epoch [ 12]: Loss 0.01906
Epoch [ 12]: Loss 0.01883
Epoch [ 12]: Loss 0.01847
Validation: Loss 0.01691 Accuracy 1.00000
Validation: Loss 0.01657 Accuracy 1.00000
Epoch [ 13]: Loss 0.01665
Epoch [ 13]: Loss 0.01855
Epoch [ 13]: Loss 0.01664
Epoch [ 13]: Loss 0.01553
Epoch [ 13]: Loss 0.01510
Epoch [ 13]: Loss 0.01501
Epoch [ 13]: Loss 0.01356
Validation: Loss 0.01451 Accuracy 1.00000
Validation: Loss 0.01421 Accuracy 1.00000
Epoch [ 14]: Loss 0.01486
Epoch [ 14]: Loss 0.01377
Epoch [ 14]: Loss 0.01505
Epoch [ 14]: Loss 0.01415
Epoch [ 14]: Loss 0.01321
Epoch [ 14]: Loss 0.01317
Epoch [ 14]: Loss 0.01225
Validation: Loss 0.01273 Accuracy 1.00000
Validation: Loss 0.01247 Accuracy 1.00000
Epoch [ 15]: Loss 0.01264
Epoch [ 15]: Loss 0.01310
Epoch [ 15]: Loss 0.01270
Epoch [ 15]: Loss 0.01257
Epoch [ 15]: Loss 0.01237
Epoch [ 15]: Loss 0.01100
Epoch [ 15]: Loss 0.01082
Validation: Loss 0.01135 Accuracy 1.00000
Validation: Loss 0.01112 Accuracy 1.00000
Epoch [ 16]: Loss 0.01180
Epoch [ 16]: Loss 0.01108
Epoch [ 16]: Loss 0.01129
Epoch [ 16]: Loss 0.01075
Epoch [ 16]: Loss 0.01080
Epoch [ 16]: Loss 0.01062
Epoch [ 16]: Loss 0.01077
Validation: Loss 0.01024 Accuracy 1.00000
Validation: Loss 0.01003 Accuracy 1.00000
Epoch [ 17]: Loss 0.01040
Epoch [ 17]: Loss 0.01017
Epoch [ 17]: Loss 0.01020
Epoch [ 17]: Loss 0.00910
Epoch [ 17]: Loss 0.00994
Epoch [ 17]: Loss 0.00990
Epoch [ 17]: Loss 0.01107
Validation: Loss 0.00932 Accuracy 1.00000
Validation: Loss 0.00912 Accuracy 1.00000
Epoch [ 18]: Loss 0.00994
Epoch [ 18]: Loss 0.00918
Epoch [ 18]: Loss 0.00968
Epoch [ 18]: Loss 0.00866
Epoch [ 18]: Loss 0.00829
Epoch [ 18]: Loss 0.00889
Epoch [ 18]: Loss 0.00955
Validation: Loss 0.00853 Accuracy 1.00000
Validation: Loss 0.00836 Accuracy 1.00000
Epoch [ 19]: Loss 0.00852
Epoch [ 19]: Loss 0.00859
Epoch [ 19]: Loss 0.00828
Epoch [ 19]: Loss 0.00842
Epoch [ 19]: Loss 0.00822
Epoch [ 19]: Loss 0.00831
Epoch [ 19]: Loss 0.00805
Validation: Loss 0.00787 Accuracy 1.00000
Validation: Loss 0.00770 Accuracy 1.00000
Epoch [ 20]: Loss 0.00774
Epoch [ 20]: Loss 0.00779
Epoch [ 20]: Loss 0.00795
Epoch [ 20]: Loss 0.00773
Epoch [ 20]: Loss 0.00719
Epoch [ 20]: Loss 0.00795
Epoch [ 20]: Loss 0.00788
Validation: Loss 0.00728 Accuracy 1.00000
Validation: Loss 0.00713 Accuracy 1.00000
Epoch [ 21]: Loss 0.00777
Epoch [ 21]: Loss 0.00665
Epoch [ 21]: Loss 0.00737
Epoch [ 21]: Loss 0.00700
Epoch [ 21]: Loss 0.00680
Epoch [ 21]: Loss 0.00756
Epoch [ 21]: Loss 0.00670
Validation: Loss 0.00677 Accuracy 1.00000
Validation: Loss 0.00663 Accuracy 1.00000
Epoch [ 22]: Loss 0.00723
Epoch [ 22]: Loss 0.00626
Epoch [ 22]: Loss 0.00680
Epoch [ 22]: Loss 0.00719
Epoch [ 22]: Loss 0.00595
Epoch [ 22]: Loss 0.00717
Epoch [ 22]: Loss 0.00471
Validation: Loss 0.00632 Accuracy 1.00000
Validation: Loss 0.00619 Accuracy 1.00000
Epoch [ 23]: Loss 0.00638
Epoch [ 23]: Loss 0.00623
Epoch [ 23]: Loss 0.00629
Epoch [ 23]: Loss 0.00600
Epoch [ 23]: Loss 0.00635
Epoch [ 23]: Loss 0.00619
Epoch [ 23]: Loss 0.00643
Validation: Loss 0.00593 Accuracy 1.00000
Validation: Loss 0.00580 Accuracy 1.00000
Epoch [ 24]: Loss 0.00583
Epoch [ 24]: Loss 0.00629
Epoch [ 24]: Loss 0.00521
Epoch [ 24]: Loss 0.00583
Epoch [ 24]: Loss 0.00642
Epoch [ 24]: Loss 0.00582
Epoch [ 24]: Loss 0.00508
Validation: Loss 0.00557 Accuracy 1.00000
Validation: Loss 0.00545 Accuracy 1.00000
Epoch [ 25]: Loss 0.00581
Epoch [ 25]: Loss 0.00549
Epoch [ 25]: Loss 0.00570
Epoch [ 25]: Loss 0.00548
Epoch [ 25]: Loss 0.00502
Epoch [ 25]: Loss 0.00563
Epoch [ 25]: Loss 0.00552
Validation: Loss 0.00525 Accuracy 1.00000
Validation: Loss 0.00514 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" {compress = true} ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/docs
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEBUG = Literate
  JULIA_CPU_THREADS = 2
  JULIA_NUM_THREADS = 48
  JULIA_LOAD_PATH = @:@v#.#:@stdlib
  JULIA_CUDA_HARD_MEMORY_LIMIT = 25%

CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.328 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.