Skip to content

MNIST Classification using Neural ODEs

To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.

Package Imports

julia
using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEqTsit5,
    Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs

CUDA.allowscalar(false)
Precompiling SciMLSensitivity...
    618.8 ms  ✓ MuladdMacro
    527.2 ms  ✓ PositiveFactorizations
    325.6 ms  ✓ CommonSolve
    975.0 ms  ✓ Cassette
    404.2 ms  ✓ PoissonRandom
    349.8 ms  ✓ FastPower
    401.7 ms  ✓ RuntimeGeneratedFunctions
    443.1 ms  ✓ Parameters
    395.3 ms  ✓ FunctionWrappersWrappers
   1988.2 ms  ✓ ExproniconLite
    890.3 ms  ✓ DifferentiationInterface
    885.3 ms  ✓ PDMats
    401.7 ms  ✓ SciMLStructures
    503.5 ms  ✓ TruncatedStacktraces
   1892.1 ms  ✓ SciMLOperators
   1142.7 ms  ✓ Rmath_jll
    700.3 ms  ✓ oneTBB_jll
   6211.7 ms  ✓ Krylov
    651.8 ms  ✓ ResettableStacks
   1807.9 ms  ✓ RecipesBase
   1085.6 ms  ✓ QuadGK
   1152.6 ms  ✓ HypergeometricFunctions
   1303.6 ms  ✓ IntelOpenMP_jll
    870.9 ms  ✓ PreallocationTools
    866.1 ms  ✓ FastBroadcast
   1328.9 ms  ✓ NLSolversBase
    627.3 ms  ✓ FunctionProperties
  12117.6 ms  ✓ ArrayLayouts
   1371.7 ms  ✓ FastPower → FastPowerTrackerExt
    813.7 ms  ✓ FastPower → FastPowerForwardDiffExt
   1621.6 ms  ✓ SymbolicIndexingInterface
   3670.8 ms  ✓ FastPower → FastPowerReverseDiffExt
   2166.1 ms  ✓ Jieko
    682.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
   6566.5 ms  ✓ FastPower → FastPowerEnzymeExt
    507.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
    453.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
   1179.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
    812.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
   3864.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
    650.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
   1671.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
    675.9 ms  ✓ FillArrays → FillArraysPDMatsExt
    563.2 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   1443.2 ms  ✓ Tracker → TrackerPDMatsExt
    839.1 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
    819.6 ms  ✓ Rmath
   1277.5 ms  ✓ MKL_jll
   5914.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   1829.1 ms  ✓ LineSearches
   3495.7 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
    818.8 ms  ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
   5849.4 ms  ✓ QuadGK → QuadGKEnzymeExt
   1858.3 ms  ✓ StatsFuns
   2294.0 ms  ✓ RecursiveArrayTools
   2504.0 ms  ✓ LazyArrays
   3295.6 ms  ✓ Optim
    740.7 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   1625.5 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
    824.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
    638.0 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
    921.0 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   9528.8 ms  ✓ Moshi
   5074.5 ms  ✓ Distributions
   1286.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
    795.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
   1373.0 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
   3401.6 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   1700.9 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   5851.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
  10676.1 ms  ✓ SciMLBase
   1123.9 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   2765.5 ms  ✓ SciMLJacobianOperators
   3634.3 ms  ✓ SciMLBase → SciMLBaseZygoteExt
   6005.1 ms  ✓ DiffEqBase
   1413.4 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
   2451.0 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   1809.1 ms  ✓ DiffEqBase → DiffEqBaseForwardDiffExt
   2083.6 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
   5251.5 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
   1585.8 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
   4528.9 ms  ✓ DiffEqCallbacks
  16784.5 ms  ✓ LinearSolve
   4002.5 ms  ✓ DiffEqNoiseProcess
   3667.5 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
   8613.1 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
   1790.9 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   4851.2 ms  ✓ LinearSolve → LinearSolveSparseArraysExt
   5038.3 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
  21567.9 ms  ✓ SciMLSensitivity
  90 dependencies successfully precompiled in 103 seconds. 186 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    655.0 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
    707.8 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
   1418.1 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  1 dependency successfully precompiled in 2 seconds. 89 already precompiled.
Precompiling LuxCUDA...
   6873.9 ms  ✓ LuxCUDA
  1 dependency successfully precompiled in 8 seconds. 103 already precompiled.
Precompiling DiffEqBaseCUDAExt...
   6230.6 ms  ✓ DiffEqBase → DiffEqBaseCUDAExt
  1 dependency successfully precompiled in 7 seconds. 172 already precompiled.
Precompiling LinearSolveCUDAExt...
   7393.8 ms  ✓ LinearSolve → LinearSolveCUDAExt
  1 dependency successfully precompiled in 8 seconds. 164 already precompiled.
