MNIST Classification using Neural ODEs
To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.
Package Imports
using Lux,
ComponentArrays,
SciMLSensitivity,
LuxCUDA,
Optimisers,
OrdinaryDiffEqTsit5,
Random,
Statistics,
Zygote,
OneHotArrays,
InteractiveUtils,
Printf
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs
CUDA.allowscalar(false)
Precompiling SciMLSensitivity...
391.3 ms ✓ MuladdMacro
441.8 ms ✓ PositiveFactorizations
357.4 ms ✓ ExprTools
927.6 ms ✓ Cassette
292.0 ms ✓ CommonSolve
308.2 ms ✓ FastPower
378.2 ms ✓ PoissonRandom
353.2 ms ✓ FunctionWrappersWrappers
1018.7 ms ✓ DifferentiationInterface
1941.2 ms ✓ ExproniconLite
371.6 ms ✓ SciMLStructures
498.0 ms ✓ TruncatedStacktraces
1882.1 ms ✓ SciMLOperators
593.0 ms ✓ ResettableStacks
724.4 ms ✓ PreallocationTools
1583.3 ms ✓ Tracker → TrackerPDMatsExt
6164.3 ms ✓ Krylov
760.1 ms ✓ FastBroadcast
410.4 ms ✓ RuntimeGeneratedFunctions
499.6 ms ✓ FunctionProperties
1146.3 ms ✓ FastPower → FastPowerTrackerExt
2757.7 ms ✓ TimerOutputs
657.5 ms ✓ FastPower → FastPowerForwardDiffExt
11977.7 ms ✓ ArrayLayouts
3450.5 ms ✓ FastPower → FastPowerReverseDiffExt
670.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
508.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
465.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
1128.3 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
3516.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
861.2 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
696.5 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
547.2 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
1654.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
1769.5 ms ✓ Jieko
810.0 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
1623.8 ms ✓ SymbolicIndexingInterface
865.7 ms ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
3661.1 ms ✓ PreallocationTools → PreallocationToolsReverseDiffExt
1081.0 ms ✓ NLSolversBase
2239.8 ms ✓ RecursiveArrayTools
2612.4 ms ✓ LazyArrays
1962.3 ms ✓ LineSearches
847.3 ms ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
689.7 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
9553.7 ms ✓ Moshi
899.6 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
1212.9 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
768.3 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
1347.2 ms ✓ LazyArrays → LazyArraysStaticArraysExt
3376.1 ms ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
3396.6 ms ✓ Optim
5540.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
10384.1 ms ✓ SciMLBase
27358.0 ms ✓ GPUCompiler
1047.0 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
2655.7 ms ✓ SciMLJacobianOperators
2818.4 ms ✓ DiffEqBase
3545.0 ms ✓ SciMLBase → SciMLBaseZygoteExt
1509.0 ms ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
2352.4 ms ✓ DiffEqBase → DiffEqBaseTrackerExt
4851.4 ms ✓ DiffEqBase → DiffEqBaseReverseDiffExt
1621.6 ms ✓ DiffEqBase → DiffEqBaseForwardDiffExt
2149.1 ms ✓ DiffEqBase → DiffEqBaseDistributionsExt
16482.6 ms ✓ LinearSolve
1554.7 ms ✓ DiffEqBase → DiffEqBaseSparseArraysExt
3602.4 ms ✓ LinearSolve → LinearSolveKernelAbstractionsExt
4549.4 ms ✓ DiffEqCallbacks
1739.6 ms ✓ LinearSolve → LinearSolveEnzymeExt
4523.7 ms ✓ LinearSolve → LinearSolveSparseArraysExt
3932.5 ms ✓ DiffEqNoiseProcess
5019.8 ms ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
220297.3 ms ✓ Enzyme
6137.8 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
10759.9 ms ✓ Enzyme → EnzymeStaticArraysExt
11462.5 ms ✓ Enzyme → EnzymeChainRulesCoreExt
6187.7 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
5818.4 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
6044.5 ms ✓ QuadGK → QuadGKEnzymeExt
5833.2 ms ✓ FastPower → FastPowerEnzymeExt
6041.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
11030.1 ms ✓ DiffEqBase → DiffEqBaseEnzymeExt
21584.2 ms ✓ SciMLSensitivity
83 dependencies successfully precompiled in 310 seconds. 192 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
608.9 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
682.2 ms ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 53 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
1095.4 ms ✓ ComponentArrays → ComponentArraysSciMLBaseExt
1 dependency successfully precompiled in 1 seconds. 73 already precompiled.
Precompiling LuxEnzymeExt...
6978.8 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling LuxCUDA...
369.2 ms ✓ LaTeXStrings
19522.5 ms ✓ PrettyTables
45258.2 ms ✓ DataFrames
47272.2 ms ✓ CUDA
4990.2 ms ✓ Atomix → AtomixCUDAExt
8391.0 ms ✓ cuDNN
5395.3 ms ✓ LuxCUDA
7 dependencies successfully precompiled in 131 seconds. 95 already precompiled.
Precompiling EnzymeBFloat16sExt...
5874.5 ms ✓ Enzyme → EnzymeBFloat16sExt
1 dependency successfully precompiled in 6 seconds. 46 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
4818.4 ms ✓ ArrayInterface → ArrayInterfaceCUDAExt
1 dependency successfully precompiled in 5 seconds. 103 already precompiled.
Precompiling NNlibCUDAExt...
4871.1 ms ✓ CUDA → ChainRulesCoreExt
5514.2 ms ✓ NNlib → NNlibCUDAExt
2 dependencies successfully precompiled in 6 seconds. 104 already precompiled.
Precompiling MLDataDevicesCUDAExt...
5016.2 ms ✓ MLDataDevices → MLDataDevicesCUDAExt
1 dependency successfully precompiled in 5 seconds. 106 already precompiled.
Precompiling LuxLibCUDAExt...
5126.5 ms ✓ CUDA → SpecialFunctionsExt
5144.3 ms ✓ CUDA → EnzymeCoreExt
5642.4 ms ✓ LuxLib → LuxLibCUDAExt
3 dependencies successfully precompiled in 6 seconds. 168 already precompiled.
Precompiling DiffEqBaseCUDAExt...
5493.0 ms ✓ DiffEqBase → DiffEqBaseCUDAExt
1 dependency successfully precompiled in 6 seconds. 168 already precompiled.
Precompiling LinearSolveCUDAExt...
6150.6 ms ✓ LinearSolve → LinearSolveCUDAExt
1 dependency successfully precompiled in 7 seconds. 160 already precompiled.
Precompiling WeightInitializersCUDAExt...
5094.6 ms ✓ WeightInitializers → WeightInitializersCUDAExt
1 dependency successfully precompiled in 5 seconds. 111 already precompiled.
Precompiling NNlibCUDACUDNNExt...
5516.0 ms ✓ NNlib → NNlibCUDACUDNNExt
1 dependency successfully precompiled in 6 seconds. 108 already precompiled.
Precompiling MLDataDevicescuDNNExt...
5200.5 ms ✓ MLDataDevices → MLDataDevicescuDNNExt
1 dependency successfully precompiled in 6 seconds. 109 already precompiled.
Precompiling LuxLibcuDNNExt...
5924.5 ms ✓ LuxLib → LuxLibcuDNNExt
1 dependency successfully precompiled in 6 seconds. 175 already precompiled.
Precompiling OrdinaryDiffEqTsit5...
358.8 ms ✓ SimpleUnPack
3978.3 ms ✓ OrdinaryDiffEqCore
1239.8 ms ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
7221.9 ms ✓ OrdinaryDiffEqTsit5
4 dependencies successfully precompiled in 13 seconds. 96 already precompiled.
Precompiling BangBangDataFramesExt...
1692.0 ms ✓ BangBang → BangBangDataFramesExt
1 dependency successfully precompiled in 2 seconds. 45 already precompiled.
Precompiling SciMLBaseMLStyleExt...
1229.1 ms ✓ SciMLBase → SciMLBaseMLStyleExt
1 dependency successfully precompiled in 2 seconds. 61 already precompiled.
Precompiling TransducersDataFramesExt...
1410.9 ms ✓ Transducers → TransducersDataFramesExt
1 dependency successfully precompiled in 2 seconds. 61 already precompiled.
Precompiling TransducersLazyArraysExt...
1569.9 ms ✓ Transducers → TransducersLazyArraysExt
1 dependency successfully precompiled in 2 seconds. 48 already precompiled.
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST: Only 1500 for demonstration purposes
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split=:train)
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end
# Process images into (H,W,C,BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false),
)
end
loadmnist (generic function with 1 method)
Define the Neural ODE Layer
First we will use the @compact
macro to define the Neural ODE Layer.
function NeuralODECompact(
model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
return @compact(; model, solver, tspan, kwargs...) do x, p
dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
# Note the `p.model` here
prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
@return solve(prob, solver; kwargs...)
end
end
NeuralODECompact (generic function with 1 method)
We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact
was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.
The NeuralODE is a ContainerLayer, which stores a model
. The parameters and states of the NeuralODE are same as those of the underlying model.
struct NeuralODE{M<:Lux.AbstractLuxLayer,So,T,K} <: Lux.AbstractLuxWrapperLayer{:model}
model::M
solver::So
tspan::T
kwargs::K
end
function NeuralODE(
model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
return NeuralODE(model, solver, tspan, kwargs)
end
Main.var"##230".NeuralODE
OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint
can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.
function (n::NeuralODE)(x, ps, st)
function dudt(u, p, t)
u_, st = n.model(reshape(u, size(x)), p, st)
return vec(u_)
end
prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
return solve(prob, n.solver; n.kwargs...), st
end
@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
diffeqsol_to_array (generic function with 2 methods)
Create and Initialize the Neural ODE Layer
function create_model(
model_fn=NeuralODE;
dev=gpu_device(),
use_named_tuple::Bool=false,
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
)
# Construct the Neural ODE Model
model = Chain(
FlattenLayer(),
Dense(784 => 20, tanh),
model_fn(
Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
save_everystep=false,
reltol=1.0f-3,
abstol=1.0f-3,
save_start=false,
sensealg,
),
Base.Fix1(diffeqsol_to_array, 20),
Dense(20 => 10),
)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
ps = dev((use_named_tuple ? ps : ComponentArray(ps)))
st = dev(st)
return model, ps, st
end
create_model (generic function with 2 methods)
Define Utility Functions
const logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model(x, ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
function train(model_function; cpu::Bool=false, kwargs...)
dev = cpu ? cpu_device() : gpu_device()
model, ps, st = create_model(model_function; dev, kwargs...)
# Training
train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))
tstate = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 9
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
_, _, _, tstate = Training.single_train_step!(
AutoZygote(), logitcrossentropy, (x, y), tstate
)
end
ttime = time() - stime
tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
@printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \
Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
end
return nothing
end
train(NeuralODECompact)
[1/9] Time 144.5534s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.8440s Training Accuracy: 58.22222% Test Accuracy: 57.33333%
[3/9] Time 0.7138s Training Accuracy: 67.85185% Test Accuracy: 70.66667%
[4/9] Time 0.8484s Training Accuracy: 74.29630% Test Accuracy: 74.66667%
[5/9] Time 0.6967s Training Accuracy: 76.29630% Test Accuracy: 76.00000%
[6/9] Time 0.8703s Training Accuracy: 78.74074% Test Accuracy: 80.00000%
[7/9] Time 0.6810s Training Accuracy: 82.22222% Test Accuracy: 81.33333%
[8/9] Time 0.9166s Training Accuracy: 83.62963% Test Accuracy: 83.33333%
[9/9] Time 0.6995s Training Accuracy: 85.18519% Test Accuracy: 82.66667%
train(NeuralODE)
[1/9] Time 32.0445s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.6083s Training Accuracy: 57.18519% Test Accuracy: 57.33333%
[3/9] Time 0.7673s Training Accuracy: 68.37037% Test Accuracy: 68.00000%
[4/9] Time 0.6132s Training Accuracy: 73.77778% Test Accuracy: 75.33333%
[5/9] Time 0.8293s Training Accuracy: 76.14815% Test Accuracy: 77.33333%
[6/9] Time 0.6174s Training Accuracy: 79.48148% Test Accuracy: 80.66667%
[7/9] Time 0.6410s Training Accuracy: 81.25926% Test Accuracy: 80.66667%
[8/9] Time 0.6229s Training Accuracy: 83.40741% Test Accuracy: 82.66667%
[9/9] Time 0.6283s Training Accuracy: 84.81481% Test Accuracy: 82.00000%
We can also change the sensealg and train the model! GaussAdjoint
allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray
).
train(NeuralODE; sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), use_named_tuple=true)
[1/9] Time 40.7734s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.6428s Training Accuracy: 58.44444% Test Accuracy: 58.00000%
[3/9] Time 0.8302s Training Accuracy: 66.96296% Test Accuracy: 68.00000%
[4/9] Time 0.6017s Training Accuracy: 72.44444% Test Accuracy: 73.33333%
[5/9] Time 0.8427s Training Accuracy: 76.37037% Test Accuracy: 76.00000%
[6/9] Time 0.6259s Training Accuracy: 78.81481% Test Accuracy: 79.33333%
[7/9] Time 0.6380s Training Accuracy: 80.51852% Test Accuracy: 81.33333%
[8/9] Time 0.8885s Training Accuracy: 82.74074% Test Accuracy: 83.33333%
[9/9] Time 0.8287s Training Accuracy: 85.25926% Test Accuracy: 82.66667%
But remember some AD backends like ReverseDiff
is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.
train(NeuralODE; sensealg=InterpolatingAdjoint(; autojacvec=ReverseDiffVJP()), cpu=true)
[1/9] Time 41.4192s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.4049s Training Accuracy: 58.74074% Test Accuracy: 56.66667%
[3/9] Time 0.3715s Training Accuracy: 69.92593% Test Accuracy: 71.33333%
[4/9] Time 0.3737s Training Accuracy: 72.81481% Test Accuracy: 74.00000%
[5/9] Time 0.3694s Training Accuracy: 76.37037% Test Accuracy: 78.66667%
[6/9] Time 0.3691s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
[7/9] Time 0.3689s Training Accuracy: 81.62963% Test Accuracy: 80.66667%
[8/9] Time 0.3662s Training Accuracy: 83.33333% Test Accuracy: 80.00000%
[9/9] Time 0.3851s Training Accuracy: 85.40741% Test Accuracy: 82.00000%
For completeness, let's also test out discrete sensitivities!
train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true)
[1/9] Time 37.9044s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 10.6361s Training Accuracy: 58.66667% Test Accuracy: 57.33333%
[3/9] Time 11.1547s Training Accuracy: 69.70370% Test Accuracy: 71.33333%
[4/9] Time 11.1869s Training Accuracy: 72.74074% Test Accuracy: 74.00000%
[5/9] Time 10.5031s Training Accuracy: 76.14815% Test Accuracy: 78.66667%
[6/9] Time 10.7118s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
[7/9] Time 10.3844s Training Accuracy: 81.55556% Test Accuracy: 80.66667%
[8/9] Time 10.3411s Training Accuracy: 83.40741% Test Accuracy: 80.00000%
[9/9] Time 11.2435s Training Accuracy: 85.25926% Test Accuracy: 81.33333%
Alternate Implementation using Stateful Layer
Starting v0.5.5
, Lux provides a StatefulLuxLayer
which can be used to avoid the Box
ing of st
. Using the @compact
API avoids this problem entirely.
struct StatefulNeuralODE{M<:Lux.AbstractLuxLayer,So,T,K} <:
Lux.AbstractLuxWrapperLayer{:model}
model::M
solver::So
tspan::T
kwargs::K
end
function StatefulNeuralODE(
model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
return StatefulNeuralODE(model, solver, tspan, kwargs)
end
function (n::StatefulNeuralODE)(x, ps, st)
st_model = StatefulLuxLayer{true}(n.model, ps, st)
dudt(u, p, t) = st_model(u, p)
prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
return solve(prob, n.solver; n.kwargs...), st_model.st
end
Train the new Stateful Neural ODE
train(StatefulNeuralODE)
[1/9] Time 36.7686s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.6270s Training Accuracy: 58.22222% Test Accuracy: 55.33333%
[3/9] Time 0.8992s Training Accuracy: 68.29630% Test Accuracy: 68.66667%
[4/9] Time 0.6302s Training Accuracy: 73.11111% Test Accuracy: 76.00000%
[5/9] Time 0.6249s Training Accuracy: 75.92593% Test Accuracy: 76.66667%
[6/9] Time 0.6767s Training Accuracy: 78.96296% Test Accuracy: 80.66667%
[7/9] Time 0.6419s Training Accuracy: 80.81481% Test Accuracy: 81.33333%
[8/9] Time 0.6378s Training Accuracy: 83.25926% Test Accuracy: 82.66667%
[9/9] Time 0.6355s Training Accuracy: 84.59259% Test Accuracy: 82.00000%
We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.
Type Stability
model, ps, st = create_model(NeuralODE)
model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)
x = gpu_device()(ones(Float32, 28, 28, 1, 3));
NeuralODE is not type stable due to the boxing of st
@code_warntype model(x, ps, st)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│ %3 = (%1)(%2, x, ps, st)::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
└── return %3
We avoid the problem entirely by using StatefulNeuralODE
@code_warntype model_stateful(x, ps_stateful, st_stateful)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│ %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└── return %3
Note, that we still recommend using this layer internally and not exposing this as the default API to the users.
Finally checking the compact model
model_compact, ps_compact, st_compact = create_model(NeuralODECompact)
@code_warntype model_compact(x, ps_compact, st_compact)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│ %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└── return %3
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
CUDA runtime 12.8, artifact installation
CUDA driver 12.8
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.8.4
- CURAND: 10.3.9
- CUFFT: 11.3.3
- CUSOLVER: 11.7.3
- CUSPARSE: 12.5.8
- CUPTI: 2025.1.1 (API 26.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.7.1
- CUDA_Driver_jll: 0.12.1+1
- CUDA_Runtime_jll: 0.16.1+0
Toolchain:
- Julia: 1.11.4
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.982 GiB / 4.750 GiB available)
This page was generated using Literate.jl.