Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

function get_dataloaders(; dataset_size = 1000, sequence_length = 50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
            for d in data[1:(dataset_size ÷ 2)]
    anticlockwise_spirals = [
                d[1][:, (sequence_length + 1):end], :, sequence_length, 1
            for d in data[((dataset_size ÷ 2) + 1):end]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims = 3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at = 0.8, shuffle = true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
            collect.((x_train, y_train)); batchsize = 128, shuffle = true, partial = false
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize = 128, shuffle = false, partial = false),
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

struct SpiralClassifier{L, C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

function (s::SpiralClassifier)(
        x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple
    ) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier = st_classifier, lstm_cell = st_lstm))
    return vec(y), st

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        @return vec(classifier(y))
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred = ŷ)

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() |> dev

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = Lux.setup(Random.default_rng(), model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        @compile model(first(train_loader)[1], ps, Lux.testmode(st))
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            total_loss += loss * length(y)
            total_samples += length(y)
        @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)

        @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (total_acc / total_samples)

    return (train_state.parameters, train_state.states) |> cpu_device()

ps_trained, st_trained = main(SpiralClassifier)
Epoch [  1]: Loss 0.65806
Validation:	Loss 0.58377	Accuracy 1.00000
Epoch [  2]: Loss 0.53889
Validation:	Loss 0.48049	Accuracy 1.00000
Epoch [  3]: Loss 0.44765
Validation:	Loss 0.39908	Accuracy 1.00000
Epoch [  4]: Loss 0.37392
Validation:	Loss 0.33268	Accuracy 1.00000
Epoch [  5]: Loss 0.31215
Validation:	Loss 0.27409	Accuracy 1.00000
Epoch [  6]: Loss 0.25143
Validation:	Loss 0.20997	Accuracy 1.00000
Epoch [  7]: Loss 0.18228
Validation:	Loss 0.15024	Accuracy 1.00000
Epoch [  8]: Loss 0.13617
Validation:	Loss 0.11492	Accuracy 1.00000
Epoch [  9]: Loss 0.10375
Validation:	Loss 0.08718	Accuracy 1.00000
Epoch [ 10]: Loss 0.07819
Validation:	Loss 0.06623	Accuracy 1.00000
Epoch [ 11]: Loss 0.05933
Validation:	Loss 0.05051	Accuracy 1.00000
Epoch [ 12]: Loss 0.04508
Validation:	Loss 0.03884	Accuracy 1.00000
Epoch [ 13]: Loss 0.03473
Validation:	Loss 0.03074	Accuracy 1.00000
Epoch [ 14]: Loss 0.02794
Validation:	Loss 0.02558	Accuracy 1.00000
Epoch [ 15]: Loss 0.02355
Validation:	Loss 0.02204	Accuracy 1.00000
Epoch [ 16]: Loss 0.02055
Validation:	Loss 0.01939	Accuracy 1.00000
Epoch [ 17]: Loss 0.01814
Validation:	Loss 0.01732	Accuracy 1.00000
Epoch [ 18]: Loss 0.01625
Validation:	Loss 0.01565	Accuracy 1.00000
Epoch [ 19]: Loss 0.01473
Validation:	Loss 0.01429	Accuracy 1.00000
Epoch [ 20]: Loss 0.01348
Validation:	Loss 0.01314	Accuracy 1.00000
Epoch [ 21]: Loss 0.01241
Validation:	Loss 0.01217	Accuracy 1.00000
Epoch [ 22]: Loss 0.01141
Validation:	Loss 0.01133	Accuracy 1.00000
Epoch [ 23]: Loss 0.01076
Validation:	Loss 0.01060	Accuracy 1.00000
Epoch [ 24]: Loss 0.01004
Validation:	Loss 0.00996	Accuracy 1.00000
Epoch [ 25]: Loss 0.00941
Validation:	Loss 0.00939	Accuracy 1.00000

We can also train the compact model with the exact same code!

ps_trained2, st_trained2 = main(SpiralClassifierCompact)
Epoch [  1]: Loss 0.56069
Validation:	Loss 0.50002	Accuracy 0.50781
Epoch [  2]: Loss 0.45935
Validation:	Loss 0.40520	Accuracy 0.50781
Epoch [  3]: Loss 0.39022
Validation:	Loss 0.34996	Accuracy 1.00000
Epoch [  4]: Loss 0.32765
Validation:	Loss 0.28566	Accuracy 1.00000
Epoch [  5]: Loss 0.26878
Validation:	Loss 0.22812	Accuracy 1.00000
Epoch [  6]: Loss 0.20919
Validation:	Loss 0.17706	Accuracy 1.00000
Epoch [  7]: Loss 0.15947
Validation:	Loss 0.13204	Accuracy 1.00000
Epoch [  8]: Loss 0.11734
Validation:	Loss 0.09225	Accuracy 1.00000
Epoch [  9]: Loss 0.08175
Validation:	Loss 0.06725	Accuracy 1.00000
Epoch [ 10]: Loss 0.06103
Validation:	Loss 0.05031	Accuracy 1.00000
Epoch [ 11]: Loss 0.04458
Validation:	Loss 0.03722	Accuracy 1.00000
Epoch [ 12]: Loss 0.03371
Validation:	Loss 0.02848	Accuracy 1.00000
Epoch [ 13]: Loss 0.02671
Validation:	Loss 0.02289	Accuracy 1.00000
Epoch [ 14]: Loss 0.02178
Validation:	Loss 0.01941	Accuracy 1.00000
Epoch [ 15]: Loss 0.01880
Validation:	Loss 0.01707	Accuracy 1.00000
Epoch [ 16]: Loss 0.01657
Validation:	Loss 0.01535	Accuracy 1.00000
Epoch [ 17]: Loss 0.01505
Validation:	Loss 0.01396	Accuracy 1.00000
Epoch [ 18]: Loss 0.01353
Validation:	Loss 0.01277	Accuracy 1.00000
Epoch [ 19]: Loss 0.01258
Validation:	Loss 0.01171	Accuracy 1.00000
Epoch [ 20]: Loss 0.01151
Validation:	Loss 0.01066	Accuracy 1.00000
Epoch [ 21]: Loss 0.01035
Validation:	Loss 0.00954	Accuracy 1.00000
Epoch [ 22]: Loss 0.00910
Validation:	Loss 0.00826	Accuracy 1.00000
Epoch [ 23]: Loss 0.00791
Validation:	Loss 0.00705	Accuracy 1.00000
Epoch [ 24]: Loss 0.00676
Validation:	Loss 0.00615	Accuracy 1.00000
Epoch [ 25]: Loss 0.00591
Validation:	Loss 0.00544	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:


