Skip to content

Building a LSTM Encoder-Decoder model using Lux.jl ​

This examples is based on LSTM_encoder_decoder by Laura Kulowski.

julia
using Lux, Reactant, Random, Optimisers, Statistics, Enzyme, Printf, CairoMakie, MLUtils

const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Precompiling Reactant...
  15227.3 ms  ? Enzyme
  14751.3 ms  ? Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  16420.3 ms  ? Reactant
WARNING: Method definition within_autodiff() in module EnzymeCore at /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Enzyme...
Info Given Enzyme was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  15191.6 ms  ? Enzyme
WARNING: Method definition within_autodiff() in module EnzymeCore at /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling LuxEnzymeExt...
  15146.1 ms  ? Enzyme
Info Given LuxEnzymeExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    880.3 ms  ? Lux → LuxEnzymeExt
   1080.8 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   1102.3 ms  ? Enzyme → EnzymeChainRulesCoreExt
   1176.9 ms  ? Enzyme → EnzymeLogExpFunctionsExt
   1172.4 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1273.2 ms  ? Enzyme → EnzymeStaticArraysExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling EnzymeLogExpFunctionsExt...
  14741.2 ms  ? Enzyme
Info Given EnzymeLogExpFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    766.4 ms  ? Enzyme → EnzymeLogExpFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling EnzymeChainRulesCoreExt...
  14754.7 ms  ? Enzyme
Info Given EnzymeChainRulesCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    770.1 ms  ? Enzyme → EnzymeChainRulesCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling EnzymeSpecialFunctionsExt...
  14793.5 ms  ? Enzyme
    756.1 ms  ? Enzyme → EnzymeLogExpFunctionsExt
Info Given EnzymeSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    967.3 ms  ? Enzyme → EnzymeSpecialFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling EnzymeGPUArraysCoreExt...
  15031.0 ms  ? Enzyme
Info Given EnzymeGPUArraysCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    746.3 ms  ? Enzyme → EnzymeGPUArraysCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling EnzymeStaticArraysExt...
  14917.6 ms  ? Enzyme
Info Given EnzymeStaticArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    972.3 ms  ? Enzyme → EnzymeStaticArraysExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling OptimisersReactantExt...
  14559.9 ms  ? Enzyme
    763.8 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    818.1 ms  ? Enzyme → EnzymeChainRulesCoreExt
   2142.2 ms  ? Reactant
    731.1 ms  ? Reactant → ReactantStatisticsExt
Info Given OptimisersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    840.7 ms  ? Optimisers → OptimisersReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling LuxCoreReactantExt...
  14928.2 ms  ? Enzyme
    806.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2363.9 ms  ? Reactant
Info Given LuxCoreReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    795.4 ms  ? LuxCore → LuxCoreReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling MLDataDevicesReactantExt...
  14444.9 ms  ? Enzyme
    775.4 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2178.3 ms  ? Reactant
Info Given MLDataDevicesReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    821.2 ms  ? MLDataDevices → MLDataDevicesReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling WeightInitializersReactantExt...
  14544.2 ms  ? Enzyme
    769.1 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    803.9 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    930.2 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   2123.0 ms  ? Reactant
    734.5 ms  ? Reactant → ReactantStatisticsExt
Info Given WeightInitializersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    767.2 ms  ? WeightInitializers → WeightInitializersReactantExt
    943.3 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling ReactantKernelAbstractionsExt...
  14601.9 ms  ? Enzyme
    756.8 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    963.0 ms  ? Enzyme → EnzymeStaticArraysExt
   2103.8 ms  ? Reactant
Info Given ReactantKernelAbstractionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    740.0 ms  ? Reactant → ReactantKernelAbstractionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling ReactantArrayInterfaceExt...
  14552.0 ms  ? Enzyme
    759.9 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2108.0 ms  ? Reactant
Info Given ReactantArrayInterfaceExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    749.0 ms  ? Reactant → ReactantArrayInterfaceExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling ReactantSpecialFunctionsExt...
  15475.1 ms  ? Enzyme
    968.7 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    970.9 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1265.0 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   2686.6 ms  ? Reactant
Info Given ReactantSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    970.2 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling ReactantStatisticsExt...
  15894.1 ms  ? Enzyme
    879.1 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2215.5 ms  ? Reactant
Info Given ReactantStatisticsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    747.7 ms  ? Reactant → ReactantStatisticsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling LuxLibReactantExt...
  14757.9 ms  ? Enzyme
    776.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    787.3 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    843.9 ms  ? Enzyme → EnzymeChainRulesCoreExt
    984.4 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   1021.1 ms  ? Enzyme → EnzymeStaticArraysExt
   2110.8 ms  ? Reactant
    817.2 ms  ? Reactant → ReactantStatisticsExt
