MNIST Classification using Neural ODEs
To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.
Package Imports
using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEqTsit5,
Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs
Precompiling SciMLSensitivity...
618.8 ms ✓ MuladdMacro
527.2 ms ✓ PositiveFactorizations
325.6 ms ✓ CommonSolve
975.0 ms ✓ Cassette
404.2 ms ✓ PoissonRandom
349.8 ms ✓ FastPower
401.7 ms ✓ RuntimeGeneratedFunctions
443.1 ms ✓ Parameters
395.3 ms ✓ FunctionWrappersWrappers
1988.2 ms ✓ ExproniconLite
890.3 ms ✓ DifferentiationInterface
885.3 ms ✓ PDMats
401.7 ms ✓ SciMLStructures
503.5 ms ✓ TruncatedStacktraces
1892.1 ms ✓ SciMLOperators
1142.7 ms ✓ Rmath_jll
700.3 ms ✓ oneTBB_jll
6211.7 ms ✓ Krylov
651.8 ms ✓ ResettableStacks
1807.9 ms ✓ RecipesBase
1085.6 ms ✓ QuadGK
1152.6 ms ✓ HypergeometricFunctions
1303.6 ms ✓ IntelOpenMP_jll
870.9 ms ✓ PreallocationTools
866.1 ms ✓ FastBroadcast
1328.9 ms ✓ NLSolversBase
627.3 ms ✓ FunctionProperties
12117.6 ms ✓ ArrayLayouts
1371.7 ms ✓ FastPower → FastPowerTrackerExt
813.7 ms ✓ FastPower → FastPowerForwardDiffExt
1621.6 ms ✓ SymbolicIndexingInterface
3670.8 ms ✓ FastPower → FastPowerReverseDiffExt
2166.1 ms ✓ Jieko
682.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
6566.5 ms ✓ FastPower → FastPowerEnzymeExt
507.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
453.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
1179.3 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
812.3 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
3864.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
650.3 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
1671.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
675.9 ms ✓ FillArrays → FillArraysPDMatsExt
563.2 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
1443.2 ms ✓ Tracker → TrackerPDMatsExt
839.1 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
819.6 ms ✓ Rmath
1277.5 ms ✓ MKL_jll
5914.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
1829.1 ms ✓ LineSearches
3495.7 ms ✓ PreallocationTools → PreallocationToolsReverseDiffExt
818.8 ms ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
5849.4 ms ✓ QuadGK → QuadGKEnzymeExt
1858.3 ms ✓ StatsFuns
2294.0 ms ✓ RecursiveArrayTools
2504.0 ms ✓ LazyArrays
3295.6 ms ✓ Optim
740.7 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
1625.5 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
824.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
638.0 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
921.0 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
9528.8 ms ✓ Moshi
5074.5 ms ✓ Distributions
1286.1 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
795.8 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
1373.0 ms ✓ LazyArrays → LazyArraysStaticArraysExt
3401.6 ms ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
1700.9 ms ✓ Distributions → DistributionsChainRulesCoreExt
5851.8 ms ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
10676.1 ms ✓ SciMLBase
1123.9 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
2765.5 ms ✓ SciMLJacobianOperators
3634.3 ms ✓ SciMLBase → SciMLBaseZygoteExt
6005.1 ms ✓ DiffEqBase
1413.4 ms ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
2451.0 ms ✓ DiffEqBase → DiffEqBaseTrackerExt
1809.1 ms ✓ DiffEqBase → DiffEqBaseForwardDiffExt
2083.6 ms ✓ DiffEqBase → DiffEqBaseDistributionsExt
5251.5 ms ✓ DiffEqBase → DiffEqBaseReverseDiffExt
1585.8 ms ✓ DiffEqBase → DiffEqBaseSparseArraysExt
4528.9 ms ✓ DiffEqCallbacks
16784.5 ms ✓ LinearSolve
4002.5 ms ✓ DiffEqNoiseProcess
3667.5 ms ✓ LinearSolve → LinearSolveKernelAbstractionsExt
8613.1 ms ✓ DiffEqBase → DiffEqBaseEnzymeExt
1790.9 ms ✓ LinearSolve → LinearSolveEnzymeExt
4851.2 ms ✓ LinearSolve → LinearSolveSparseArraysExt
5038.3 ms ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
21567.9 ms ✓ SciMLSensitivity
90 dependencies successfully precompiled in 103 seconds. 186 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
655.0 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
707.8 ms ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
1418.1 ms ✓ ComponentArrays → ComponentArraysSciMLBaseExt
1 dependency successfully precompiled in 2 seconds. 89 already precompiled.
Precompiling LuxCUDA...
6873.9 ms ✓ LuxCUDA
1 dependency successfully precompiled in 8 seconds. 103 already precompiled.
Precompiling DiffEqBaseCUDAExt...
6230.6 ms ✓ DiffEqBase → DiffEqBaseCUDAExt
1 dependency successfully precompiled in 7 seconds. 172 already precompiled.
Precompiling LinearSolveCUDAExt...
7393.8 ms ✓ LinearSolve → LinearSolveCUDAExt
1 dependency successfully precompiled in 8 seconds. 164 already precompiled.
Precompiling OrdinaryDiffEqTsit5...
361.4 ms ✓ SimpleUnPack
4205.4 ms ✓ OrdinaryDiffEqCore
1276.4 ms ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
7399.2 ms ✓ OrdinaryDiffEqTsit5
4 dependencies successfully precompiled in 13 seconds. 96 already precompiled.
Precompiling MLDatasets...
407.7 ms ✓ Glob
448.7 ms ✓ WorkerUtilities
455.9 ms ✓ TensorCore
503.5 ms ✓ BufferedStreams
466.2 ms ✓ PaddedViews
623.2 ms ✓ URIs
377.8 ms ✓ SimpleBufferStream
412.6 ms ✓ LazyModules
349.4 ms ✓ PackageExtensionCompat
383.9 ms ✓ BitFlags
439.0 ms ✓ MappedArrays
434.0 ms ✓ StackViews
695.3 ms ✓ GZip
745.2 ms ✓ ConcurrentUtilities
630.3 ms ✓ ZipFile
864.0 ms ✓ StructTypes
1041.4 ms ✓ MbedTLS
623.7 ms ✓ MPIPreferences
392.0 ms ✓ InternedStrings
542.5 ms ✓ ExceptionUnwrapping
632.5 ms ✓ Chemfiles_jll
680.5 ms ✓ libaec_jll
519.7 ms ✓ MicrosoftMPI_jll
1064.2 ms ✓ FilePathsBase
590.6 ms ✓ StringEncodings
869.1 ms ✓ WeakRefStrings
4545.3 ms ✓ FileIO
2330.2 ms ✓ ColorVectorSpace
474.4 ms ✓ StridedViews
504.0 ms ✓ MosaicViews
2205.2 ms ✓ OpenSSL
1177.6 ms ✓ OpenMPI_jll
1546.2 ms ✓ MPICH_jll
1199.0 ms ✓ MPItrampoline_jll
583.9 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1353.3 ms ✓ FilePathsBase → FilePathsBaseTestExt
2096.3 ms ✓ NPZ
10125.4 ms ✓ JSON3
21375.9 ms ✓ Unitful
3878.2 ms ✓ ColorSchemes
2565.0 ms ✓ Pickle
19352.0 ms ✓ ImageCore
2035.8 ms ✓ HDF5_jll
19589.5 ms ✓ HTTP
665.2 ms ✓ Unitful → ConstructionBaseUnitfulExt
801.8 ms ✓ Unitful → InverseFunctionsUnitfulExt
3036.2 ms ✓ UnitfulAtomic
2488.2 ms ✓ PeriodicTable
643.2 ms ✓ Accessors → UnitfulExt
35702.7 ms ✓ JLD2
2233.4 ms ✓ ImageBase
3213.2 ms ✓ DataDeps
1997.4 ms ✓ FileIO → HTTPExt
7697.4 ms ✓ HDF5
17494.4 ms ✓ CSV
2348.4 ms ✓ AtomsBase
1969.3 ms ✓ ImageShow
2462.5 ms ✓ MAT
2371.4 ms ✓ Chemfiles
10131.2 ms ✓ MLDatasets
60 dependencies successfully precompiled in 77 seconds. 140 already precompiled.
Precompiling DistributionsTestExt...
1497.4 ms ✓ Distributions → DistributionsTestExt
1 dependency successfully precompiled in 2 seconds. 53 already precompiled.
Precompiling SciMLBaseMLStyleExt...
1137.0 ms ✓ SciMLBase → SciMLBaseMLStyleExt
1 dependency successfully precompiled in 2 seconds. 61 already precompiled.
Precompiling TransducersLazyArraysExt...
1270.1 ms ✓ Transducers → TransducersLazyArraysExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST: Only 1500 for demonstration purposes
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split = :train)
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
imgs = dataset.features
labels_raw = dataset.targets
# Process images into (H,W,C,BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at = train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle = true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle = false),
loadmnist (generic function with 1 method)
Define the Neural ODE Layer
First we will use the @compact
macro to define the Neural ODE Layer.
function NeuralODECompact(
model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
return @compact(; model, solver, tspan, kwargs...) do x, p
dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
# Note the `p.model` here
prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
@return solve(prob, solver; kwargs...)
NeuralODECompact (generic function with 1 method)
We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact
was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.
The NeuralODE is a ContainerLayer, which stores a model
. The parameters and states of the NeuralODE are same as those of the underlying model.
struct NeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <: Lux.AbstractLuxWrapperLayer{:model}
function NeuralODE(
model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
return NeuralODE(model, solver, tspan, kwargs)
OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint
can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.
function (n::NeuralODE)(x, ps, st)
function dudt(u, p, t)
u_, st = n.model(reshape(u, size(x)), p, st)
return vec(u_)
prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
return solve(prob, n.solver; n.kwargs...), st
@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
diffeqsol_to_array (generic function with 2 methods)
Create and Initialize the Neural ODE Layer
function create_model(
model_fn = NeuralODE; dev = gpu_device(), use_named_tuple::Bool = false,
sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP())
# Construct the Neural ODE Model
model = Chain(
Dense(784 => 20, tanh),
Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
save_everystep = false, reltol = 1.0f-3,
abstol = 1.0f-3, save_start = false, sensealg
Base.Fix1(diffeqsol_to_array, 20),
Dense(20 => 10)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
ps = (use_named_tuple ? ps : ComponentArray(ps)) |> dev
st = st |> dev
return model, ps, st
create_model (generic function with 2 methods)
Define Utility Functions
const logitcrossentropy = CrossEntropyLoss(; logits = Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model(x, ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
return total_correct / total
accuracy (generic function with 1 method)
function train(model_function; cpu::Bool = false, kwargs...)
dev = cpu ? cpu_device() : gpu_device()
model, ps, st = create_model(model_function; dev, kwargs...)
# Training
train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev
tstate = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 9
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
_, _, _, tstate = Training.single_train_step!(
AutoZygote(), logitcrossentropy, (x, y), tstate
ttime = time() - stime
tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
@printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \
Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
[1/9] Time 146.9311s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.8216s Training Accuracy: 58.22222% Test Accuracy: 57.33333%
[3/9] Time 1.0646s Training Accuracy: 67.85185% Test Accuracy: 70.66667%
[4/9] Time 0.8538s Training Accuracy: 74.29630% Test Accuracy: 74.66667%
[5/9] Time 1.0689s Training Accuracy: 76.29630% Test Accuracy: 76.00000%
[6/9] Time 0.8333s Training Accuracy: 78.74074% Test Accuracy: 80.00000%
[7/9] Time 1.0549s Training Accuracy: 82.22222% Test Accuracy: 81.33333%
[8/9] Time 0.8210s Training Accuracy: 83.62963% Test Accuracy: 83.33333%
[9/9] Time 0.8249s Training Accuracy: 85.18519% Test Accuracy: 82.66667%
[1/9] Time 35.9996s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.6250s Training Accuracy: 57.18519% Test Accuracy: 57.33333%
[3/9] Time 0.6530s Training Accuracy: 68.37037% Test Accuracy: 68.00000%
[4/9] Time 0.6341s Training Accuracy: 73.77778% Test Accuracy: 75.33333%
[5/9] Time 0.8622s Training Accuracy: 76.14815% Test Accuracy: 77.33333%
[6/9] Time 0.6281s Training Accuracy: 79.48148% Test Accuracy: 80.66667%
[7/9] Time 0.9201s Training Accuracy: 81.25926% Test Accuracy: 80.66667%
[8/9] Time 0.6353s Training Accuracy: 83.40741% Test Accuracy: 82.66667%
[9/9] Time 0.6291s Training Accuracy: 84.81481% Test Accuracy: 82.00000%
We can also change the sensealg and train the model! GaussAdjoint
allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray
train(NeuralODE; sensealg = GaussAdjoint(; autojacvec = ZygoteVJP()), use_named_tuple = true)
[1/9] Time 41.6490s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.6059s Training Accuracy: 58.44444% Test Accuracy: 58.00000%
[3/9] Time 0.7753s Training Accuracy: 66.96296% Test Accuracy: 68.00000%
[4/9] Time 0.6138s Training Accuracy: 72.44444% Test Accuracy: 73.33333%
[5/9] Time 0.6501s Training Accuracy: 76.37037% Test Accuracy: 76.00000%
[6/9] Time 0.6149s Training Accuracy: 78.81481% Test Accuracy: 79.33333%
[7/9] Time 0.6000s Training Accuracy: 80.51852% Test Accuracy: 81.33333%
[8/9] Time 0.8558s Training Accuracy: 82.74074% Test Accuracy: 83.33333%
[9/9] Time 0.6038s Training Accuracy: 85.25926% Test Accuracy: 82.66667%
But remember some AD backends like ReverseDiff
is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.
train(NeuralODE; sensealg = InterpolatingAdjoint(; autojacvec = ReverseDiffVJP()), cpu = true)
[1/9] Time 39.7240s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.4356s Training Accuracy: 58.74074% Test Accuracy: 56.66667%
[3/9] Time 0.3644s Training Accuracy: 69.92593% Test Accuracy: 71.33333%
[4/9] Time 0.3735s Training Accuracy: 72.81481% Test Accuracy: 74.00000%
[5/9] Time 0.3752s Training Accuracy: 76.37037% Test Accuracy: 78.66667%
[6/9] Time 0.3699s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
[7/9] Time 0.3610s Training Accuracy: 81.62963% Test Accuracy: 80.66667%
[8/9] Time 0.3638s Training Accuracy: 83.33333% Test Accuracy: 80.00000%
[9/9] Time 0.3709s Training Accuracy: 85.40741% Test Accuracy: 82.00000%
For completeness, let's also test out discrete sensitivities!
train(NeuralODE; sensealg = ReverseDiffAdjoint(), cpu = true)
[1/9] Time 38.0484s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 11.8893s Training Accuracy: 58.66667% Test Accuracy: 57.33333%
[3/9] Time 11.2469s Training Accuracy: 69.70370% Test Accuracy: 71.33333%
[4/9] Time 11.0207s Training Accuracy: 72.74074% Test Accuracy: 74.00000%
[5/9] Time 10.7469s Training Accuracy: 76.14815% Test Accuracy: 78.66667%
[6/9] Time 10.7892s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
[7/9] Time 10.9633s Training Accuracy: 81.55556% Test Accuracy: 80.66667%
[8/9] Time 10.5584s Training Accuracy: 83.40741% Test Accuracy: 80.00000%
[9/9] Time 10.9576s Training Accuracy: 85.25926% Test Accuracy: 81.33333%
Alternate Implementation using Stateful Layer
Starting v0.5.5
, Lux provides a StatefulLuxLayer
which can be used to avoid the Box
ing of st
. Using the @compact
API avoids this problem entirely.
struct StatefulNeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <:
function StatefulNeuralODE(
model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
return StatefulNeuralODE(model, solver, tspan, kwargs)
function (n::StatefulNeuralODE)(x, ps, st)
st_model = StatefulLuxLayer{true}(n.model, ps, st)
dudt(u, p, t) = st_model(u, p)
prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
return solve(prob, n.solver; n.kwargs...),
Train the new Stateful Neural ODE
[1/9] Time 36.7343s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.6411s Training Accuracy: 58.22222% Test Accuracy: 55.33333%
[3/9] Time 0.6665s Training Accuracy: 68.29630% Test Accuracy: 68.66667%
[4/9] Time 0.9576s Training Accuracy: 73.11111% Test Accuracy: 76.00000%
[5/9] Time 0.6310s Training Accuracy: 75.92593% Test Accuracy: 76.66667%
[6/9] Time 0.6411s Training Accuracy: 78.96296% Test Accuracy: 80.66667%
[7/9] Time 0.6729s Training Accuracy: 80.81481% Test Accuracy: 81.33333%
[8/9] Time 1.0656s Training Accuracy: 83.25926% Test Accuracy: 82.66667%
[9/9] Time 0.6349s Training Accuracy: 84.59259% Test Accuracy: 82.00000%
We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.
Type Stability
model, ps, st = create_model(NeuralODE)
model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)
x = gpu_device()(ones(Float32, 28, 28, 1, 3));
NeuralODE is not type stable due to the boxing of st
@code_warntype model(x, ps, st)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/src/layers/containers.jl:513
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
└── return %3
We avoid the problem entirely by using StatefulNeuralODE
@code_warntype model_stateful(x, ps_stateful, st_stateful)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/src/layers/containers.jl:513
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│ %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└── return %3
Note, that we still recommend using this layer internally and not exposing this as the default API to the users.
Finally checking the compact model
model_compact, ps_compact, st_compact = create_model(NeuralODECompact)
@code_warntype model_compact(x, ps_compact, st_compact)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/src/layers/containers.jl:513
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│ %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└── return %3
using InteractiveUtils
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_DEBUG = Literate
CUDA runtime 12.8, artifact installation
CUDA driver 12.8
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.8.3
- CURAND: 10.3.9
- CUFFT: 11.3.3
- CUSOLVER: 11.7.2
- CUSPARSE: 12.5.7
- CUPTI: 2025.1.0 (API 26.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.7.0
- CUDA_Driver_jll: 0.12.0+0
- CUDA_Runtime_jll: 0.16.0+0
- Julia: 1.11.4
- LLVM: 16.0.6
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.920 GiB / 4.750 GiB available)
This page was generated using Literate.jl.