Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
      Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(model, ps, st, (x, y))
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main()
    # Get the dataloaders
    (train_loader, val_loader) = get_dataloaders()

    # Create the model
    model = SpiralClassifier(2, 8, 1)
    rng = Xoshiro(0)

    dev = gpu_device()
    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(0.01f0); transform_variables=dev)

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            x = x |> dev
            y = y |> dev

            gs, loss, _, train_state = Lux.Experimental.compute_gradients(
                AutoZygote(), compute_loss, (x, y), train_state)
            train_state = Lux.Experimental.apply_gradients(train_state, gs)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            x = x |> dev
            y = y |> dev
            loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
            acc = accuracy(ret.y_pred, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main()
Epoch [  1]: Loss 0.56263
Epoch [  1]: Loss 0.50622
Epoch [  1]: Loss 0.46754
Epoch [  1]: Loss 0.45518
Epoch [  1]: Loss 0.43237
Epoch [  1]: Loss 0.40357
Epoch [  1]: Loss 0.37433
Validation: Loss 0.36939 Accuracy 1.00000
Validation: Loss 0.38321 Accuracy 1.00000
Epoch [  2]: Loss 0.37481
Epoch [  2]: Loss 0.35556
Epoch [  2]: Loss 0.32647
Epoch [  2]: Loss 0.32097
Epoch [  2]: Loss 0.29943
Epoch [  2]: Loss 0.28412
Epoch [  2]: Loss 0.26005
Validation: Loss 0.25919 Accuracy 1.00000
Validation: Loss 0.26821 Accuracy 1.00000
Epoch [  3]: Loss 0.26588
Epoch [  3]: Loss 0.24159
Epoch [  3]: Loss 0.23038
Epoch [  3]: Loss 0.22152
Epoch [  3]: Loss 0.20744
Epoch [  3]: Loss 0.20136
Epoch [  3]: Loss 0.18646
Validation: Loss 0.18102 Accuracy 1.00000
Validation: Loss 0.18637 Accuracy 1.00000
Epoch [  4]: Loss 0.17623
Epoch [  4]: Loss 0.16670
Epoch [  4]: Loss 0.16627
Epoch [  4]: Loss 0.15707
Epoch [  4]: Loss 0.15225
Epoch [  4]: Loss 0.14029
Epoch [  4]: Loss 0.13672
Validation: Loss 0.12918 Accuracy 1.00000
Validation: Loss 0.13293 Accuracy 1.00000
Epoch [  5]: Loss 0.12726
Epoch [  5]: Loss 0.12789
Epoch [  5]: Loss 0.11814
Epoch [  5]: Loss 0.10943
Epoch [  5]: Loss 0.10736
Epoch [  5]: Loss 0.10008
Epoch [  5]: Loss 0.09878
Validation: Loss 0.09405 Accuracy 1.00000
Validation: Loss 0.09701 Accuracy 1.00000
Epoch [  6]: Loss 0.09387
Epoch [  6]: Loss 0.08977
Epoch [  6]: Loss 0.08546
Epoch [  6]: Loss 0.08282
Epoch [  6]: Loss 0.07942
Epoch [  6]: Loss 0.07283
Epoch [  6]: Loss 0.07509
Validation: Loss 0.06955 Accuracy 1.00000
Validation: Loss 0.07227 Accuracy 1.00000
Epoch [  7]: Loss 0.06925
Epoch [  7]: Loss 0.06594
Epoch [  7]: Loss 0.06383
Epoch [  7]: Loss 0.05994
Epoch [  7]: Loss 0.05758
Epoch [  7]: Loss 0.05775
Epoch [  7]: Loss 0.05475
Validation: Loss 0.05197 Accuracy 1.00000
Validation: Loss 0.05449 Accuracy 1.00000
Epoch [  8]: Loss 0.05228
Epoch [  8]: Loss 0.05002
Epoch [  8]: Loss 0.04655
Epoch [  8]: Loss 0.04395
Epoch [  8]: Loss 0.04409
Epoch [  8]: Loss 0.04332
Epoch [  8]: Loss 0.04139
Validation: Loss 0.03915 Accuracy 1.00000
Validation: Loss 0.04145 Accuracy 1.00000
Epoch [  9]: Loss 0.03956
Epoch [  9]: Loss 0.03666
Epoch [  9]: Loss 0.03553
Epoch [  9]: Loss 0.03384
Epoch [  9]: Loss 0.03446
Epoch [  9]: Loss 0.03219
Epoch [  9]: Loss 0.02961
Validation: Loss 0.02991 Accuracy 1.00000
Validation: Loss 0.03203 Accuracy 1.00000
Epoch [ 10]: Loss 0.03070
Epoch [ 10]: Loss 0.02812
Epoch [ 10]: Loss 0.02747
Epoch [ 10]: Loss 0.02620
Epoch [ 10]: Loss 0.02516
Epoch [ 10]: Loss 0.02568
Epoch [ 10]: Loss 0.02363
Validation: Loss 0.02351 Accuracy 1.00000
Validation: Loss 0.02541 Accuracy 1.00000
Epoch [ 11]: Loss 0.02379
Epoch [ 11]: Loss 0.02360
Epoch [ 11]: Loss 0.02179
Epoch [ 11]: Loss 0.02033
Epoch [ 11]: Loss 0.02039
Epoch [ 11]: Loss 0.01950
Epoch [ 11]: Loss 0.02115
Validation: Loss 0.01913 Accuracy 1.00000
Validation: Loss 0.02080 Accuracy 1.00000
Epoch [ 12]: Loss 0.01892
Epoch [ 12]: Loss 0.01786
Epoch [ 12]: Loss 0.01832
Epoch [ 12]: Loss 0.01683
Epoch [ 12]: Loss 0.01617
Epoch [ 12]: Loss 0.01823
Epoch [ 12]: Loss 0.01779
Validation: Loss 0.01608 Accuracy 1.00000
Validation: Loss 0.01752 Accuracy 1.00000
Epoch [ 13]: Loss 0.01481
Epoch [ 13]: Loss 0.01561
Epoch [ 13]: Loss 0.01593
Epoch [ 13]: Loss 0.01432
Epoch [ 13]: Loss 0.01573
Epoch [ 13]: Loss 0.01389
Epoch [ 13]: Loss 0.01494
Validation: Loss 0.01388 Accuracy 1.00000
Validation: Loss 0.01514 Accuracy 1.00000
Epoch [ 14]: Loss 0.01335
Epoch [ 14]: Loss 0.01329
Epoch [ 14]: Loss 0.01391
Epoch [ 14]: Loss 0.01310
Epoch [ 14]: Loss 0.01329
Epoch [ 14]: Loss 0.01209
Epoch [ 14]: Loss 0.01129
Validation: Loss 0.01222 Accuracy 1.00000
Validation: Loss 0.01335 Accuracy 1.00000
Epoch [ 15]: Loss 0.01293
Epoch [ 15]: Loss 0.01176
Epoch [ 15]: Loss 0.01128
Epoch [ 15]: Loss 0.01138
Epoch [ 15]: Loss 0.01119
Epoch [ 15]: Loss 0.01150
Epoch [ 15]: Loss 0.01013
Validation: Loss 0.01094 Accuracy 1.00000
Validation: Loss 0.01196 Accuracy 1.00000
Epoch [ 16]: Loss 0.01111
Epoch [ 16]: Loss 0.01155
Epoch [ 16]: Loss 0.01045
Epoch [ 16]: Loss 0.01038
Epoch [ 16]: Loss 0.00986
Epoch [ 16]: Loss 0.00926
Epoch [ 16]: Loss 0.01058
Validation: Loss 0.00990 Accuracy 1.00000
Validation: Loss 0.01085 Accuracy 1.00000
Epoch [ 17]: Loss 0.00961
Epoch [ 17]: Loss 0.00939
Epoch [ 17]: Loss 0.01024
Epoch [ 17]: Loss 0.00949
Epoch [ 17]: Loss 0.00934
Epoch [ 17]: Loss 0.00900
Epoch [ 17]: Loss 0.00861
Validation: Loss 0.00904 Accuracy 1.00000
Validation: Loss 0.00990 Accuracy 1.00000
Epoch [ 18]: Loss 0.00919
Epoch [ 18]: Loss 0.00851
Epoch [ 18]: Loss 0.00921
Epoch [ 18]: Loss 0.00895
Epoch [ 18]: Loss 0.00814
Epoch [ 18]: Loss 0.00790
Epoch [ 18]: Loss 0.00914
Validation: Loss 0.00831 Accuracy 1.00000
Validation: Loss 0.00911 Accuracy 1.00000
Epoch [ 19]: Loss 0.00879
Epoch [ 19]: Loss 0.00821
Epoch [ 19]: Loss 0.00771
Epoch [ 19]: Loss 0.00754
Epoch [ 19]: Loss 0.00763
Epoch [ 19]: Loss 0.00819
Epoch [ 19]: Loss 0.00725
Validation: Loss 0.00767 Accuracy 1.00000
Validation: Loss 0.00842 Accuracy 1.00000
Epoch [ 20]: Loss 0.00766
Epoch [ 20]: Loss 0.00682
Epoch [ 20]: Loss 0.00756
Epoch [ 20]: Loss 0.00719
Epoch [ 20]: Loss 0.00725
Epoch [ 20]: Loss 0.00777
Epoch [ 20]: Loss 0.00760
Validation: Loss 0.00713 Accuracy 1.00000
Validation: Loss 0.00783 Accuracy 1.00000
Epoch [ 21]: Loss 0.00663
Epoch [ 21]: Loss 0.00737
Epoch [ 21]: Loss 0.00708
Epoch [ 21]: Loss 0.00655
Epoch [ 21]: Loss 0.00692
Epoch [ 21]: Loss 0.00675
Epoch [ 21]: Loss 0.00652
Validation: Loss 0.00664 Accuracy 1.00000
Validation: Loss 0.00730 Accuracy 1.00000
Epoch [ 22]: Loss 0.00615
Epoch [ 22]: Loss 0.00571
Epoch [ 22]: Loss 0.00629
Epoch [ 22]: Loss 0.00674
Epoch [ 22]: Loss 0.00678
Epoch [ 22]: Loss 0.00690
Epoch [ 22]: Loss 0.00588
Validation: Loss 0.00621 Accuracy 1.00000
Validation: Loss 0.00684 Accuracy 1.00000
Epoch [ 23]: Loss 0.00587
Epoch [ 23]: Loss 0.00625
Epoch [ 23]: Loss 0.00642
Epoch [ 23]: Loss 0.00557
Epoch [ 23]: Loss 0.00585
Epoch [ 23]: Loss 0.00597
Epoch [ 23]: Loss 0.00643
Validation: Loss 0.00583 Accuracy 1.00000
Validation: Loss 0.00642 Accuracy 1.00000
Epoch [ 24]: Loss 0.00560
Epoch [ 24]: Loss 0.00568
Epoch [ 24]: Loss 0.00568
Epoch [ 24]: Loss 0.00549
Epoch [ 24]: Loss 0.00585
Epoch [ 24]: Loss 0.00556
Epoch [ 24]: Loss 0.00554
Validation: Loss 0.00549 Accuracy 1.00000
Validation: Loss 0.00604 Accuracy 1.00000
Epoch [ 25]: Loss 0.00557
Epoch [ 25]: Loss 0.00522
Epoch [ 25]: Loss 0.00536
Epoch [ 25]: Loss 0.00513
Epoch [ 25]: Loss 0.00540
Epoch [ 25]: Loss 0.00509
Epoch [ 25]: Loss 0.00569
Validation: Loss 0.00517 Accuracy 1.00000
Validation: Loss 0.00570 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" {compress = true} ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/docs
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEBUG = Literate
  JULIA_CPU_THREADS = 2
  JULIA_NUM_THREADS = 48
  JULIA_LOAD_PATH = @:@v#.#:@stdlib
  JULIA_CUDA_HARD_MEMORY_LIMIT = 25%

CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.141 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.