Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
Statistics
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractExplicitContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <:
Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
end
Main.var"##225".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
function xlogy(x, y)
result = x * log(y)
return ifelse(iszero(x), zero(result), result)
end
function binarycrossentropy(y_pred, y_true)
y_pred = y_pred .+ eps(eltype(y_pred))
return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end
function compute_loss(model, ps, st, (x, y))
y_pred, st = model(x, ps, st)
return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main()
# Get the dataloaders
(train_loader, val_loader) = get_dataloaders()
# Create the model
model = SpiralClassifier(2, 8, 1)
rng = Xoshiro(0)
dev = gpu_device()
train_state = Lux.Experimental.TrainState(
rng, model, Adam(0.01f0); transform_variables=dev)
for epoch in 1:25
# Train the model
for (x, y) in train_loader
x = x |> dev
y = y |> dev
gs, loss, _, train_state = Lux.Experimental.compute_gradients(
AutoZygote(), compute_loss, (x, y), train_state)
train_state = Lux.Experimental.apply_gradients(train_state, gs)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
x = x |> dev
y = y |> dev
loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
acc = accuracy(ret.y_pred, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main()
Epoch [ 1]: Loss 0.56235
Epoch [ 1]: Loss 0.51151
Epoch [ 1]: Loss 0.47914
Epoch [ 1]: Loss 0.45635
Epoch [ 1]: Loss 0.42339
Epoch [ 1]: Loss 0.41773
Epoch [ 1]: Loss 0.41768
Validation: Loss 0.37153 Accuracy 1.00000
Validation: Loss 0.36875 Accuracy 1.00000
Epoch [ 2]: Loss 0.37936
Epoch [ 2]: Loss 0.34991
Epoch [ 2]: Loss 0.34274
Epoch [ 2]: Loss 0.31910
Epoch [ 2]: Loss 0.30310
Epoch [ 2]: Loss 0.29096
Epoch [ 2]: Loss 0.28280
Validation: Loss 0.26254 Accuracy 1.00000
Validation: Loss 0.26106 Accuracy 1.00000
Epoch [ 3]: Loss 0.26050
Epoch [ 3]: Loss 0.25222
Epoch [ 3]: Loss 0.23802
Epoch [ 3]: Loss 0.22420
Epoch [ 3]: Loss 0.21580
Epoch [ 3]: Loss 0.20403
Epoch [ 3]: Loss 0.19697
Validation: Loss 0.18550 Accuracy 1.00000
Validation: Loss 0.18480 Accuracy 1.00000
Epoch [ 4]: Loss 0.18422
Epoch [ 4]: Loss 0.17635
Epoch [ 4]: Loss 0.17018
Epoch [ 4]: Loss 0.16104
Epoch [ 4]: Loss 0.15308
Epoch [ 4]: Loss 0.14689
Epoch [ 4]: Loss 0.14509
Validation: Loss 0.13456 Accuracy 1.00000
Validation: Loss 0.13406 Accuracy 1.00000
Epoch [ 5]: Loss 0.13665
Epoch [ 5]: Loss 0.12755
Epoch [ 5]: Loss 0.12420
Epoch [ 5]: Loss 0.11720
Epoch [ 5]: Loss 0.11334
Epoch [ 5]: Loss 0.10736
Epoch [ 5]: Loss 0.10463
Validation: Loss 0.09955 Accuracy 1.00000
Validation: Loss 0.09906 Accuracy 1.00000
Epoch [ 6]: Loss 0.09854
Epoch [ 6]: Loss 0.09558
Epoch [ 6]: Loss 0.09159
Epoch [ 6]: Loss 0.08865
Epoch [ 6]: Loss 0.08344
Epoch [ 6]: Loss 0.08169
Epoch [ 6]: Loss 0.07743
Validation: Loss 0.07425 Accuracy 1.00000
Validation: Loss 0.07366 Accuracy 1.00000
Epoch [ 7]: Loss 0.07425
Epoch [ 7]: Loss 0.07175
Epoch [ 7]: Loss 0.06786
Epoch [ 7]: Loss 0.06666
Epoch [ 7]: Loss 0.06266
Epoch [ 7]: Loss 0.05979
Epoch [ 7]: Loss 0.06205
Validation: Loss 0.05576 Accuracy 1.00000
Validation: Loss 0.05512 Accuracy 1.00000
Epoch [ 8]: Loss 0.05682
Epoch [ 8]: Loss 0.05305
Epoch [ 8]: Loss 0.05003
Epoch [ 8]: Loss 0.05069
Epoch [ 8]: Loss 0.04792
Epoch [ 8]: Loss 0.04655
Epoch [ 8]: Loss 0.04070
Validation: Loss 0.04199 Accuracy 1.00000
Validation: Loss 0.04138 Accuracy 1.00000
Epoch [ 9]: Loss 0.04330
Epoch [ 9]: Loss 0.04062
Epoch [ 9]: Loss 0.03796
Epoch [ 9]: Loss 0.03726
Epoch [ 9]: Loss 0.03676
Epoch [ 9]: Loss 0.03459
Epoch [ 9]: Loss 0.03161
Validation: Loss 0.03193 Accuracy 1.00000
Validation: Loss 0.03137 Accuracy 1.00000
Epoch [ 10]: Loss 0.03298
Epoch [ 10]: Loss 0.03151
Epoch [ 10]: Loss 0.02842
Epoch [ 10]: Loss 0.02961
Epoch [ 10]: Loss 0.02920
Epoch [ 10]: Loss 0.02529
Epoch [ 10]: Loss 0.02528
Validation: Loss 0.02498 Accuracy 1.00000
Validation: Loss 0.02450 Accuracy 1.00000
Epoch [ 11]: Loss 0.02410
Epoch [ 11]: Loss 0.02415
Epoch [ 11]: Loss 0.02397
Epoch [ 11]: Loss 0.02241
Epoch [ 11]: Loss 0.02284
Epoch [ 11]: Loss 0.02217
Epoch [ 11]: Loss 0.02287
Validation: Loss 0.02024 Accuracy 1.00000
Validation: Loss 0.01983 Accuracy 1.00000
Epoch [ 12]: Loss 0.02098
Epoch [ 12]: Loss 0.01865
Epoch [ 12]: Loss 0.01899
Epoch [ 12]: Loss 0.01814
Epoch [ 12]: Loss 0.01906
Epoch [ 12]: Loss 0.01883
Epoch [ 12]: Loss 0.01847
Validation: Loss 0.01691 Accuracy 1.00000
Validation: Loss 0.01657 Accuracy 1.00000
Epoch [ 13]: Loss 0.01665
Epoch [ 13]: Loss 0.01855
Epoch [ 13]: Loss 0.01664
Epoch [ 13]: Loss 0.01553
Epoch [ 13]: Loss 0.01510
Epoch [ 13]: Loss 0.01501
Epoch [ 13]: Loss 0.01356
Validation: Loss 0.01451 Accuracy 1.00000
Validation: Loss 0.01421 Accuracy 1.00000
Epoch [ 14]: Loss 0.01486
Epoch [ 14]: Loss 0.01377
Epoch [ 14]: Loss 0.01505
Epoch [ 14]: Loss 0.01415
Epoch [ 14]: Loss 0.01321
Epoch [ 14]: Loss 0.01317
Epoch [ 14]: Loss 0.01225
Validation: Loss 0.01273 Accuracy 1.00000
Validation: Loss 0.01247 Accuracy 1.00000
Epoch [ 15]: Loss 0.01264
Epoch [ 15]: Loss 0.01310
Epoch [ 15]: Loss 0.01270
Epoch [ 15]: Loss 0.01257
Epoch [ 15]: Loss 0.01237
Epoch [ 15]: Loss 0.01100
Epoch [ 15]: Loss 0.01082
Validation: Loss 0.01135 Accuracy 1.00000
Validation: Loss 0.01112 Accuracy 1.00000
Epoch [ 16]: Loss 0.01180
Epoch [ 16]: Loss 0.01108
Epoch [ 16]: Loss 0.01129
Epoch [ 16]: Loss 0.01075
Epoch [ 16]: Loss 0.01080
Epoch [ 16]: Loss 0.01062
Epoch [ 16]: Loss 0.01077
Validation: Loss 0.01024 Accuracy 1.00000
Validation: Loss 0.01003 Accuracy 1.00000
Epoch [ 17]: Loss 0.01040
Epoch [ 17]: Loss 0.01017
Epoch [ 17]: Loss 0.01020
Epoch [ 17]: Loss 0.00910
Epoch [ 17]: Loss 0.00994
Epoch [ 17]: Loss 0.00990
Epoch [ 17]: Loss 0.01107
Validation: Loss 0.00932 Accuracy 1.00000
Validation: Loss 0.00912 Accuracy 1.00000
Epoch [ 18]: Loss 0.00994
Epoch [ 18]: Loss 0.00918
Epoch [ 18]: Loss 0.00968
Epoch [ 18]: Loss 0.00866
Epoch [ 18]: Loss 0.00829
Epoch [ 18]: Loss 0.00889
Epoch [ 18]: Loss 0.00955
Validation: Loss 0.00853 Accuracy 1.00000
Validation: Loss 0.00836 Accuracy 1.00000
Epoch [ 19]: Loss 0.00852
Epoch [ 19]: Loss 0.00859
Epoch [ 19]: Loss 0.00828
Epoch [ 19]: Loss 0.00842
Epoch [ 19]: Loss 0.00822
Epoch [ 19]: Loss 0.00831
Epoch [ 19]: Loss 0.00805
Validation: Loss 0.00787 Accuracy 1.00000
Validation: Loss 0.00770 Accuracy 1.00000
Epoch [ 20]: Loss 0.00774
Epoch [ 20]: Loss 0.00779
Epoch [ 20]: Loss 0.00795
Epoch [ 20]: Loss 0.00773
Epoch [ 20]: Loss 0.00719
Epoch [ 20]: Loss 0.00795
Epoch [ 20]: Loss 0.00788
Validation: Loss 0.00728 Accuracy 1.00000
Validation: Loss 0.00713 Accuracy 1.00000
Epoch [ 21]: Loss 0.00777
Epoch [ 21]: Loss 0.00665
Epoch [ 21]: Loss 0.00737
Epoch [ 21]: Loss 0.00700
Epoch [ 21]: Loss 0.00680
Epoch [ 21]: Loss 0.00756
Epoch [ 21]: Loss 0.00670
Validation: Loss 0.00677 Accuracy 1.00000
Validation: Loss 0.00663 Accuracy 1.00000
Epoch [ 22]: Loss 0.00723
Epoch [ 22]: Loss 0.00626
Epoch [ 22]: Loss 0.00680
Epoch [ 22]: Loss 0.00719
Epoch [ 22]: Loss 0.00595
Epoch [ 22]: Loss 0.00717
Epoch [ 22]: Loss 0.00471
Validation: Loss 0.00632 Accuracy 1.00000
Validation: Loss 0.00619 Accuracy 1.00000
Epoch [ 23]: Loss 0.00638
Epoch [ 23]: Loss 0.00623
Epoch [ 23]: Loss 0.00629
Epoch [ 23]: Loss 0.00600
Epoch [ 23]: Loss 0.00635
Epoch [ 23]: Loss 0.00619
Epoch [ 23]: Loss 0.00643
Validation: Loss 0.00593 Accuracy 1.00000
Validation: Loss 0.00580 Accuracy 1.00000
Epoch [ 24]: Loss 0.00583
Epoch [ 24]: Loss 0.00629
Epoch [ 24]: Loss 0.00521
Epoch [ 24]: Loss 0.00583
Epoch [ 24]: Loss 0.00642
Epoch [ 24]: Loss 0.00582
Epoch [ 24]: Loss 0.00508
Validation: Loss 0.00557 Accuracy 1.00000
Validation: Loss 0.00545 Accuracy 1.00000
Epoch [ 25]: Loss 0.00581
Epoch [ 25]: Loss 0.00549
Epoch [ 25]: Loss 0.00570
Epoch [ 25]: Loss 0.00548
Epoch [ 25]: Loss 0.00502
Epoch [ 25]: Loss 0.00563
Epoch [ 25]: Loss 0.00552
Validation: Loss 0.00525 Accuracy 1.00000
Validation: Loss 0.00514 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model
@save "trained_model.jld2" {compress = true} ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/docs
JULIA_AMDGPU_LOGGING_ENABLED = true
JULIA_DEBUG = Literate
JULIA_CPU_THREADS = 2
JULIA_NUM_THREADS = 48
JULIA_LOAD_PATH = @:@v#.#:@stdlib
JULIA_CUDA_HARD_MEMORY_LIMIT = 25%
CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15
CUDA libraries:
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15
Julia packages:
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0
Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.328 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19
This page was generated using Literate.jl.