Skip to content

MNIST Classification using Neural ODEs

To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.

Package Imports

julia
using Lux,
    ComponentArrays,
    SciMLSensitivity,
    LuxCUDA,
    Optimisers,
    OrdinaryDiffEqTsit5,
    Random,
    Statistics,
    Zygote,
    OneHotArrays,
    InteractiveUtils,
    Printf
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs

CUDA.allowscalar(false)
Precompiling SciMLSensitivity...
    391.3 ms  ✓ MuladdMacro
    441.8 ms  ✓ PositiveFactorizations
    357.4 ms  ✓ ExprTools
    927.6 ms  ✓ Cassette
    292.0 ms  ✓ CommonSolve
    308.2 ms  ✓ FastPower
    378.2 ms  ✓ PoissonRandom
    353.2 ms  ✓ FunctionWrappersWrappers
   1018.7 ms  ✓ DifferentiationInterface
   1941.2 ms  ✓ ExproniconLite
    371.6 ms  ✓ SciMLStructures
    498.0 ms  ✓ TruncatedStacktraces
   1882.1 ms  ✓ SciMLOperators
    593.0 ms  ✓ ResettableStacks
    724.4 ms  ✓ PreallocationTools
   1583.3 ms  ✓ Tracker → TrackerPDMatsExt
   6164.3 ms  ✓ Krylov
    760.1 ms  ✓ FastBroadcast
    410.4 ms  ✓ RuntimeGeneratedFunctions
    499.6 ms  ✓ FunctionProperties
   1146.3 ms  ✓ FastPower → FastPowerTrackerExt
   2757.7 ms  ✓ TimerOutputs
    657.5 ms  ✓ FastPower → FastPowerForwardDiffExt
  11977.7 ms  ✓ ArrayLayouts
   3450.5 ms  ✓ FastPower → FastPowerReverseDiffExt
    670.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
    508.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
    465.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
   1128.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   3516.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
    861.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
    696.5 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
    547.2 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   1654.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   1769.5 ms  ✓ Jieko
    810.0 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   1623.8 ms  ✓ SymbolicIndexingInterface
    865.7 ms  ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
   3661.1 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
   1081.0 ms  ✓ NLSolversBase
   2239.8 ms  ✓ RecursiveArrayTools
   2612.4 ms  ✓ LazyArrays
   1962.3 ms  ✓ LineSearches
    847.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
    689.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
   9553.7 ms  ✓ Moshi
    899.6 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   1212.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
    768.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
   1347.2 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
   3376.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   3396.6 ms  ✓ Optim
   5540.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
  10384.1 ms  ✓ SciMLBase
  27358.0 ms  ✓ GPUCompiler
   1047.0 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   2655.7 ms  ✓ SciMLJacobianOperators
   2818.4 ms  ✓ DiffEqBase
   3545.0 ms  ✓ SciMLBase → SciMLBaseZygoteExt
   1509.0 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
   2352.4 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   4851.4 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
   1621.6 ms  ✓ DiffEqBase → DiffEqBaseForwardDiffExt
   2149.1 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
  16482.6 ms  ✓ LinearSolve
   1554.7 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
   3602.4 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
   4549.4 ms  ✓ DiffEqCallbacks
   1739.6 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   4523.7 ms  ✓ LinearSolve → LinearSolveSparseArraysExt
   3932.5 ms  ✓ DiffEqNoiseProcess
   5019.8 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
 220297.3 ms  ✓ Enzyme
   6137.8 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
  10759.9 ms  ✓ Enzyme → EnzymeStaticArraysExt
  11462.5 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   6187.7 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   5818.4 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
   6044.5 ms  ✓ QuadGK → QuadGKEnzymeExt
   5833.2 ms  ✓ FastPower → FastPowerEnzymeExt
   6041.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
  11030.1 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
  21584.2 ms  ✓ SciMLSensitivity
  83 dependencies successfully precompiled in 310 seconds. 192 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    608.9 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
    682.2 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 53 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
   1095.4 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  1 dependency successfully precompiled in 1 seconds. 73 already precompiled.
Precompiling LuxEnzymeExt...
   6978.8 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling LuxCUDA...
    369.2 ms  ✓ LaTeXStrings
  19522.5 ms  ✓ PrettyTables
  45258.2 ms  ✓ DataFrames
  47272.2 ms  ✓ CUDA
   4990.2 ms  ✓ Atomix → AtomixCUDAExt
   8391.0 ms  ✓ cuDNN
   5395.3 ms  ✓ LuxCUDA
  7 dependencies successfully precompiled in 131 seconds. 95 already precompiled.