Precompiling OrdinaryDiffEqTsit5...
    361.4 ms  ✓ SimpleUnPack
   4205.4 ms  ✓ OrdinaryDiffEqCore
   1276.4 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   7399.2 ms  ✓ OrdinaryDiffEqTsit5
  4 dependencies successfully precompiled in 13 seconds. 96 already precompiled.
Precompiling MLDatasets...
    407.7 ms  ✓ Glob
    448.7 ms  ✓ WorkerUtilities
    455.9 ms  ✓ TensorCore
    503.5 ms  ✓ BufferedStreams
    466.2 ms  ✓ PaddedViews
    623.2 ms  ✓ URIs
    377.8 ms  ✓ SimpleBufferStream
    412.6 ms  ✓ LazyModules
    349.4 ms  ✓ PackageExtensionCompat
    383.9 ms  ✓ BitFlags
    439.0 ms  ✓ MappedArrays
    434.0 ms  ✓ StackViews
    695.3 ms  ✓ GZip
    745.2 ms  ✓ ConcurrentUtilities
    630.3 ms  ✓ ZipFile
    864.0 ms  ✓ StructTypes
   1041.4 ms  ✓ MbedTLS
    623.7 ms  ✓ MPIPreferences
    392.0 ms  ✓ InternedStrings
    542.5 ms  ✓ ExceptionUnwrapping
    632.5 ms  ✓ Chemfiles_jll
    680.5 ms  ✓ libaec_jll
    519.7 ms  ✓ MicrosoftMPI_jll
   1064.2 ms  ✓ FilePathsBase
    590.6 ms  ✓ StringEncodings
    869.1 ms  ✓ WeakRefStrings
   4545.3 ms  ✓ FileIO
   2330.2 ms  ✓ ColorVectorSpace
    474.4 ms  ✓ StridedViews
    504.0 ms  ✓ MosaicViews
   2205.2 ms  ✓ OpenSSL
   1177.6 ms  ✓ OpenMPI_jll
   1546.2 ms  ✓ MPICH_jll
   1199.0 ms  ✓ MPItrampoline_jll
    583.9 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1353.3 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   2096.3 ms  ✓ NPZ
  10125.4 ms  ✓ JSON3
  21375.9 ms  ✓ Unitful
   3878.2 ms  ✓ ColorSchemes
   2565.0 ms  ✓ Pickle
  19352.0 ms  ✓ ImageCore
   2035.8 ms  ✓ HDF5_jll
  19589.5 ms  ✓ HTTP
    665.2 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    801.8 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   3036.2 ms  ✓ UnitfulAtomic
   2488.2 ms  ✓ PeriodicTable
    643.2 ms  ✓ Accessors → UnitfulExt
  35702.7 ms  ✓ JLD2
   2233.4 ms  ✓ ImageBase
   3213.2 ms  ✓ DataDeps
   1997.4 ms  ✓ FileIO → HTTPExt
   7697.4 ms  ✓ HDF5
  17494.4 ms  ✓ CSV
   2348.4 ms  ✓ AtomsBase
   1969.3 ms  ✓ ImageShow
   2462.5 ms  ✓ MAT
   2371.4 ms  ✓ Chemfiles
  10131.2 ms  ✓ MLDatasets
  60 dependencies successfully precompiled in 77 seconds. 140 already precompiled.
Precompiling DistributionsTestExt...
   1497.4 ms  ✓ Distributions → DistributionsTestExt
  1 dependency successfully precompiled in 2 seconds. 53 already precompiled.
Precompiling SciMLBaseMLStyleExt...
   1137.0 ms  ✓ SciMLBase → SciMLBaseMLStyleExt
  1 dependency successfully precompiled in 2 seconds. 61 already precompiled.
Precompiling TransducersLazyArraysExt...
   1270.1 ms  ✓ Transducers → TransducersLazyArraysExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST: Only 1500 for demonstration purposes
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split = :train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at = train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle = true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle = false),
    )
end
loadmnist (generic function with 1 method)

Define the Neural ODE Layer

First we will use the @compact macro to define the Neural ODE Layer.

julia
function NeuralODECompact(
        model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
    )
    return @compact(; model, solver, tspan, kwargs...) do x, p
        dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
        # Note the `p.model` here
        prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
        @return solve(prob, solver; kwargs...)
    end
end
NeuralODECompact (generic function with 1 method)

We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.

The NeuralODE is a ContainerLayer, which stores a model. The parameters and states of the NeuralODE are same as those of the underlying model.

julia
struct NeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <: Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function NeuralODE(
        model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
    )
    return NeuralODE(model, solver, tspan, kwargs)
end
Main.var"##230".NeuralODE

OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.

