Building a LSTM Encoder-Decoder model using Lux.jl
This examples is based on LSTM_encoder_decoder by Laura Kulowski.
julia
using Lux, Reactant, Random, Optimisers, Statistics, Enzyme, Printf, CairoMakie, MLUtils
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)
Generate synthetic data
julia
function synthetic_data(Nt=2000, tf=80 * Float32(π))
t = range(0.0f0, tf; length=Nt)
y = sin.(2.0f0 * t) .+ 0.5f0 * cos.(t) .+ randn(Float32, Nt) * 0.2f0
return t, y
end
function train_test_split(t, y, split=0.8)
indx_split = ceil(Int, length(t) * split)
indx_train = 1:indx_split
indx_test = (indx_split + 1):length(t)
t_train = t[indx_train]
y_train = reshape(y[indx_train], 1, :)
t_test = t[indx_test]
y_test = reshape(y[indx_test], 1, :)
return t_train, y_train, t_test, y_test
end
function windowed_dataset(y; input_window=5, output_window=1, stride=1, num_features=1)
L = size(y, ndims(y))
num_samples = (L - input_window - output_window) ÷ stride + 1
X = zeros(Float32, num_features, input_window, num_samples)
Y = zeros(Float32, num_features, output_window, num_samples)
for ii in 1:num_samples, ff in 1:num_features
start_x = stride * (ii - 1) + 1
end_x = start_x + input_window - 1
X[ff, :, ii] .= y[start_x:end_x]
start_y = stride * (ii - 1) + input_window + 1
end_y = start_y + output_window - 1
Y[ff, :, ii] .= y[start_y:end_y]
end
return X, Y
end
t, y = synthetic_data()
begin
fig = Figure(; size=(1000, 400))
ax = Axis(fig[1, 1]; title="Synthetic Time Series", xlabel="t", ylabel="y")
lines!(ax, t, y; label="y", color=:black, linewidth=2)
fig
end
t_train, y_train, t_test, y_test = train_test_split(t, y)
begin
fig = Figure(; size=(1000, 400))
ax = Axis(
fig[1, 1];
title="Time Series Split into Train and Test Sets",
xlabel="t",
ylabel="y",
)
lines!(ax, t_train, y_train[1, :]; label="Train", color=:black, linewidth=2)
lines!(ax, t_test, y_test[1, :]; label="Test", color=:red, linewidth=2)
fig[1, 2] = Legend(fig, ax)
fig
end
X_train, Y_train = windowed_dataset(y_train; input_window=80, output_window=20, stride=5)
X_test, Y_test = windowed_dataset(y_test; input_window=80, output_window=20, stride=5)
begin
fig = Figure(; size=(1000, 400))
ax = Axis(fig[1, 1]; title="Example of Windowed Training Data", xlabel="t", ylabel="y")
linestyles = [:solid, :dash, :dot, :dashdot, :dashdotdot]
for b in 1:4:16
lines!(
ax,
0:79,
X_train[1, :, b];
label="Input",
color=:black,
linewidth=2,
linestyle=linestyles[mod1(b, 5)],
)
lines!(
ax,
79:99,
vcat(X_train[1, end, b], Y_train[1, :, b]);
label="Target",
color=:red,
linewidth=2,
linestyle=linestyles[mod1(b, 5)],
)
end
fig
end
Define the model
julia
struct RNNEncoder{C} <: AbstractLuxWrapperLayer{:cell}
cell::C
end
function (rnn::RNNEncoder)(x::AbstractArray{T,3}, ps, st) where {T}
(y, carry), st = Lux.apply(rnn.cell, x[:, 1, :], ps, st)
@trace for i in 2:size(x, 2)
(y, carry), st = Lux.apply(rnn.cell, (x[:, i, :], carry), ps, st)
end
return (y, carry), st
end
struct RNNDecoder{C,L} <: AbstractLuxContainerLayer{(:cell, :linear)}
cell::C
linear::L
training_mode::Symbol
teacher_forcing_ratio::Float32
function RNNDecoder(
cell::C,
linear::L;
training_mode::Symbol=:recursive,
teacher_forcing_ratio::Float32=0.5f0,
) where {C,L}
@assert training_mode in (:recursive, :teacher_forcing, :mixed_teacher_forcing)
return new{C,L}(cell, linear, training_mode, Float32(teacher_forcing_ratio))
end
end
function LuxCore.initialstates(rng::AbstractRNG, d::RNNDecoder)
return (;
cell=LuxCore.initialstates(rng, d.cell),
linear=LuxCore.initialstates(rng, d.linear),
training=Val(true),
rng,
)
end
function _teacher_forcing_condition(::Val{false}, x, mode, rng, ratio, target_len)
res = similar(x, Bool, target_len)
fill!(res, true)
return res
end
function _teacher_forcing_condition(::Val{true}, x, mode, rng, ratio, target_len)
mode === :recursive &&
return _teacher_forcing_condition(Val(false), x, mode, rng, ratio, target_len)
mode === :teacher_forcing && fill(rand(rng, Float32) < ratio, target_len)
return rand(rng, Float32, target_len) .< ratio
end
function (rnn::RNNDecoder)((decoder_input, carry, target_len, target), ps, st)
@assert ndims(decoder_input) == 2
rng = Lux.replicate(st.rng)
if target === nothing
### This will be optimized out by Reactant
target = similar(
decoder_input, size(decoder_input, 1), target_len, size(decoder_input, 3)
)
fill!(target, 0)
else
@assert size(target, 2) ≤ target_len
end
(y_latent, carry), st_cell = Lux.apply(
rnn.cell, (decoder_input, carry), ps.cell, st.cell
)
y_pred, st_linear = Lux.apply(rnn.linear, y_latent, ps.linear, st.linear)
y_full = similar(y_pred, size(y_pred, 1), target_len, size(y_pred, 2))
y_full[:, 1, :] = y_pred
conditions = _teacher_forcing_condition(
st.training,
decoder_input,
rnn.training_mode,
rng,
rnn.teacher_forcing_ratio,
target_len,
)
decoder_input = ifelse.(@allowscalar(conditions[1]), target[:, 1, :], y_pred)
@trace for i in 2:target_len
(y_latent, carry), st_cell = Lux.apply(
rnn.cell, (decoder_input, carry), ps.cell, st_cell
)
y_pred, st_linear = Lux.apply(rnn.linear, y_latent, ps.linear, st_linear)
y_full[:, i, :] = y_pred
decoder_input = ifelse.(@allowscalar(conditions[i]), target[:, i, :], y_pred)
end
return y_full, merge(st, (; cell=st_cell, linear=st_linear, rng))
end
struct RNNEncoderDecoder{C<:RNNEncoder,L<:RNNDecoder} <:
AbstractLuxContainerLayer{(:encoder, :decoder)}
encoder::C
decoder::L
end
function (rnn::RNNEncoderDecoder)((x, target_len, target), ps, st)
(y, carry), st_encoder = Lux.apply(rnn.encoder, x, ps.encoder, st.encoder)
pred, st_decoder = Lux.apply(
rnn.decoder, (x[:, end, :], carry, target_len, target), ps.decoder, st.decoder
)
return pred, (; encoder=st_encoder, decoder=st_decoder)
end
Training
julia
function train(
train_dataset,
validation_dataset;
nepochs=50,
batchsize=32,
hidden_dims=32,
training_mode=:mixed_teacher_forcing,
teacher_forcing_ratio=0.5f0,
learning_rate=1e-3,
)
(X_train, Y_train), (X_test, Y_test) = train_dataset, validation_dataset
in_dims = size(X_train, 1)
@assert size(Y_train, 2) == size(Y_test, 2)
target_len = size(Y_train, 2)
train_dataloader =
DataLoader(
(X_train, Y_train);
batchsize=min(batchsize, size(X_train, 4)),
shuffle=true,
partial=false,
) |> xdev
X_test, Y_test = (X_test, Y_test) |> xdev
model = RNNEncoderDecoder(
RNNEncoder(LSTMCell(in_dims => hidden_dims)),
RNNDecoder(
LSTMCell(in_dims => hidden_dims),
Dense(hidden_dims => in_dims);
training_mode,
teacher_forcing_ratio,
),
)
ps, st = Lux.setup(Random.default_rng(), model) |> xdev
train_state = Training.TrainState(model, ps, st, Optimisers.Adam(learning_rate))
stime = time()
model_compiled = @compile model((X_test, target_len, nothing), ps, Lux.testmode(st))
ttime = time() - stime
@printf "Compilation time: %.4f seconds\n\n" ttime
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
MSELoss(),
((x, target_len, y), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
y_pred, _ = model_compiled(
(X_test, target_len, nothing),
train_state.parameters,
Lux.testmode(train_state.states),
)
pred_loss = Float32(@jit(MSELoss()(y_pred, Y_test)))
@printf(
"[%3d/%3d]\tTime per epoch: %3.5fs\tValidation Loss: %.4f\n",
epoch,
nepochs,
ttime,
pred_loss,
)
end
return StatefulLuxLayer{true}(
model, train_state.parameters |> cdev, train_state.states |> cdev
)
end
trained_model = train(
(X_train, Y_train),
(X_test, Y_test);
nepochs=50,
batchsize=4,
hidden_dims=32,
training_mode=:mixed_teacher_forcing,
teacher_forcing_ratio=0.5f0,
learning_rate=3e-4,
)
StatefulLuxLayer{true}(
RNNEncoderDecoder(
encoder = RNNEncoder(
cell = LSTMCell(1 => 32), # 4_480 parameters, plus 1
),
decoder = RNNDecoder(
cell = LSTMCell(1 => 32), # 4_480 parameters, plus 1
linear = Dense(32 => 1), # 33 parameters
),
),
) # Total: 8_993 parameters,
# plus 2 states.
Making Predictions
julia
Y_pred = trained_model((X_test, 20, nothing))
begin
fig = Figure(; size=(1200, 800))
for i in 1:4, j in 1:2
b = i + j * 4
ax = Axis(fig[i, j]; xlabel="t", ylabel="y")
i != 4 && hidexdecorations!(ax; grid=false)
j != 1 && hideydecorations!(ax; grid=false)
lines!(ax, 0:79, X_test[1, :, b]; label="Input", color=:black, linewidth=2)
lines!(
ax,
79:99,
vcat(X_test[1, end, b], Y_test[1, :, b]);
label="Ground Truth\n(Noisy)",
color=:red,
linewidth=2,
)
lines!(
ax,
79:99,
vcat(X_test[1, end, b], Y_pred[1, :, b]);
label="Prediction",
color=:blue,
linewidth=2,
)
i == 4 && j == 2 && axislegend(ax; position=:lb)
end
fig[0, :] = Label(fig, "Predictions from Trained Model"; fontsize=20)
fig
end
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.