Precompiling EnzymeBFloat16sExt...
   5874.5 ms  ✓ Enzyme → EnzymeBFloat16sExt
  1 dependency successfully precompiled in 6 seconds. 46 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
   4818.4 ms  ✓ ArrayInterface → ArrayInterfaceCUDAExt
  1 dependency successfully precompiled in 5 seconds. 103 already precompiled.
Precompiling NNlibCUDAExt...
   4871.1 ms  ✓ CUDA → ChainRulesCoreExt
   5514.2 ms  ✓ NNlib → NNlibCUDAExt
  2 dependencies successfully precompiled in 6 seconds. 104 already precompiled.
Precompiling MLDataDevicesCUDAExt...
   5016.2 ms  ✓ MLDataDevices → MLDataDevicesCUDAExt
  1 dependency successfully precompiled in 5 seconds. 106 already precompiled.
Precompiling LuxLibCUDAExt...
   5126.5 ms  ✓ CUDA → SpecialFunctionsExt
   5144.3 ms  ✓ CUDA → EnzymeCoreExt
   5642.4 ms  ✓ LuxLib → LuxLibCUDAExt
  3 dependencies successfully precompiled in 6 seconds. 168 already precompiled.
Precompiling DiffEqBaseCUDAExt...
   5493.0 ms  ✓ DiffEqBase → DiffEqBaseCUDAExt
  1 dependency successfully precompiled in 6 seconds. 168 already precompiled.
Precompiling LinearSolveCUDAExt...
   6150.6 ms  ✓ LinearSolve → LinearSolveCUDAExt
  1 dependency successfully precompiled in 7 seconds. 160 already precompiled.
Precompiling WeightInitializersCUDAExt...
   5094.6 ms  ✓ WeightInitializers → WeightInitializersCUDAExt
  1 dependency successfully precompiled in 5 seconds. 111 already precompiled.
Precompiling NNlibCUDACUDNNExt...
   5516.0 ms  ✓ NNlib → NNlibCUDACUDNNExt
  1 dependency successfully precompiled in 6 seconds. 108 already precompiled.
Precompiling MLDataDevicescuDNNExt...
   5200.5 ms  ✓ MLDataDevices → MLDataDevicescuDNNExt
  1 dependency successfully precompiled in 6 seconds. 109 already precompiled.
Precompiling LuxLibcuDNNExt...
   5924.5 ms  ✓ LuxLib → LuxLibcuDNNExt
  1 dependency successfully precompiled in 6 seconds. 175 already precompiled.
Precompiling OrdinaryDiffEqTsit5...
    358.8 ms  ✓ SimpleUnPack
   3978.3 ms  ✓ OrdinaryDiffEqCore
   1239.8 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   7221.9 ms  ✓ OrdinaryDiffEqTsit5
  4 dependencies successfully precompiled in 13 seconds. 96 already precompiled.
Precompiling BangBangDataFramesExt...
   1692.0 ms  ✓ BangBang → BangBangDataFramesExt
  1 dependency successfully precompiled in 2 seconds. 45 already precompiled.
Precompiling SciMLBaseMLStyleExt...
   1229.1 ms  ✓ SciMLBase → SciMLBaseMLStyleExt
  1 dependency successfully precompiled in 2 seconds. 61 already precompiled.
Precompiling TransducersDataFramesExt...
   1410.9 ms  ✓ Transducers → TransducersDataFramesExt
  1 dependency successfully precompiled in 2 seconds. 61 already precompiled.
Precompiling TransducersLazyArraysExt...
   1569.9 ms  ✓ Transducers → TransducersLazyArraysExt
  1 dependency successfully precompiled in 2 seconds. 48 already precompiled.

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST: Only 1500 for demonstration purposes
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split=:train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false),
    )
end
loadmnist (generic function with 1 method)

Define the Neural ODE Layer

First we will use the @compact macro to define the Neural ODE Layer.

