Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
      Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(model, ps, st, (x, y))
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main()
    # Get the dataloaders
    (train_loader, val_loader) = get_dataloaders()

    # Create the model
    model = SpiralClassifier(2, 8, 1)
    rng = Xoshiro(0)

    dev = gpu_device()
    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(0.01f0); transform_variables=dev)

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            x = x |> dev
            y = y |> dev

            gs, loss, _, train_state = Lux.Experimental.compute_gradients(
                AutoZygote(), compute_loss, (x, y), train_state)
            train_state = Lux.Experimental.apply_gradients(train_state, gs)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            x = x |> dev
            y = y |> dev
            loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
            acc = accuracy(ret.y_pred, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main()
Epoch [  1]: Loss 0.56463
Epoch [  1]: Loss 0.50721
Epoch [  1]: Loss 0.46992
Epoch [  1]: Loss 0.44816
Epoch [  1]: Loss 0.42660
Epoch [  1]: Loss 0.40369
Epoch [  1]: Loss 0.40457
Validation: Loss 0.37170 Accuracy 1.00000
Validation: Loss 0.37710 Accuracy 1.00000
Epoch [  2]: Loss 0.37565
Epoch [  2]: Loss 0.34933
Epoch [  2]: Loss 0.32690
Epoch [  2]: Loss 0.32020
Epoch [  2]: Loss 0.30072
Epoch [  2]: Loss 0.29085
Epoch [  2]: Loss 0.25987
Validation: Loss 0.26075 Accuracy 1.00000
Validation: Loss 0.26443 Accuracy 1.00000
Epoch [  3]: Loss 0.26233
Epoch [  3]: Loss 0.24794
Epoch [  3]: Loss 0.23020
Epoch [  3]: Loss 0.22393
Epoch [  3]: Loss 0.20936
Epoch [  3]: Loss 0.19822
Epoch [  3]: Loss 0.19010
Validation: Loss 0.18216 Accuracy 1.00000
Validation: Loss 0.18443 Accuracy 1.00000
Epoch [  4]: Loss 0.17958
Epoch [  4]: Loss 0.17405
Epoch [  4]: Loss 0.16180
Epoch [  4]: Loss 0.15636
Epoch [  4]: Loss 0.15067
Epoch [  4]: Loss 0.14136
Epoch [  4]: Loss 0.13588
Validation: Loss 0.13000 Accuracy 1.00000
Validation: Loss 0.13157 Accuracy 1.00000
Epoch [  5]: Loss 0.12789
Epoch [  5]: Loss 0.12253
Epoch [  5]: Loss 0.11746
Epoch [  5]: Loss 0.11366
Epoch [  5]: Loss 0.10759
Epoch [  5]: Loss 0.10393
Epoch [  5]: Loss 0.09666
Validation: Loss 0.09469 Accuracy 1.00000
Validation: Loss 0.09604 Accuracy 1.00000
Epoch [  6]: Loss 0.09401
Epoch [  6]: Loss 0.09065
Epoch [  6]: Loss 0.08637
Epoch [  6]: Loss 0.08198
Epoch [  6]: Loss 0.07886
Epoch [  6]: Loss 0.07415
Epoch [  6]: Loss 0.07407
Validation: Loss 0.07004 Accuracy 1.00000
Validation: Loss 0.07127 Accuracy 1.00000
Epoch [  7]: Loss 0.06808
Epoch [  7]: Loss 0.06647
Epoch [  7]: Loss 0.06360
Epoch [  7]: Loss 0.05931
Epoch [  7]: Loss 0.05855
Epoch [  7]: Loss 0.05807
Epoch [  7]: Loss 0.05811
Validation: Loss 0.05238 Accuracy 1.00000
Validation: Loss 0.05348 Accuracy 1.00000
Epoch [  8]: Loss 0.05238
Epoch [  8]: Loss 0.05103
Epoch [  8]: Loss 0.04656
Epoch [  8]: Loss 0.04507
Epoch [  8]: Loss 0.04239
Epoch [  8]: Loss 0.04274
Epoch [  8]: Loss 0.04218
Validation: Loss 0.03947 Accuracy 1.00000
Validation: Loss 0.04042 Accuracy 1.00000
Epoch [  9]: Loss 0.04073
Epoch [  9]: Loss 0.03685
Epoch [  9]: Loss 0.03457
Epoch [  9]: Loss 0.03385
Epoch [  9]: Loss 0.03354
Epoch [  9]: Loss 0.03252
Epoch [  9]: Loss 0.03013
Validation: Loss 0.03021 Accuracy 1.00000
Validation: Loss 0.03105 Accuracy 1.00000
Epoch [ 10]: Loss 0.02917
Epoch [ 10]: Loss 0.02908
Epoch [ 10]: Loss 0.02776
Epoch [ 10]: Loss 0.02556
Epoch [ 10]: Loss 0.02717
Epoch [ 10]: Loss 0.02490
Epoch [ 10]: Loss 0.02249
Validation: Loss 0.02382 Accuracy 1.00000
Validation: Loss 0.02456 Accuracy 1.00000
Epoch [ 11]: Loss 0.02467
Epoch [ 11]: Loss 0.02294
Epoch [ 11]: Loss 0.02209
Epoch [ 11]: Loss 0.02046
Epoch [ 11]: Loss 0.02095
Epoch [ 11]: Loss 0.01954
Epoch [ 11]: Loss 0.01759
Validation: Loss 0.01944 Accuracy 1.00000
Validation: Loss 0.02009 Accuracy 1.00000
Epoch [ 12]: Loss 0.01899
Epoch [ 12]: Loss 0.01879
Epoch [ 12]: Loss 0.01721
Epoch [ 12]: Loss 0.01761
Epoch [ 12]: Loss 0.01707
Epoch [ 12]: Loss 0.01766
Epoch [ 12]: Loss 0.01605
Validation: Loss 0.01640 Accuracy 1.00000
Validation: Loss 0.01697 Accuracy 1.00000
Epoch [ 13]: Loss 0.01670
Epoch [ 13]: Loss 0.01623
Epoch [ 13]: Loss 0.01391
Epoch [ 13]: Loss 0.01418
Epoch [ 13]: Loss 0.01538
Epoch [ 13]: Loss 0.01497
Epoch [ 13]: Loss 0.01379
Validation: Loss 0.01419 Accuracy 1.00000
Validation: Loss 0.01470 Accuracy 1.00000
Epoch [ 14]: Loss 0.01313
Epoch [ 14]: Loss 0.01393
Epoch [ 14]: Loss 0.01407
Epoch [ 14]: Loss 0.01244
Epoch [ 14]: Loss 0.01258
Epoch [ 14]: Loss 0.01296
Epoch [ 14]: Loss 0.01401
Validation: Loss 0.01252 Accuracy 1.00000
Validation: Loss 0.01298 Accuracy 1.00000
Epoch [ 15]: Loss 0.01239
Epoch [ 15]: Loss 0.01169
Epoch [ 15]: Loss 0.01255
Epoch [ 15]: Loss 0.01130
Epoch [ 15]: Loss 0.01130
Epoch [ 15]: Loss 0.01156
Epoch [ 15]: Loss 0.01020
Validation: Loss 0.01121 Accuracy 1.00000
Validation: Loss 0.01162 Accuracy 1.00000
Epoch [ 16]: Loss 0.01029
Epoch [ 16]: Loss 0.01038
Epoch [ 16]: Loss 0.01079
Epoch [ 16]: Loss 0.01082
Epoch [ 16]: Loss 0.01064
Epoch [ 16]: Loss 0.01021
Epoch [ 16]: Loss 0.01097
Validation: Loss 0.01015 Accuracy 1.00000
Validation: Loss 0.01052 Accuracy 1.00000
Epoch [ 17]: Loss 0.00937
Epoch [ 17]: Loss 0.00912
Epoch [ 17]: Loss 0.01075
Epoch [ 17]: Loss 0.00960
Epoch [ 17]: Loss 0.00964
Epoch [ 17]: Loss 0.00905
Epoch [ 17]: Loss 0.00925
Validation: Loss 0.00926 Accuracy 1.00000
Validation: Loss 0.00961 Accuracy 1.00000
Epoch [ 18]: Loss 0.00918
Epoch [ 18]: Loss 0.00851
Epoch [ 18]: Loss 0.00864
Epoch [ 18]: Loss 0.00886
Epoch [ 18]: Loss 0.00821
Epoch [ 18]: Loss 0.00917
Epoch [ 18]: Loss 0.00873
Validation: Loss 0.00851 Accuracy 1.00000
Validation: Loss 0.00883 Accuracy 1.00000
Epoch [ 19]: Loss 0.00835
Epoch [ 19]: Loss 0.00818
Epoch [ 19]: Loss 0.00756
Epoch [ 19]: Loss 0.00771
Epoch [ 19]: Loss 0.00826
Epoch [ 19]: Loss 0.00815
Epoch [ 19]: Loss 0.00879
Validation: Loss 0.00786 Accuracy 1.00000
Validation: Loss 0.00816 Accuracy 1.00000
Epoch [ 20]: Loss 0.00812
Epoch [ 20]: Loss 0.00746
Epoch [ 20]: Loss 0.00751
Epoch [ 20]: Loss 0.00708
Epoch [ 20]: Loss 0.00723
Epoch [ 20]: Loss 0.00757
Epoch [ 20]: Loss 0.00675
Validation: Loss 0.00729 Accuracy 1.00000
Validation: Loss 0.00757 Accuracy 1.00000
Epoch [ 21]: Loss 0.00723
Epoch [ 21]: Loss 0.00699
Epoch [ 21]: Loss 0.00672
Epoch [ 21]: Loss 0.00660
Epoch [ 21]: Loss 0.00700
Epoch [ 21]: Loss 0.00696
Epoch [ 21]: Loss 0.00738
Validation: Loss 0.00680 Accuracy 1.00000
Validation: Loss 0.00706 Accuracy 1.00000
Epoch [ 22]: Loss 0.00704
Epoch [ 22]: Loss 0.00612
Epoch [ 22]: Loss 0.00660
Epoch [ 22]: Loss 0.00641
Epoch [ 22]: Loss 0.00634
Epoch [ 22]: Loss 0.00632
Epoch [ 22]: Loss 0.00649
Validation: Loss 0.00636 Accuracy 1.00000
Validation: Loss 0.00661 Accuracy 1.00000
Epoch [ 23]: Loss 0.00626
Epoch [ 23]: Loss 0.00625
Epoch [ 23]: Loss 0.00652
Epoch [ 23]: Loss 0.00609
Epoch [ 23]: Loss 0.00574
Epoch [ 23]: Loss 0.00583
Epoch [ 23]: Loss 0.00480
Validation: Loss 0.00597 Accuracy 1.00000
Validation: Loss 0.00620 Accuracy 1.00000
Epoch [ 24]: Loss 0.00638
Epoch [ 24]: Loss 0.00600
Epoch [ 24]: Loss 0.00557
Epoch [ 24]: Loss 0.00522
Epoch [ 24]: Loss 0.00549
Epoch [ 24]: Loss 0.00566
Epoch [ 24]: Loss 0.00509
Validation: Loss 0.00562 Accuracy 1.00000
Validation: Loss 0.00584 Accuracy 1.00000
Epoch [ 25]: Loss 0.00594
Epoch [ 25]: Loss 0.00543
Epoch [ 25]: Loss 0.00527
Epoch [ 25]: Loss 0.00531
Epoch [ 25]: Loss 0.00510
Epoch [ 25]: Loss 0.00521
Epoch [ 25]: Loss 0.00510
Validation: Loss 0.00530 Accuracy 1.00000
Validation: Loss 0.00551 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" {compress = true} ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/docs
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEBUG = Literate
  JULIA_CPU_THREADS = 2
  JULIA_NUM_THREADS = 48
  JULIA_LOAD_PATH = @:@v#.#:@stdlib
  JULIA_CUDA_HARD_MEMORY_LIMIT = 25%

CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.161 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.