Info Given LuxLibReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    827.0 ms  ? Reactant → ReactantKernelAbstractionsExt
    820.1 ms  ? LuxLib → LuxLibReactantExt
    844.6 ms  ? Reactant → ReactantArrayInterfaceExt
    876.8 ms  ? MLDataDevices → MLDataDevicesReactantExt
    881.0 ms  ? LuxCore → LuxCoreReactantExt
   1023.3 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling ReactantNNlibExt...
  14682.8 ms  ? Enzyme
    781.4 ms  ? Enzyme → EnzymeChainRulesCoreExt
    807.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1023.2 ms  ? Enzyme → EnzymeStaticArraysExt
   2117.8 ms  ? Reactant
    752.1 ms  ? Reactant → ReactantStatisticsExt
    752.4 ms  ? Reactant → ReactantKernelAbstractionsExt
Info Given ReactantNNlibExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
   1084.0 ms  ? Reactant → ReactantNNlibExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling LuxReactantExt...
  14970.1 ms  ? Enzyme
    794.6 ms  ? Lux → LuxEnzymeExt
    840.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    853.3 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    857.5 ms  ? Enzyme → EnzymeChainRulesCoreExt
   1004.2 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   1068.3 ms  ? Enzyme → EnzymeStaticArraysExt
   2123.5 ms  ? Reactant
    846.3 ms  ? Reactant → ReactantKernelAbstractionsExt
Info Given LuxReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    848.9 ms  ? Lux → LuxReactantExt
    903.6 ms  ? MLDataDevices → MLDataDevicesReactantExt
    901.2 ms  ? LuxCore → LuxCoreReactantExt
    899.6 ms  ? Optimisers → OptimisersReactantExt
   1032.6 ms  ? Reactant → ReactantSpecialFunctionsExt
   1090.7 ms  ? Reactant → ReactantStatisticsExt
   1082.3 ms  ? Reactant → ReactantArrayInterfaceExt
   1078.0 ms  ? WeightInitializers → WeightInitializersReactantExt
   1082.1 ms  ? LuxLib → LuxLibReactantExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling ReactantOffsetArraysExt...
  14645.7 ms  ? Enzyme
    763.1 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2133.5 ms  ? Reactant
Info Given ReactantOffsetArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    784.2 ms  ? Reactant → ReactantOffsetArraysExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling QuadGKEnzymeExt...
  14496.7 ms  ? Enzyme
Info Given QuadGKEnzymeExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    801.4 ms  ? QuadGK → QuadGKEnzymeExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-99ad-d98f374d73ab is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling ReactantFillArraysExt...
  14427.4 ms  ? Enzyme
    762.7 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2118.4 ms  ? Reactant
Info Given ReactantFillArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    766.0 ms  ? Reactant → ReactantFillArraysExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
Precompiling ReactantAbstractFFTsExt...
  14510.3 ms  ? Enzyme
    764.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2108.4 ms  ? Reactant
Info Given ReactantAbstractFFTsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541
    770.1 ms  ? Reactant → ReactantAbstractFFTsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-b818-251801c36361 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
â”” @ Base loading.jl:2541

Generate synthetic data ​

julia
function synthetic_data(Nt=2000, tf=80 * Float32(Ï€))
    t = range(0.0f0, tf; length=Nt)
    y = sin.(2.0f0 * t) .+ 0.5f0 * cos.(t) .+ randn(Float32, Nt) * 0.2f0
    return t, y
end

function train_test_split(t, y, split=0.8)
    indx_split = ceil(Int, length(t) * split)
    indx_train = 1:indx_split
    indx_test = (indx_split + 1):length(t)

    t_train = t[indx_train]
    y_train = reshape(y[indx_train], 1, :)

    t_test = t[indx_test]
    y_test = reshape(y[indx_test], 1, :)

    return t_train, y_train, t_test, y_test
end

function windowed_dataset(y; input_window=5, output_window=1, stride=1, num_features=1)
    L = size(y, ndims(y))
    num_samples = (L - input_window - output_window) ÷ stride + 1

    X = zeros(Float32, num_features, input_window, num_samples)
    Y = zeros(Float32, num_features, output_window, num_samples)

    for ii in 1:num_samples, ff in 1:num_features
        start_x = stride * (ii - 1) + 1
        end_x = start_x + input_window - 1
        X[ff, :, ii] .= y[start_x:end_x]

        start_y = stride * (ii - 1) + input_window + 1
        end_y = start_y + output_window - 1
        Y[ff, :, ii] .= y[start_y:end_y]
    end

    return X, Y
end

t, y = synthetic_data()

begin
    fig = Figure(; size=(1000, 400))
    ax = Axis(fig[1, 1]; title="Synthetic Time Series", xlabel="t", ylabel="y")

    lines!(ax, t, y; label="y", color=:black, linewidth=2)

    fig