julia
function NeuralODECompact(
    model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
    return @compact(; model, solver, tspan, kwargs...) do x, p
        dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
        # Note the `p.model` here
        prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
        @return solve(prob, solver; kwargs...)
    end
end
NeuralODECompact (generic function with 1 method)

We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.

The NeuralODE is a ContainerLayer, which stores a model. The parameters and states of the NeuralODE are same as those of the underlying model.

julia
struct NeuralODE{M<:Lux.AbstractLuxLayer,So,T,K} <: Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function NeuralODE(
    model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
    return NeuralODE(model, solver, tspan, kwargs)
end
Main.var"##230".NeuralODE

OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.

julia
function (n::NeuralODE)(x, ps, st)
    function dudt(u, p, t)
        u_, st = n.model(reshape(u, size(x)), p, st)
        return vec(u_)
    end
    prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st
end

@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
diffeqsol_to_array (generic function with 2 methods)

Create and Initialize the Neural ODE Layer

julia
function create_model(
    model_fn=NeuralODE;
    dev=gpu_device(),
    use_named_tuple::Bool=false,
    sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
)
    # Construct the Neural ODE Model
    model = Chain(
        FlattenLayer(),
        Dense(784 => 20, tanh),
        model_fn(
            Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
            save_everystep=false,
            reltol=1.0f-3,
            abstol=1.0f-3,
            save_start=false,
            sensealg,
        ),
        Base.Fix1(diffeqsol_to_array, 20),
        Dense(20 => 10),
    )

    rng = Random.default_rng()
    Random.seed!(rng, 0)

    ps, st = Lux.setup(rng, model)
    ps = dev((use_named_tuple ? ps : ComponentArray(ps)))
    st = dev(st)

    return model, ps, st
end
create_model (generic function with 2 methods)

Define Utility Functions

julia
const logitcrossentropy = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model(x, ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train(model_function; cpu::Bool=false, kwargs...)
    dev = cpu ? cpu_device() : gpu_device()
    model, ps, st = create_model(model_function; dev, kwargs...)

    # Training
    train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))

    tstate = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 9
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, tstate = Training.single_train_step!(
                AutoZygote(), logitcrossentropy, (x, y), tstate
            )
        end
        ttime = time() - stime

        tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
        te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
        @printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \
                 Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
    end
    return nothing
end

train(NeuralODECompact)
[1/9]	Time 144.5534s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.8440s	Training Accuracy: 58.22222%	Test Accuracy: 57.33333%
[3/9]	Time 0.7138s	Training Accuracy: 67.85185%	Test Accuracy: 70.66667%
[4/9]	Time 0.8484s	Training Accuracy: 74.29630%	Test Accuracy: 74.66667%
[5/9]	Time 0.6967s	Training Accuracy: 76.29630%	Test Accuracy: 76.00000%
[6/9]	Time 0.8703s	Training Accuracy: 78.74074%	Test Accuracy: 80.00000%
[7/9]	Time 0.6810s	Training Accuracy: 82.22222%	Test Accuracy: 81.33333%
[8/9]	Time 0.9166s	Training Accuracy: 83.62963%	Test Accuracy: 83.33333%
[9/9]	Time 0.6995s	Training Accuracy: 85.18519%	Test Accuracy: 82.66667%
julia
train(NeuralODE)
[1/9]	Time 32.0445s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.6083s	Training Accuracy: 57.18519%	Test Accuracy: 57.33333%
[3/9]	Time 0.7673s	Training Accuracy: 68.37037%	Test Accuracy: 68.00000%
[4/9]	Time 0.6132s	Training Accuracy: 73.77778%	Test Accuracy: 75.33333%
[5/9]	Time 0.8293s	Training Accuracy: 76.14815%	Test Accuracy: 77.33333%
[6/9]	Time 0.6174s	Training Accuracy: 79.48148%	Test Accuracy: 80.66667%
[7/9]	Time 0.6410s	Training Accuracy: 81.25926%	Test Accuracy: 80.66667%
[8/9]	Time 0.6229s	Training Accuracy: 83.40741%	Test Accuracy: 82.66667%
[9/9]	Time 0.6283s	Training Accuracy: 84.81481%	Test Accuracy: 82.00000%

We can also change the sensealg and train the model! GaussAdjoint allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray).

julia
train(NeuralODE; sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), use_named_tuple=true)
[1/9]	Time 40.7734s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.6428s	Training Accuracy: 58.44444%	Test Accuracy: 58.00000%
[3/9]	Time 0.8302s	Training Accuracy: 66.96296%	Test Accuracy: 68.00000%
[4/9]	Time 0.6017s	Training Accuracy: 72.44444%	Test Accuracy: 73.33333%
[5/9]	Time 0.8427s	Training Accuracy: 76.37037%	Test Accuracy: 76.00000%
[6/9]	Time 0.6259s	Training Accuracy: 78.81481%	Test Accuracy: 79.33333%
[7/9]	Time 0.6380s	Training Accuracy: 80.51852%	Test Accuracy: 81.33333%
[8/9]	Time 0.8885s	Training Accuracy: 82.74074%	Test Accuracy: 83.33333%
[9/9]	Time 0.8287s	Training Accuracy: 85.25926%	Test Accuracy: 82.66667%

But remember some AD backends like ReverseDiff is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.

