Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
using ADTypes, Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random,
StatisticsDataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
endget_dataloaders (generic function with 1 method)Creating a Classifier
We will be extending the Lux.AbstractExplicitContainerLayer type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L, C} <:
Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
endWe won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid))
endMain.var"##225".SpiralClassifierWe can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
endDefining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.
function xlogy(x, y)
result = x * log(y)
return ifelse(iszero(x), zero(result), result)
end
function binarycrossentropy(y_pred, y_true)
y_pred = y_pred .+ eps(eltype(y_pred))
return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end
function compute_loss(model, ps, st, (x, y))
y_pred, st = model(x, ps, st)
return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)accuracy (generic function with 1 method)Training the Model
function main()
# Get the dataloaders
(train_loader, val_loader) = get_dataloaders()
# Create the model
model = SpiralClassifier(2, 8, 1)
rng = Xoshiro(0)
dev = gpu_device()
train_state = Lux.Experimental.TrainState(
rng, model, Adam(0.01f0); transform_variables=dev)
for epoch in 1:25
# Train the model
for (x, y) in train_loader
x = x |> dev
y = y |> dev
gs, loss, _, train_state = Lux.Experimental.compute_gradients(
AutoZygote(), compute_loss, (x, y), train_state)
train_state = Lux.Experimental.apply_gradients(train_state, gs)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
x = x |> dev
y = y |> dev
loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y))
acc = accuracy(ret.y_pred, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main()Epoch [ 1]: Loss 0.56178
Epoch [ 1]: Loss 0.51017
Epoch [ 1]: Loss 0.47634
Epoch [ 1]: Loss 0.44813
Epoch [ 1]: Loss 0.42954
Epoch [ 1]: Loss 0.40370
Epoch [ 1]: Loss 0.36643
Validation: Loss 0.36494 Accuracy 1.00000
Validation: Loss 0.36217 Accuracy 1.00000
Epoch [ 2]: Loss 0.37246
Epoch [ 2]: Loss 0.35711
Epoch [ 2]: Loss 0.33796
Epoch [ 2]: Loss 0.30543
Epoch [ 2]: Loss 0.29671
Epoch [ 2]: Loss 0.28651
Epoch [ 2]: Loss 0.27739
Validation: Loss 0.25594 Accuracy 1.00000
Validation: Loss 0.25439 Accuracy 1.00000
Epoch [ 3]: Loss 0.25597
Epoch [ 3]: Loss 0.24539
Epoch [ 3]: Loss 0.22981
Epoch [ 3]: Loss 0.22039
Epoch [ 3]: Loss 0.20723
Epoch [ 3]: Loss 0.20337
Epoch [ 3]: Loss 0.19897
Validation: Loss 0.17872 Accuracy 1.00000
Validation: Loss 0.17805 Accuracy 1.00000
Epoch [ 4]: Loss 0.18087
Epoch [ 4]: Loss 0.17154
Epoch [ 4]: Loss 0.16526
Epoch [ 4]: Loss 0.15488
Epoch [ 4]: Loss 0.14729
Epoch [ 4]: Loss 0.14107
Epoch [ 4]: Loss 0.13093
Validation: Loss 0.12766 Accuracy 1.00000
Validation: Loss 0.12741 Accuracy 1.00000
Epoch [ 5]: Loss 0.12735
Epoch [ 5]: Loss 0.12332
Epoch [ 5]: Loss 0.11649
Epoch [ 5]: Loss 0.11297
Epoch [ 5]: Loss 0.10769
Epoch [ 5]: Loss 0.10170
Epoch [ 5]: Loss 0.09774
Validation: Loss 0.09279 Accuracy 1.00000
Validation: Loss 0.09243 Accuracy 1.00000
Epoch [ 6]: Loss 0.09282
Epoch [ 6]: Loss 0.08867
Epoch [ 6]: Loss 0.08627
Epoch [ 6]: Loss 0.08247
Epoch [ 6]: Loss 0.07937
Epoch [ 6]: Loss 0.07490
Epoch [ 6]: Loss 0.07309
Validation: Loss 0.06842 Accuracy 1.00000
Validation: Loss 0.06795 Accuracy 1.00000
Epoch [ 7]: Loss 0.06945
Epoch [ 7]: Loss 0.06710
Epoch [ 7]: Loss 0.06493
Epoch [ 7]: Loss 0.06089
Epoch [ 7]: Loss 0.05727
Epoch [ 7]: Loss 0.05496
Epoch [ 7]: Loss 0.05386
Validation: Loss 0.05104 Accuracy 1.00000
Validation: Loss 0.05054 Accuracy 1.00000
Epoch [ 8]: Loss 0.05187
Epoch [ 8]: Loss 0.04979
Epoch [ 8]: Loss 0.04745
Epoch [ 8]: Loss 0.04606
Epoch [ 8]: Loss 0.04462
Epoch [ 8]: Loss 0.04123
Epoch [ 8]: Loss 0.03969
Validation: Loss 0.03840 Accuracy 1.00000
Validation: Loss 0.03789 Accuracy 1.00000
Epoch [ 9]: Loss 0.03889
Epoch [ 9]: Loss 0.03676
Epoch [ 9]: Loss 0.03615
Epoch [ 9]: Loss 0.03439
Epoch [ 9]: Loss 0.03411
Epoch [ 9]: Loss 0.03188
Epoch [ 9]: Loss 0.03283
Validation: Loss 0.02934 Accuracy 1.00000
Validation: Loss 0.02884 Accuracy 1.00000
Epoch [ 10]: Loss 0.02973
Epoch [ 10]: Loss 0.02819
Epoch [ 10]: Loss 0.02990
Epoch [ 10]: Loss 0.02658
Epoch [ 10]: Loss 0.02443
Epoch [ 10]: Loss 0.02507
Epoch [ 10]: Loss 0.02589
Validation: Loss 0.02307 Accuracy 1.00000
Validation: Loss 0.02261 Accuracy 1.00000
Epoch [ 11]: Loss 0.02310
Epoch [ 11]: Loss 0.02350
Epoch [ 11]: Loss 0.02149
Epoch [ 11]: Loss 0.02201
Epoch [ 11]: Loss 0.02027
Epoch [ 11]: Loss 0.02080
Epoch [ 11]: Loss 0.01853
Validation: Loss 0.01877 Accuracy 1.00000
Validation: Loss 0.01838 Accuracy 1.00000
Epoch [ 12]: Loss 0.01829
Epoch [ 12]: Loss 0.01854
Epoch [ 12]: Loss 0.01798
Epoch [ 12]: Loss 0.01856
Epoch [ 12]: Loss 0.01769
Epoch [ 12]: Loss 0.01694
Epoch [ 12]: Loss 0.01593
Validation: Loss 0.01579 Accuracy 1.00000
Validation: Loss 0.01545 Accuracy 1.00000
Epoch [ 13]: Loss 0.01627
Epoch [ 13]: Loss 0.01667
Epoch [ 13]: Loss 0.01634
Epoch [ 13]: Loss 0.01412
Epoch [ 13]: Loss 0.01476
Epoch [ 13]: Loss 0.01361
Epoch [ 13]: Loss 0.01447
Validation: Loss 0.01365 Accuracy 1.00000
Validation: Loss 0.01334 Accuracy 1.00000
Epoch [ 14]: Loss 0.01428
Epoch [ 14]: Loss 0.01430
Epoch [ 14]: Loss 0.01343
Epoch [ 14]: Loss 0.01234
Epoch [ 14]: Loss 0.01289
Epoch [ 14]: Loss 0.01239
Epoch [ 14]: Loss 0.01374
Validation: Loss 0.01204 Accuracy 1.00000
Validation: Loss 0.01177 Accuracy 1.00000
Epoch [ 15]: Loss 0.01287
Epoch [ 15]: Loss 0.01245
Epoch [ 15]: Loss 0.01182
Epoch [ 15]: Loss 0.01154
Epoch [ 15]: Loss 0.01066
Epoch [ 15]: Loss 0.01143
Epoch [ 15]: Loss 0.01176
Validation: Loss 0.01077 Accuracy 1.00000
Validation: Loss 0.01053 Accuracy 1.00000
Epoch [ 16]: Loss 0.01140
Epoch [ 16]: Loss 0.01095
Epoch [ 16]: Loss 0.01149
Epoch [ 16]: Loss 0.01025
Epoch [ 16]: Loss 0.00990
Epoch [ 16]: Loss 0.00973
Epoch [ 16]: Loss 0.01035
Validation: Loss 0.00975 Accuracy 1.00000
Validation: Loss 0.00953 Accuracy 1.00000
Epoch [ 17]: Loss 0.01003
Epoch [ 17]: Loss 0.01000
Epoch [ 17]: Loss 0.00990
Epoch [ 17]: Loss 0.00933
Epoch [ 17]: Loss 0.00901
Epoch [ 17]: Loss 0.00962
Epoch [ 17]: Loss 0.00914
Validation: Loss 0.00890 Accuracy 1.00000
Validation: Loss 0.00869 Accuracy 1.00000
Epoch [ 18]: Loss 0.00904
Epoch [ 18]: Loss 0.00871
Epoch [ 18]: Loss 0.00966
Epoch [ 18]: Loss 0.00881
Epoch [ 18]: Loss 0.00813
Epoch [ 18]: Loss 0.00869
Epoch [ 18]: Loss 0.00821
Validation: Loss 0.00818 Accuracy 1.00000
Validation: Loss 0.00799 Accuracy 1.00000
Epoch [ 19]: Loss 0.00855
Epoch [ 19]: Loss 0.00854
Epoch [ 19]: Loss 0.00859
Epoch [ 19]: Loss 0.00797
Epoch [ 19]: Loss 0.00775
Epoch [ 19]: Loss 0.00736
Epoch [ 19]: Loss 0.00805
Validation: Loss 0.00756 Accuracy 1.00000
Validation: Loss 0.00738 Accuracy 1.00000
Epoch [ 20]: Loss 0.00837
Epoch [ 20]: Loss 0.00729
Epoch [ 20]: Loss 0.00769
Epoch [ 20]: Loss 0.00698
Epoch [ 20]: Loss 0.00764
Epoch [ 20]: Loss 0.00718
Epoch [ 20]: Loss 0.00742
Validation: Loss 0.00702 Accuracy 1.00000
Validation: Loss 0.00685 Accuracy 1.00000
Epoch [ 21]: Loss 0.00708
Epoch [ 21]: Loss 0.00722
Epoch [ 21]: Loss 0.00694
Epoch [ 21]: Loss 0.00741
Epoch [ 21]: Loss 0.00644
Epoch [ 21]: Loss 0.00700
Epoch [ 21]: Loss 0.00647
Validation: Loss 0.00654 Accuracy 1.00000
Validation: Loss 0.00638 Accuracy 1.00000
Epoch [ 22]: Loss 0.00665
Epoch [ 22]: Loss 0.00669
Epoch [ 22]: Loss 0.00662
Epoch [ 22]: Loss 0.00647
Epoch [ 22]: Loss 0.00646
Epoch [ 22]: Loss 0.00635
Epoch [ 22]: Loss 0.00628
Validation: Loss 0.00612 Accuracy 1.00000
Validation: Loss 0.00597 Accuracy 1.00000
Epoch [ 23]: Loss 0.00604
Epoch [ 23]: Loss 0.00628
Epoch [ 23]: Loss 0.00616
Epoch [ 23]: Loss 0.00615
Epoch [ 23]: Loss 0.00592
Epoch [ 23]: Loss 0.00595
Epoch [ 23]: Loss 0.00695
Validation: Loss 0.00574 Accuracy 1.00000
Validation: Loss 0.00560 Accuracy 1.00000
Epoch [ 24]: Loss 0.00579
Epoch [ 24]: Loss 0.00602
Epoch [ 24]: Loss 0.00590
Epoch [ 24]: Loss 0.00582
Epoch [ 24]: Loss 0.00557
Epoch [ 24]: Loss 0.00554
Epoch [ 24]: Loss 0.00519
Validation: Loss 0.00540 Accuracy 1.00000
Validation: Loss 0.00527 Accuracy 1.00000
Epoch [ 25]: Loss 0.00525
Epoch [ 25]: Loss 0.00528
Epoch [ 25]: Loss 0.00538
Epoch [ 25]: Loss 0.00546
Epoch [ 25]: Loss 0.00562
Epoch [ 25]: Loss 0.00552
Epoch [ 25]: Loss 0.00524
Validation: Loss 0.00510 Accuracy 1.00000
Validation: Loss 0.00497 Accuracy 1.00000Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model
@save "trained_model.jld2" {compress = true} ps_trained st_trainedLet's try loading the model
@load "trained_model.jld2" ps_trained st_trained2-element Vector{Symbol}:
:ps_trained
:st_trainedAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); endJulia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 6 default, 0 interactive, 3 GC (on 2 virtual cores)
Environment:
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-3/julialang/lux-dot-jl/docs
JULIA_AMDGPU_LOGGING_ENABLED = true
JULIA_DEBUG = Literate
JULIA_CPU_THREADS = 2
JULIA_NUM_THREADS = 6
JULIA_LOAD_PATH = @:@v#.#:@stdlib
JULIA_CUDA_HARD_MEMORY_LIMIT = 25%
CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15
CUDA libraries:
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15
Julia packages:
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0
Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.974 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19This page was generated using Literate.jl.