end

julia
t_train, y_train, t_test, y_test = train_test_split(t, y)

begin
    fig = Figure(; size=(1000, 400))
    ax = Axis(
        fig[1, 1];
        title="Time Series Split into Train and Test Sets",
        xlabel="t",
        ylabel="y",
    )

    lines!(ax, t_train, y_train[1, :]; label="Train", color=:black, linewidth=2)
    lines!(ax, t_test, y_test[1, :]; label="Test", color=:red, linewidth=2)

    fig[1, 2] = Legend(fig, ax)

    fig
end

julia
X_train, Y_train = windowed_dataset(y_train; input_window=80, output_window=20, stride=5)
X_test, Y_test = windowed_dataset(y_test; input_window=80, output_window=20, stride=5)

begin
    fig = Figure(; size=(1000, 400))
    ax = Axis(fig[1, 1]; title="Example of Windowed Training Data", xlabel="t", ylabel="y")

    linestyles = [:solid, :dash, :dot, :dashdot, :dashdotdot]

    for b in 1:4:16
        lines!(
            ax,
            0:79,
            X_train[1, :, b];
            label="Input",
            color=:black,
            linewidth=2,
            linestyle=linestyles[mod1(b, 5)],
        )
        lines!(
            ax,
            79:99,
            vcat(X_train[1, end, b], Y_train[1, :, b]);
            label="Target",
            color=:red,
            linewidth=2,
            linestyle=linestyles[mod1(b, 5)],
        )
    end

    fig
end

Define the model ​

julia
struct RNNEncoder{C} <: AbstractLuxWrapperLayer{:cell}
    cell::C
end

function (rnn::RNNEncoder)(x::AbstractArray{T,3}, ps, st) where {T}
    (y, carry), st = Lux.apply(rnn.cell, x[:, 1, :], ps, st)
    @trace for i in 2:size(x, 2)
        (y, carry), st = Lux.apply(rnn.cell, (x[:, i, :], carry), ps, st)
    end
    return (y, carry), st
end

struct RNNDecoder{C,L} <: AbstractLuxContainerLayer{(:cell, :linear)}
    cell::C
    linear::L
    training_mode::Symbol
    teacher_forcing_ratio::Float32

    function RNNDecoder(
        cell::C,
        linear::L;
        training_mode::Symbol=:recursive,
        teacher_forcing_ratio::Float32=0.5f0,
    ) where {C,L}
        @assert training_mode in (:recursive, :teacher_forcing, :mixed_teacher_forcing)
        return new{C,L}(cell, linear, training_mode, Float32(teacher_forcing_ratio))
    end
end

function LuxCore.initialstates(rng::AbstractRNG, d::RNNDecoder)
    return (;
        cell=LuxCore.initialstates(rng, d.cell),
        linear=LuxCore.initialstates(rng, d.linear),
        training=Val(true),
        rng,
    )
end

function _teacher_forcing_condition(::Val{false}, x, mode, rng, ratio, target_len)
    res = similar(x, Bool, target_len)
    fill!(res, true)
    return res
end
function _teacher_forcing_condition(::Val{true}, x, mode, rng, ratio, target_len)
    mode === :recursive &&
        return _teacher_forcing_condition(Val(false), x, mode, rng, ratio, target_len)
    mode === :teacher_forcing && fill(rand(rng, Float32) < ratio, target_len)
    return rand(rng, Float32, target_len) .< ratio
end

function (rnn::RNNDecoder)((decoder_input, carry, target_len, target), ps, st)
    @assert ndims(decoder_input) == 2
    rng = Lux.replicate(st.rng)

    if target === nothing
        ### This will be optimized out by Reactant
        target = similar(
            decoder_input, size(decoder_input, 1), target_len, size(decoder_input, 3)
        )
        fill!(target, 0)
    else
        @assert size(target, 2) ≤ target_len
    end

    (y_latent, carry), st_cell = Lux.apply(
        rnn.cell, (decoder_input, carry), ps.cell, st.cell
    )
    y_pred, st_linear = Lux.apply(rnn.linear, y_latent, ps.linear, st.linear)

    y_full = similar(y_pred, size(y_pred, 1), target_len, size(y_pred, 2))
    y_full[:, 1, :] = y_pred

    conditions = _teacher_forcing_condition(
        st.training,
        decoder_input,
        rnn.training_mode,
        rng,
        rnn.teacher_forcing_ratio,
        target_len,
    )
    decoder_input = ifelse.(@allowscalar(conditions[1]), target[:, 1, :], y_pred)

    @trace for i in 2:target_len
        (y_latent, carry), st_cell = Lux.apply(
            rnn.cell, (decoder_input, carry), ps.cell, st_cell
        )

        y_pred, st_linear = Lux.apply(rnn.linear, y_latent, ps.linear, st_linear)
        y_full[:, i, :] = y_pred

        decoder_input = ifelse.(@allowscalar(conditions[i]), target[:, i, :], y_pred)
    end

    return y_full, merge(st, (; cell=st_cell, linear=st_linear, rng))