julia
function (n::NeuralODE)(x, ps, st)
    function dudt(u, p, t)
        u_, st = n.model(reshape(u, size(x)), p, st)
        return vec(u_)
    end
    prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st
end

@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
diffeqsol_to_array (generic function with 2 methods)

Create and Initialize the Neural ODE Layer

julia
function create_model(
        model_fn = NeuralODE; dev = gpu_device(), use_named_tuple::Bool = false,
        sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP())
    )
    # Construct the Neural ODE Model
    model = Chain(
        FlattenLayer(),
        Dense(784 => 20, tanh),
        model_fn(
            Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
            save_everystep = false, reltol = 1.0f-3,
            abstol = 1.0f-3, save_start = false, sensealg
        ),
        Base.Fix1(diffeqsol_to_array, 20),
        Dense(20 => 10)
    )

    rng = Random.default_rng()
    Random.seed!(rng, 0)

    ps, st = Lux.setup(rng, model)
    ps = (use_named_tuple ? ps : ComponentArray(ps)) |> dev
    st = st |> dev

    return model, ps, st
end
create_model (generic function with 2 methods)

Define Utility Functions

julia
const logitcrossentropy = CrossEntropyLoss(; logits = Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model(x, ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train(model_function; cpu::Bool = false, kwargs...)
    dev = cpu ? cpu_device() : gpu_device()
    model, ps, st = create_model(model_function; dev, kwargs...)

    # Training
    train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev

    tstate = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 9
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, tstate = Training.single_train_step!(
                AutoZygote(), logitcrossentropy, (x, y), tstate
            )
        end
        ttime = time() - stime

        tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
        te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
        @printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \
                 Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
    end
    return
end

train(NeuralODECompact)
[1/9]	Time 146.9311s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.8216s	Training Accuracy: 58.22222%	Test Accuracy: 57.33333%
[3/9]	Time 1.0646s	Training Accuracy: 67.85185%	Test Accuracy: 70.66667%
[4/9]	Time 0.8538s	Training Accuracy: 74.29630%	Test Accuracy: 74.66667%
[5/9]	Time 1.0689s	Training Accuracy: 76.29630%	Test Accuracy: 76.00000%
[6/9]	Time 0.8333s	Training Accuracy: 78.74074%	Test Accuracy: 80.00000%
[7/9]	Time 1.0549s	Training Accuracy: 82.22222%	Test Accuracy: 81.33333%
[8/9]	Time 0.8210s	Training Accuracy: 83.62963%	Test Accuracy: 83.33333%
[9/9]	Time 0.8249s	Training Accuracy: 85.18519%	Test Accuracy: 82.66667%
julia
train(NeuralODE)
[1/9]	Time 35.9996s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.6250s	Training Accuracy: 57.18519%	Test Accuracy: 57.33333%
[3/9]	Time 0.6530s	Training Accuracy: 68.37037%	Test Accuracy: 68.00000%
[4/9]	Time 0.6341s	Training Accuracy: 73.77778%	Test Accuracy: 75.33333%
[5/9]	Time 0.8622s	Training Accuracy: 76.14815%	Test Accuracy: 77.33333%
[6/9]	Time 0.6281s	Training Accuracy: 79.48148%	Test Accuracy: 80.66667%
[7/9]	Time 0.9201s	Training Accuracy: 81.25926%	Test Accuracy: 80.66667%
[8/9]	Time 0.6353s	Training Accuracy: 83.40741%	Test Accuracy: 82.66667%
[9/9]	Time 0.6291s	Training Accuracy: 84.81481%	Test Accuracy: 82.00000%

We can also change the sensealg and train the model! GaussAdjoint allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray).

julia
train(NeuralODE; sensealg = GaussAdjoint(; autojacvec = ZygoteVJP()), use_named_tuple = true)
[1/9]	Time 41.6490s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.6059s	Training Accuracy: 58.44444%	Test Accuracy: 58.00000%
[3/9]	Time 0.7753s	Training Accuracy: 66.96296%	Test Accuracy: 68.00000%
[4/9]	Time 0.6138s	Training Accuracy: 72.44444%	Test Accuracy: 73.33333%
[5/9]	Time 0.6501s	Training Accuracy: 76.37037%	Test Accuracy: 76.00000%
[6/9]	Time 0.6149s	Training Accuracy: 78.81481%	Test Accuracy: 79.33333%
[7/9]	Time 0.6000s	Training Accuracy: 80.51852%	Test Accuracy: 81.33333%
[8/9]	Time 0.8558s	Training Accuracy: 82.74074%	Test Accuracy: 83.33333%
[9/9]	Time 0.6038s	Training Accuracy: 85.25926%	Test Accuracy: 82.66667%

But remember some AD backends like ReverseDiff is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.

julia
train(NeuralODE; sensealg = InterpolatingAdjoint(; autojacvec = ReverseDiffVJP()), cpu = true)
[1/9]	Time 39.7240s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.4356s	Training Accuracy: 58.74074%	Test Accuracy: 56.66667%
[3/9]	Time 0.3644s	Training Accuracy: 69.92593%	Test Accuracy: 71.33333%
[4/9]	Time 0.3735s	Training Accuracy: 72.81481%	Test Accuracy: 74.00000%
[5/9]	Time 0.3752s	Training Accuracy: 76.37037%	Test Accuracy: 78.66667%
[6/9]	Time 0.3699s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 0.3610s	Training Accuracy: 81.62963%	Test Accuracy: 80.66667%
[8/9]	Time 0.3638s	Training Accuracy: 83.33333%	Test Accuracy: 80.00000%
[9/9]	Time 0.3709s	Training Accuracy: 85.40741%	Test Accuracy: 82.00000%

For completeness, let's also test out discrete sensitivities!

julia
train(NeuralODE; sensealg = ReverseDiffAdjoint(), cpu = true)
[1/9]	Time 38.0484s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 11.8893s	Training Accuracy: 58.66667%	Test Accuracy: 57.33333%
[3/9]	Time 11.2469s	Training Accuracy: 69.70370%	Test Accuracy: 71.33333%
[4/9]	Time 11.0207s	Training Accuracy: 72.74074%	Test Accuracy: 74.00000%
[5/9]	Time 10.7469s	Training Accuracy: 76.14815%	Test Accuracy: 78.66667%
[6/9]	Time 10.7892s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 10.9633s	Training Accuracy: 81.55556%	Test Accuracy: 80.66667%
[8/9]	Time 10.5584s	Training Accuracy: 83.40741%	Test Accuracy: 80.00000%
[9/9]	Time 10.9576s	Training Accuracy: 85.25926%	Test Accuracy: 81.33333%

Alternate Implementation using Stateful Layer

Starting v0.5.5, Lux provides a StatefulLuxLayer which can be used to avoid the Boxing of st. Using the @compact API avoids this problem entirely.

julia
struct StatefulNeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <:
    Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function StatefulNeuralODE(
        model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
    )
    return StatefulNeuralODE(model, solver, tspan, kwargs)
end

function (n::StatefulNeuralODE)(x, ps, st)
    st_model = StatefulLuxLayer{true}(n.model, ps, st)
    dudt(u, p, t) = st_model(u, p)
    prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st_model.st
end

Train the new Stateful Neural ODE

julia
train(StatefulNeuralODE)
[1/9]	Time 36.7343s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.6411s	Training Accuracy: 58.22222%	Test Accuracy: 55.33333%
[3/9]	Time 0.6665s	Training Accuracy: 68.29630%	Test Accuracy: 68.66667%
[4/9]	Time 0.9576s	Training Accuracy: 73.11111%	Test Accuracy: 76.00000%
[5/9]	Time 0.6310s	Training Accuracy: 75.92593%	Test Accuracy: 76.66667%
[6/9]	Time 0.6411s	Training Accuracy: 78.96296%	Test Accuracy: 80.66667%
[7/9]	Time 0.6729s	Training Accuracy: 80.81481%	Test Accuracy: 81.33333%
[8/9]	Time 1.0656s	Training Accuracy: 83.25926%	Test Accuracy: 82.66667%
[9/9]	Time 0.6349s	Training Accuracy: 84.59259%	Test Accuracy: 82.00000%

We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.

Type Stability

julia
model, ps, st = create_model(NeuralODE)

model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)

x = gpu_device()(ones(Float32, 28, 28, 1, 3));

NeuralODE is not type stable due to the boxing of st

julia
@code_warntype model(x, ps, st)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/src/layers/containers.jl:513
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
└──      return %3

We avoid the problem entirely by using StatefulNeuralODE

julia
@code_warntype model_stateful(x, ps_stateful, st_stateful)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/src/layers/containers.jl:513
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Note, that we still recommend using this layer internally and not exposing this as the default API to the users.

Finally checking the compact model

julia
model_compact, ps_compact, st_compact = create_model(NeuralODECompact)

@code_warntype model_compact(x, ps_compact, st_compact)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/src/layers/containers.jl:513
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.8, artifact installation
CUDA driver 12.8
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.8.3
- CURAND: 10.3.9
- CUFFT: 11.3.3
- CUSOLVER: 11.7.2
- CUSPARSE: 12.5.7
- CUPTI: 2025.1.0 (API 26.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.7.0
- CUDA_Driver_jll: 0.12.0+0
- CUDA_Runtime_jll: 0.16.0+0

Toolchain:
- Julia: 1.11.4
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.920 GiB / 4.750 GiB available)

This page was generated using Literate.jl.