Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
      Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(model, ps, st, (x, y))
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main()
    # Get the dataloaders
    (train_loader, val_loader) = get_dataloaders()

    # Create the model
    model = SpiralClassifier(2, 8, 1)
    rng = Xoshiro(0)

    dev = gpu_device()
    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(0.01f0); transform_variables=dev)

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            x = x |> dev
            y = y |> dev

            gs, loss, _, train_state = Lux.Experimental.compute_gradients(
                AutoZygote(), compute_loss, (x, y), train_state)
            train_state = Lux.Experimental.apply_gradients(train_state, gs)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            x = x |> dev
            y = y |> dev
            loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
            acc = accuracy(ret.y_pred, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main()
Epoch [  1]: Loss 0.56120
Epoch [  1]: Loss 0.50947
Epoch [  1]: Loss 0.47877
Epoch [  1]: Loss 0.44086
Epoch [  1]: Loss 0.42201
Epoch [  1]: Loss 0.41314
Epoch [  1]: Loss 0.39093
Validation: Loss 0.37067 Accuracy 1.00000
Validation: Loss 0.36635 Accuracy 1.00000
Epoch [  2]: Loss 0.36975
Epoch [  2]: Loss 0.35106
Epoch [  2]: Loss 0.32930
Epoch [  2]: Loss 0.31391
Epoch [  2]: Loss 0.30678
Epoch [  2]: Loss 0.28344
Epoch [  2]: Loss 0.27148
Validation: Loss 0.25886 Accuracy 1.00000
Validation: Loss 0.25655 Accuracy 1.00000
Epoch [  3]: Loss 0.25772
Epoch [  3]: Loss 0.24195
Epoch [  3]: Loss 0.23338
Epoch [  3]: Loss 0.22019
Epoch [  3]: Loss 0.20885
Epoch [  3]: Loss 0.20036
Epoch [  3]: Loss 0.18521
Validation: Loss 0.17991 Accuracy 1.00000
Validation: Loss 0.17886 Accuracy 1.00000
Epoch [  4]: Loss 0.17832
Epoch [  4]: Loss 0.17164
Epoch [  4]: Loss 0.16195
Epoch [  4]: Loss 0.15454
Epoch [  4]: Loss 0.14664
Epoch [  4]: Loss 0.14142
Epoch [  4]: Loss 0.13433
Validation: Loss 0.12798 Accuracy 1.00000
Validation: Loss 0.12725 Accuracy 1.00000
Epoch [  5]: Loss 0.12715
Epoch [  5]: Loss 0.12235
Epoch [  5]: Loss 0.11633
Epoch [  5]: Loss 0.11234
Epoch [  5]: Loss 0.10551
Epoch [  5]: Loss 0.10069
Epoch [  5]: Loss 0.09609
Validation: Loss 0.09292 Accuracy 1.00000
Validation: Loss 0.09217 Accuracy 1.00000
Epoch [  6]: Loss 0.09318
Epoch [  6]: Loss 0.08792
Epoch [  6]: Loss 0.08422
Epoch [  6]: Loss 0.08170
Epoch [  6]: Loss 0.07844
Epoch [  6]: Loss 0.07387
Epoch [  6]: Loss 0.06865
Validation: Loss 0.06860 Accuracy 1.00000
Validation: Loss 0.06775 Accuracy 1.00000
Epoch [  7]: Loss 0.06886
Epoch [  7]: Loss 0.06378
Epoch [  7]: Loss 0.06370
Epoch [  7]: Loss 0.06108
Epoch [  7]: Loss 0.05637
Epoch [  7]: Loss 0.05613
Epoch [  7]: Loss 0.04849
Validation: Loss 0.05128 Accuracy 1.00000
Validation: Loss 0.05040 Accuracy 1.00000
Epoch [  8]: Loss 0.04881
Epoch [  8]: Loss 0.05008
Epoch [  8]: Loss 0.04780
Epoch [  8]: Loss 0.04452
Epoch [  8]: Loss 0.04312
Epoch [  8]: Loss 0.04125
Epoch [  8]: Loss 0.04053
Validation: Loss 0.03872 Accuracy 1.00000
Validation: Loss 0.03786 Accuracy 1.00000
Epoch [  9]: Loss 0.03744
Epoch [  9]: Loss 0.03684
Epoch [  9]: Loss 0.03661
Epoch [  9]: Loss 0.03236
Epoch [  9]: Loss 0.03304
Epoch [  9]: Loss 0.03134
Epoch [  9]: Loss 0.03322
Validation: Loss 0.02968 Accuracy 1.00000
Validation: Loss 0.02889 Accuracy 1.00000
Epoch [ 10]: Loss 0.02938
Epoch [ 10]: Loss 0.02904
Epoch [ 10]: Loss 0.02771
Epoch [ 10]: Loss 0.02689
Epoch [ 10]: Loss 0.02510
Epoch [ 10]: Loss 0.02175
Epoch [ 10]: Loss 0.02717
Validation: Loss 0.02336 Accuracy 1.00000
Validation: Loss 0.02268 Accuracy 1.00000
Epoch [ 11]: Loss 0.02256
Epoch [ 11]: Loss 0.02271
Epoch [ 11]: Loss 0.02213
Epoch [ 11]: Loss 0.01976
Epoch [ 11]: Loss 0.01852
Epoch [ 11]: Loss 0.02178
Epoch [ 11]: Loss 0.01973
Validation: Loss 0.01900 Accuracy 1.00000
Validation: Loss 0.01842 Accuracy 1.00000
Epoch [ 12]: Loss 0.01915
Epoch [ 12]: Loss 0.01831
Epoch [ 12]: Loss 0.01794
Epoch [ 12]: Loss 0.01794
Epoch [ 12]: Loss 0.01654
Epoch [ 12]: Loss 0.01550
Epoch [ 12]: Loss 0.01477
Validation: Loss 0.01595 Accuracy 1.00000
Validation: Loss 0.01545 Accuracy 1.00000
Epoch [ 13]: Loss 0.01622
Epoch [ 13]: Loss 0.01614
Epoch [ 13]: Loss 0.01402
Epoch [ 13]: Loss 0.01419
Epoch [ 13]: Loss 0.01409
Epoch [ 13]: Loss 0.01454
Epoch [ 13]: Loss 0.01304
Validation: Loss 0.01377 Accuracy 1.00000
Validation: Loss 0.01334 Accuracy 1.00000
Epoch [ 14]: Loss 0.01470
Epoch [ 14]: Loss 0.01361
Epoch [ 14]: Loss 0.01277
Epoch [ 14]: Loss 0.01230
Epoch [ 14]: Loss 0.01225
Epoch [ 14]: Loss 0.01225
Epoch [ 14]: Loss 0.01049
Validation: Loss 0.01214 Accuracy 1.00000
Validation: Loss 0.01176 Accuracy 1.00000
Epoch [ 15]: Loss 0.01269
Epoch [ 15]: Loss 0.01240
Epoch [ 15]: Loss 0.01107
Epoch [ 15]: Loss 0.01072
Epoch [ 15]: Loss 0.01164
Epoch [ 15]: Loss 0.00994
Epoch [ 15]: Loss 0.01176
Validation: Loss 0.01087 Accuracy 1.00000
Validation: Loss 0.01052 Accuracy 1.00000
Epoch [ 16]: Loss 0.01045
Epoch [ 16]: Loss 0.01022
Epoch [ 16]: Loss 0.01029
Epoch [ 16]: Loss 0.01018
Epoch [ 16]: Loss 0.01054
Epoch [ 16]: Loss 0.01012
Epoch [ 16]: Loss 0.00925
Validation: Loss 0.00984 Accuracy 1.00000
Validation: Loss 0.00952 Accuracy 1.00000
Epoch [ 17]: Loss 0.00989
Epoch [ 17]: Loss 0.00953
Epoch [ 17]: Loss 0.00974
Epoch [ 17]: Loss 0.00907
Epoch [ 17]: Loss 0.00928
Epoch [ 17]: Loss 0.00843
Epoch [ 17]: Loss 0.00921
Validation: Loss 0.00898 Accuracy 1.00000
Validation: Loss 0.00869 Accuracy 1.00000
Epoch [ 18]: Loss 0.00921
Epoch [ 18]: Loss 0.00879
Epoch [ 18]: Loss 0.00824
Epoch [ 18]: Loss 0.00869
Epoch [ 18]: Loss 0.00819
Epoch [ 18]: Loss 0.00812
Epoch [ 18]: Loss 0.00817
Validation: Loss 0.00825 Accuracy 1.00000
Validation: Loss 0.00798 Accuracy 1.00000
Epoch [ 19]: Loss 0.00810
Epoch [ 19]: Loss 0.00791
Epoch [ 19]: Loss 0.00808
Epoch [ 19]: Loss 0.00760
Epoch [ 19]: Loss 0.00796
Epoch [ 19]: Loss 0.00771
Epoch [ 19]: Loss 0.00672
Validation: Loss 0.00763 Accuracy 1.00000
Validation: Loss 0.00737 Accuracy 1.00000
Epoch [ 20]: Loss 0.00717
Epoch [ 20]: Loss 0.00726
Epoch [ 20]: Loss 0.00733
Epoch [ 20]: Loss 0.00751
Epoch [ 20]: Loss 0.00737
Epoch [ 20]: Loss 0.00699
Epoch [ 20]: Loss 0.00704
Validation: Loss 0.00708 Accuracy 1.00000
Validation: Loss 0.00685 Accuracy 1.00000
Epoch [ 21]: Loss 0.00629
Epoch [ 21]: Loss 0.00682
Epoch [ 21]: Loss 0.00705
Epoch [ 21]: Loss 0.00657
Epoch [ 21]: Loss 0.00687
Epoch [ 21]: Loss 0.00696
Epoch [ 21]: Loss 0.00658
Validation: Loss 0.00661 Accuracy 1.00000
Validation: Loss 0.00638 Accuracy 1.00000
Epoch [ 22]: Loss 0.00677
Epoch [ 22]: Loss 0.00633
Epoch [ 22]: Loss 0.00626
Epoch [ 22]: Loss 0.00646
Epoch [ 22]: Loss 0.00611
Epoch [ 22]: Loss 0.00612
Epoch [ 22]: Loss 0.00553
Validation: Loss 0.00618 Accuracy 1.00000
Validation: Loss 0.00597 Accuracy 1.00000
Epoch [ 23]: Loss 0.00580
Epoch [ 23]: Loss 0.00589
Epoch [ 23]: Loss 0.00603
Epoch [ 23]: Loss 0.00596
Epoch [ 23]: Loss 0.00610
Epoch [ 23]: Loss 0.00574
Epoch [ 23]: Loss 0.00558
Validation: Loss 0.00580 Accuracy 1.00000
Validation: Loss 0.00560 Accuracy 1.00000
Epoch [ 24]: Loss 0.00578
Epoch [ 24]: Loss 0.00563
Epoch [ 24]: Loss 0.00585
Epoch [ 24]: Loss 0.00521
Epoch [ 24]: Loss 0.00539
Epoch [ 24]: Loss 0.00540
Epoch [ 24]: Loss 0.00571
Validation: Loss 0.00546 Accuracy 1.00000
Validation: Loss 0.00527 Accuracy 1.00000
Epoch [ 25]: Loss 0.00520
Epoch [ 25]: Loss 0.00566
Epoch [ 25]: Loss 0.00545
Epoch [ 25]: Loss 0.00504
Epoch [ 25]: Loss 0.00494
Epoch [ 25]: Loss 0.00516
Epoch [ 25]: Loss 0.00491
Validation: Loss 0.00515 Accuracy 1.00000
Validation: Loss 0.00497 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" {compress = true} ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 6 default, 0 interactive, 3 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/docs
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEBUG = Literate
  JULIA_CPU_THREADS = 2
  JULIA_NUM_THREADS = 6
  JULIA_LOAD_PATH = @:@v#.#:@stdlib
  JULIA_CUDA_HARD_MEMORY_LIMIT = 25%

CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.266 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.