end

struct RNNEncoderDecoder{C<:RNNEncoder,L<:RNNDecoder} <:
       AbstractLuxContainerLayer{(:encoder, :decoder)}
    encoder::C
    decoder::L
end

function (rnn::RNNEncoderDecoder)((x, target_len, target), ps, st)
    (y, carry), st_encoder = Lux.apply(rnn.encoder, x, ps.encoder, st.encoder)
    pred, st_decoder = Lux.apply(
        rnn.decoder, (x[:, end, :], carry, target_len, target), ps.decoder, st.decoder
    )
    return pred, (; encoder=st_encoder, decoder=st_decoder)
end

Training ​

julia
function train(
    train_dataset,
    validation_dataset;
    nepochs=50,
    batchsize=32,
    hidden_dims=32,
    training_mode=:mixed_teacher_forcing,
    teacher_forcing_ratio=0.5f0,
    learning_rate=1e-3,
)
    (X_train, Y_train), (X_test, Y_test) = train_dataset, validation_dataset
    in_dims = size(X_train, 1)
    @assert size(Y_train, 2) == size(Y_test, 2)
    target_len = size(Y_train, 2)

    train_dataloader =
        DataLoader(
            (X_train, Y_train);
            batchsize=min(batchsize, size(X_train, 4)),
            shuffle=true,
            partial=false,
        ) |> xdev
    X_test, Y_test = (X_test, Y_test) |> xdev

    model = RNNEncoderDecoder(
        RNNEncoder(LSTMCell(in_dims => hidden_dims)),
        RNNDecoder(
            LSTMCell(in_dims => hidden_dims),
            Dense(hidden_dims => in_dims);
            training_mode,
            teacher_forcing_ratio,
        ),
    )
    ps, st = Lux.setup(Random.default_rng(), model) |> xdev

    train_state = Training.TrainState(model, ps, st, Optimisers.Adam(learning_rate))

    stime = time()
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((X_test, target_len, nothing), ps, Lux.testmode(st))
    end
    ttime = time() - stime
    @printf "Compilation time: %.4f seconds\n\n" ttime

    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                MSELoss(),
                ((x, target_len, y), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        y_pred, _ = model_compiled(
            (X_test, target_len, nothing),
            train_state.parameters,
            Lux.testmode(train_state.states),
        )
        pred_loss = Float32(@jit(MSELoss()(y_pred, Y_test)))

        @printf(
            "[%3d/%3d]\tTime per epoch: %3.5fs\tValidation Loss: %.4f\n",
            epoch,
            nepochs,
            ttime,
            pred_loss,
        )
    end

    return StatefulLuxLayer(
        model, train_state.parameters |> cdev, train_state.states |> cdev
    )
end

trained_model = train(
    (X_train, Y_train),
    (X_test, Y_test);
    nepochs=50,
    batchsize=4,
    hidden_dims=32,
    training_mode=:mixed_teacher_forcing,
    teacher_forcing_ratio=0.5f0,
    learning_rate=3e-4,
)
StatefulLuxLayer{Val{true}()}(
    RNNEncoderDecoder(
        encoder = RNNEncoder(
            cell = LSTMCell(1 => 32),             # 4_480 parameters, plus 1 non-trainable
        ),
        decoder = RNNDecoder(
            cell = LSTMCell(1 => 32),             # 4_480 parameters, plus 1 non-trainable
            linear = Dense(32 => 1),              # 33 parameters
        ),
    ),
)         # Total: 8_993 parameters,
          #        plus 2 states.

Making Predictions ​

julia
Y_pred = trained_model((X_test, 20, nothing))

begin
    fig = Figure(; size=(1200, 800))

    for i in 1:4, j in 1:2
        b = i + j * 4
        ax = Axis(fig[i, j]; xlabel="t", ylabel="y")
        i != 4 && hidexdecorations!(ax; grid=false)
        j != 1 && hideydecorations!(ax; grid=false)

        lines!(ax, 0:79, X_test[1, :, b]; label="Input", color=:black, linewidth=2)
        lines!(
            ax,
            79:99,
            vcat(X_test[1, end, b], Y_test[1, :, b]);
            label="Ground Truth\n(Noisy)",
            color=:red,
            linewidth=2,
        )
        lines!(
            ax,
            79:99,
            vcat(X_test[1, end, b], Y_pred[1, :, b]);
            label="Prediction",
            color=:blue,
            linewidth=2,
        )

        i == 4 && j == 2 && axislegend(ax; position=:lb)
    end

    fig[0, :] = Label(fig, "Predictions from Trained Model"; fontsize=20)

    fig
end

Appendix ​

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.