julia
train(NeuralODE; sensealg=InterpolatingAdjoint(; autojacvec=ReverseDiffVJP()), cpu=true)
[1/9]	Time 41.4192s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.4049s	Training Accuracy: 58.74074%	Test Accuracy: 56.66667%
[3/9]	Time 0.3715s	Training Accuracy: 69.92593%	Test Accuracy: 71.33333%
[4/9]	Time 0.3737s	Training Accuracy: 72.81481%	Test Accuracy: 74.00000%
[5/9]	Time 0.3694s	Training Accuracy: 76.37037%	Test Accuracy: 78.66667%
[6/9]	Time 0.3691s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 0.3689s	Training Accuracy: 81.62963%	Test Accuracy: 80.66667%
[8/9]	Time 0.3662s	Training Accuracy: 83.33333%	Test Accuracy: 80.00000%
[9/9]	Time 0.3851s	Training Accuracy: 85.40741%	Test Accuracy: 82.00000%

For completeness, let's also test out discrete sensitivities!

julia
train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true)
[1/9]	Time 37.9044s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 10.6361s	Training Accuracy: 58.66667%	Test Accuracy: 57.33333%
[3/9]	Time 11.1547s	Training Accuracy: 69.70370%	Test Accuracy: 71.33333%
[4/9]	Time 11.1869s	Training Accuracy: 72.74074%	Test Accuracy: 74.00000%
[5/9]	Time 10.5031s	Training Accuracy: 76.14815%	Test Accuracy: 78.66667%
[6/9]	Time 10.7118s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 10.3844s	Training Accuracy: 81.55556%	Test Accuracy: 80.66667%
[8/9]	Time 10.3411s	Training Accuracy: 83.40741%	Test Accuracy: 80.00000%
[9/9]	Time 11.2435s	Training Accuracy: 85.25926%	Test Accuracy: 81.33333%

Alternate Implementation using Stateful Layer

Starting v0.5.5, Lux provides a StatefulLuxLayer which can be used to avoid the Boxing of st. Using the @compact API avoids this problem entirely.

julia
struct StatefulNeuralODE{M<:Lux.AbstractLuxLayer,So,T,K} <:
       Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function StatefulNeuralODE(
    model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
    return StatefulNeuralODE(model, solver, tspan, kwargs)
end

function (n::StatefulNeuralODE)(x, ps, st)
    st_model = StatefulLuxLayer{true}(n.model, ps, st)
    dudt(u, p, t) = st_model(u, p)
    prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st_model.st
end

Train the new Stateful Neural ODE

julia
train(StatefulNeuralODE)
[1/9]	Time 36.7686s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.6270s	Training Accuracy: 58.22222%	Test Accuracy: 55.33333%
[3/9]	Time 0.8992s	Training Accuracy: 68.29630%	Test Accuracy: 68.66667%
[4/9]	Time 0.6302s	Training Accuracy: 73.11111%	Test Accuracy: 76.00000%
[5/9]	Time 0.6249s	Training Accuracy: 75.92593%	Test Accuracy: 76.66667%
[6/9]	Time 0.6767s	Training Accuracy: 78.96296%	Test Accuracy: 80.66667%
[7/9]	Time 0.6419s	Training Accuracy: 80.81481%	Test Accuracy: 81.33333%
[8/9]	Time 0.6378s	Training Accuracy: 83.25926%	Test Accuracy: 82.66667%
[9/9]	Time 0.6355s	Training Accuracy: 84.59259%	Test Accuracy: 82.00000%

We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.

Type Stability

julia
model, ps, st = create_model(NeuralODE)

model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)

x = gpu_device()(ones(Float32, 28, 28, 1, 3));

NeuralODE is not type stable due to the boxing of st

julia
@code_warntype model(x, ps, st)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
└──      return %3

We avoid the problem entirely by using StatefulNeuralODE

julia
@code_warntype model_stateful(x, ps_stateful, st_stateful)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Note, that we still recommend using this layer internally and not exposing this as the default API to the users.

Finally checking the compact model

julia
model_compact, ps_compact, st_compact = create_model(NeuralODECompact)

@code_warntype model_compact(x, ps_compact, st_compact)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

CUDA runtime 12.8, artifact installation
CUDA driver 12.8
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.8.4
- CURAND: 10.3.9
- CUFFT: 11.3.3
- CUSOLVER: 11.7.3
- CUSPARSE: 12.5.8
- CUPTI: 2025.1.1 (API 26.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.7.1
- CUDA_Driver_jll: 0.12.1+1
- CUDA_Runtime_jll: 0.16.1+0

Toolchain:
- Julia: 1.11.4
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.982 GiB / 4.750 GiB available)

This page was generated using Literate.jl.