Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

julia
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
      Statistics

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(model, ps, st, (x, y))
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main()
    # Get the dataloaders
    (train_loader, val_loader) = get_dataloaders()

    # Create the model
    model = SpiralClassifier(2, 8, 1)
    rng = Xoshiro(0)

    dev = gpu_device()
    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(0.01f0); transform_variables=dev)

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            x = x |> dev
            y = y |> dev

            gs, loss, _, train_state = Lux.Experimental.compute_gradients(
                AutoZygote(), compute_loss, (x, y), train_state)
            train_state = Lux.Experimental.apply_gradients(train_state, gs)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            x = x |> dev
            y = y |> dev
            loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
            acc = accuracy(ret.y_pred, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main()
Epoch [  1]: Loss 0.56101
Epoch [  1]: Loss 0.51091
Epoch [  1]: Loss 0.47121
Epoch [  1]: Loss 0.43796
Epoch [  1]: Loss 0.42349
Epoch [  1]: Loss 0.40773
Epoch [  1]: Loss 0.37226
Validation: Loss 0.36677 Accuracy 1.00000
Validation: Loss 0.37385 Accuracy 1.00000
Epoch [  2]: Loss 0.36885
Epoch [  2]: Loss 0.35466
Epoch [  2]: Loss 0.32474
Epoch [  2]: Loss 0.32151
Epoch [  2]: Loss 0.29703
Epoch [  2]: Loss 0.27942
Epoch [  2]: Loss 0.26821
Validation: Loss 0.25659 Accuracy 1.00000
Validation: Loss 0.26113 Accuracy 1.00000
Epoch [  3]: Loss 0.25569
Epoch [  3]: Loss 0.24345
Epoch [  3]: Loss 0.23027
Epoch [  3]: Loss 0.22043
Epoch [  3]: Loss 0.20938
Epoch [  3]: Loss 0.19640
Epoch [  3]: Loss 0.18183
Validation: Loss 0.17826 Accuracy 1.00000
Validation: Loss 0.18072 Accuracy 1.00000
Epoch [  4]: Loss 0.17793
Epoch [  4]: Loss 0.17098
Epoch [  4]: Loss 0.16128
Epoch [  4]: Loss 0.15452
Epoch [  4]: Loss 0.14622
Epoch [  4]: Loss 0.13597
Epoch [  4]: Loss 0.13270
Validation: Loss 0.12648 Accuracy 1.00000
Validation: Loss 0.12810 Accuracy 1.00000
Epoch [  5]: Loss 0.12737
Epoch [  5]: Loss 0.11910
Epoch [  5]: Loss 0.11677
Epoch [  5]: Loss 0.11067
Epoch [  5]: Loss 0.10335
Epoch [  5]: Loss 0.09998
Epoch [  5]: Loss 0.09298
Validation: Loss 0.09160 Accuracy 1.00000
Validation: Loss 0.09303 Accuracy 1.00000
Epoch [  6]: Loss 0.09089
Epoch [  6]: Loss 0.08865
Epoch [  6]: Loss 0.08224
Epoch [  6]: Loss 0.08064
Epoch [  6]: Loss 0.07780
Epoch [  6]: Loss 0.07219
Epoch [  6]: Loss 0.07096
Validation: Loss 0.06746 Accuracy 1.00000
Validation: Loss 0.06888 Accuracy 1.00000
Epoch [  7]: Loss 0.06640
Epoch [  7]: Loss 0.06523
Epoch [  7]: Loss 0.06241
Epoch [  7]: Loss 0.05900
Epoch [  7]: Loss 0.05759
Epoch [  7]: Loss 0.05381
Epoch [  7]: Loss 0.05162
Validation: Loss 0.05029 Accuracy 1.00000
Validation: Loss 0.05164 Accuracy 1.00000
Epoch [  8]: Loss 0.05039
Epoch [  8]: Loss 0.04830
Epoch [  8]: Loss 0.04522
Epoch [  8]: Loss 0.04514
Epoch [  8]: Loss 0.04287
Epoch [  8]: Loss 0.04094
Epoch [  8]: Loss 0.03716
Validation: Loss 0.03787 Accuracy 1.00000
Validation: Loss 0.03913 Accuracy 1.00000
Epoch [  9]: Loss 0.03928
Epoch [  9]: Loss 0.03605
Epoch [  9]: Loss 0.03240
Epoch [  9]: Loss 0.03398
Epoch [  9]: Loss 0.03332
Epoch [  9]: Loss 0.03119
Epoch [  9]: Loss 0.02879
Validation: Loss 0.02900 Accuracy 1.00000
Validation: Loss 0.03016 Accuracy 1.00000
Epoch [ 10]: Loss 0.02834
Epoch [ 10]: Loss 0.02752
Epoch [ 10]: Loss 0.02742
Epoch [ 10]: Loss 0.02608
Epoch [ 10]: Loss 0.02613
Epoch [ 10]: Loss 0.02349
Epoch [ 10]: Loss 0.02362
Validation: Loss 0.02287 Accuracy 1.00000
Validation: Loss 0.02389 Accuracy 1.00000
Epoch [ 11]: Loss 0.02331
Epoch [ 11]: Loss 0.02291
Epoch [ 11]: Loss 0.02148
Epoch [ 11]: Loss 0.02079
Epoch [ 11]: Loss 0.01826
Epoch [ 11]: Loss 0.01984
Epoch [ 11]: Loss 0.02010
Validation: Loss 0.01865 Accuracy 1.00000
Validation: Loss 0.01954 Accuracy 1.00000
Epoch [ 12]: Loss 0.01824
Epoch [ 12]: Loss 0.01719
Epoch [ 12]: Loss 0.01833
Epoch [ 12]: Loss 0.01758
Epoch [ 12]: Loss 0.01649
Epoch [ 12]: Loss 0.01662
Epoch [ 12]: Loss 0.01603
Validation: Loss 0.01569 Accuracy 1.00000
Validation: Loss 0.01644 Accuracy 1.00000
Epoch [ 13]: Loss 0.01599
Epoch [ 13]: Loss 0.01448
Epoch [ 13]: Loss 0.01440
Epoch [ 13]: Loss 0.01481
Epoch [ 13]: Loss 0.01498
Epoch [ 13]: Loss 0.01412
Epoch [ 13]: Loss 0.01379
Validation: Loss 0.01356 Accuracy 1.00000
Validation: Loss 0.01422 Accuracy 1.00000
Epoch [ 14]: Loss 0.01288
Epoch [ 14]: Loss 0.01321
Epoch [ 14]: Loss 0.01373
Epoch [ 14]: Loss 0.01300
Epoch [ 14]: Loss 0.01219
Epoch [ 14]: Loss 0.01243
Epoch [ 14]: Loss 0.01184
Validation: Loss 0.01197 Accuracy 1.00000
Validation: Loss 0.01257 Accuracy 1.00000
Epoch [ 15]: Loss 0.01173
Epoch [ 15]: Loss 0.01194
Epoch [ 15]: Loss 0.01166
Epoch [ 15]: Loss 0.01138
Epoch [ 15]: Loss 0.01102
Epoch [ 15]: Loss 0.01080
Epoch [ 15]: Loss 0.01119
Validation: Loss 0.01072 Accuracy 1.00000
Validation: Loss 0.01124 Accuracy 1.00000
Epoch [ 16]: Loss 0.01015
Epoch [ 16]: Loss 0.01107
Epoch [ 16]: Loss 0.00988
Epoch [ 16]: Loss 0.01008
Epoch [ 16]: Loss 0.01070
Epoch [ 16]: Loss 0.00995
Epoch [ 16]: Loss 0.00924
Validation: Loss 0.00970 Accuracy 1.00000
Validation: Loss 0.01019 Accuracy 1.00000
Epoch [ 17]: Loss 0.00925
Epoch [ 17]: Loss 0.00958
Epoch [ 17]: Loss 0.00931
Epoch [ 17]: Loss 0.00931
Epoch [ 17]: Loss 0.00901
Epoch [ 17]: Loss 0.00930
Epoch [ 17]: Loss 0.01000
Validation: Loss 0.00886 Accuracy 1.00000
Validation: Loss 0.00930 Accuracy 1.00000
Epoch [ 18]: Loss 0.00850
Epoch [ 18]: Loss 0.00875
Epoch [ 18]: Loss 0.00850
Epoch [ 18]: Loss 0.00889
Epoch [ 18]: Loss 0.00832
Epoch [ 18]: Loss 0.00855
Epoch [ 18]: Loss 0.00732
Validation: Loss 0.00814 Accuracy 1.00000
Validation: Loss 0.00855 Accuracy 1.00000
Epoch [ 19]: Loss 0.00834
Epoch [ 19]: Loss 0.00794
Epoch [ 19]: Loss 0.00769
Epoch [ 19]: Loss 0.00763
Epoch [ 19]: Loss 0.00787
Epoch [ 19]: Loss 0.00776
Epoch [ 19]: Loss 0.00760
Validation: Loss 0.00753 Accuracy 1.00000
Validation: Loss 0.00791 Accuracy 1.00000
Epoch [ 20]: Loss 0.00747
Epoch [ 20]: Loss 0.00728
Epoch [ 20]: Loss 0.00735
Epoch [ 20]: Loss 0.00766
Epoch [ 20]: Loss 0.00714
Epoch [ 20]: Loss 0.00693
Epoch [ 20]: Loss 0.00670
Validation: Loss 0.00699 Accuracy 1.00000
Validation: Loss 0.00735 Accuracy 1.00000
Epoch [ 21]: Loss 0.00710
Epoch [ 21]: Loss 0.00656
Epoch [ 21]: Loss 0.00754
Epoch [ 21]: Loss 0.00682
Epoch [ 21]: Loss 0.00642
Epoch [ 21]: Loss 0.00620
Epoch [ 21]: Loss 0.00679
Validation: Loss 0.00652 Accuracy 1.00000
Validation: Loss 0.00686 Accuracy 1.00000
Epoch [ 22]: Loss 0.00682
Epoch [ 22]: Loss 0.00624
Epoch [ 22]: Loss 0.00649
Epoch [ 22]: Loss 0.00604
Epoch [ 22]: Loss 0.00608
Epoch [ 22]: Loss 0.00615
Epoch [ 22]: Loss 0.00671
Validation: Loss 0.00610 Accuracy 1.00000
Validation: Loss 0.00641 Accuracy 1.00000
Epoch [ 23]: Loss 0.00675
Epoch [ 23]: Loss 0.00600
Epoch [ 23]: Loss 0.00550
Epoch [ 23]: Loss 0.00579
Epoch [ 23]: Loss 0.00597
Epoch [ 23]: Loss 0.00548
Epoch [ 23]: Loss 0.00608
Validation: Loss 0.00572 Accuracy 1.00000
Validation: Loss 0.00602 Accuracy 1.00000
Epoch [ 24]: Loss 0.00571
Epoch [ 24]: Loss 0.00566
Epoch [ 24]: Loss 0.00547
Epoch [ 24]: Loss 0.00563
Epoch [ 24]: Loss 0.00577
Epoch [ 24]: Loss 0.00507
Epoch [ 24]: Loss 0.00576
Validation: Loss 0.00538 Accuracy 1.00000
Validation: Loss 0.00567 Accuracy 1.00000
Epoch [ 25]: Loss 0.00511
Epoch [ 25]: Loss 0.00529
Epoch [ 25]: Loss 0.00564
Epoch [ 25]: Loss 0.00527
Epoch [ 25]: Loss 0.00520
Epoch [ 25]: Loss 0.00497
Epoch [ 25]: Loss 0.00497
Validation: Loss 0.00508 Accuracy 1.00000
Validation: Loss 0.00534 Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model

julia
@save "trained_model.jld2" {compress = true} ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-2/julialang/lux-dot-jl/docs
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEBUG = Literate
  JULIA_CPU_THREADS = 2
  JULIA_NUM_THREADS = 48
  JULIA_LOAD_PATH = @:@v#.#:@stdlib
  JULIA_CUDA_HARD_MEMORY_LIMIT = 25%

CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.297 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.