Skip to content

MNIST Classification using Neural ODEs

To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.

Package Imports

julia
using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEqTsit5,
    Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs

CUDA.allowscalar(false)
Precompiling SciMLSensitivity...
    345.9 ms  ✓ IteratorInterfaceExtensions
    386.3 ms  ✓ ExprTools
    578.9 ms  ✓ AbstractFFTs
    381.8 ms  ✓ StatsAPI
    433.0 ms  ✓ InverseFunctions
    317.2 ms  ✓ DataValueInterfaces
   1020.5 ms  ✓ FillArrays
    549.4 ms  ✓ EnumX
    494.4 ms  ✓ StructIO
    376.0 ms  ✓ Zlib_jll
    352.6 ms  ✓ PtrArrays
    371.4 ms  ✓ DataAPI
    790.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
    913.5 ms  ✓ PDMats
    372.3 ms  ✓ SciMLStructures
    513.7 ms  ✓ TruncatedStacktraces
    475.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
    424.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
    619.6 ms  ✓ ResettableStacks
    611.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
    539.8 ms  ✓ FunctionProperties
    664.9 ms  ✓ FastPower → FastPowerForwardDiffExt
    880.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
    743.2 ms  ✓ PreallocationTools
   1241.8 ms  ✓ NLSolversBase
   6286.6 ms  ✓ Krylov
   1207.2 ms  ✓ FastPower → FastPowerTrackerExt
   3674.8 ms  ✓ FastPower → FastPowerReverseDiffExt
   2149.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   2030.2 ms  ✓ FastBroadcast
   4603.5 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
    386.5 ms  ✓ TableTraits
    476.9 ms  ✓ RuntimeGeneratedFunctions
    455.0 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
    436.6 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    391.6 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    479.0 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    683.6 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
    426.2 ms  ✓ FillArrays → FillArraysStatisticsExt
   2837.3 ms  ✓ TimerOutputs
    710.2 ms  ✓ FillArrays → FillArraysSparseArraysExt
    447.7 ms  ✓ AliasTables
    466.9 ms  ✓ Missings
   2024.3 ms  ✓ ObjectFile
    648.0 ms  ✓ FillArrays → FillArraysPDMatsExt
   1560.6 ms  ✓ Tracker → TrackerPDMatsExt
   1830.6 ms  ✓ LineSearches
    840.9 ms  ✓ Tables
   3391.2 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
   2560.2 ms  ✓ Accessors
   2434.2 ms  ✓ StatsBase
  12043.1 ms  ✓ ArrayLayouts
    794.5 ms  ✓ StructArrays
    992.8 ms  ✓ Accessors → LinearAlgebraExt
    726.3 ms  ✓ Accessors → StaticArraysExt
   3358.6 ms  ✓ Optim
    853.8 ms  ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
    406.6 ms  ✓ StructArrays → StructArraysAdaptExt
   5240.5 ms  ✓ Distributions
    652.1 ms  ✓ StructArrays → StructArraysSparseArraysExt
    665.8 ms  ✓ StructArrays → StructArraysStaticArraysExt
    715.5 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
    416.6 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
    475.3 ms  ✓ Accessors → StructArraysExt
   1893.7 ms  ✓ SciMLOperators
   1548.4 ms  ✓ SymbolicIndexingInterface
   1456.0 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   2495.9 ms  ✓ LazyArrays
    585.7 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
    811.7 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   2340.9 ms  ✓ RecursiveArrayTools
   5559.6 ms  ✓ ChainRules
   1323.3 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
    871.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
    639.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
    887.2 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   1213.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
    782.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
    804.1 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
  27467.4 ms  ✓ GPUCompiler
  11131.2 ms  ✓ SciMLBase
   1127.8 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   2839.3 ms  ✓ SciMLJacobianOperators
   5779.9 ms  ✓ DiffEqBase
  35610.9 ms  ✓ Zygote
   1431.6 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
  17326.7 ms  ✓ LinearSolve
   2555.6 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   1668.3 ms  ✓ DiffEqBase → DiffEqBaseForwardDiffExt
   5029.5 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
   2119.5 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
   1668.8 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
   2224.6 ms  ✓ Zygote → ZygoteTrackerExt
   1614.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   4551.1 ms  ✓ DiffEqCallbacks
   3591.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   3838.7 ms  ✓ SciMLBase → SciMLBaseZygoteExt
   1943.9 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   3818.0 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
   4864.2 ms  ✓ LinearSolve → LinearSolveSparseArraysExt
   3909.3 ms  ✓ DiffEqNoiseProcess
   4912.5 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
   5809.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
 221803.3 ms  ✓ Enzyme
   6517.0 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   8443.8 ms  ✓ Enzyme → EnzymeStaticArraysExt
  11592.5 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   6371.2 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   5916.4 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
   6151.5 ms  ✓ FastPower → FastPowerEnzymeExt
   6344.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   6280.2 ms  ✓ QuadGK → QuadGKEnzymeExt
   8080.8 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
  22124.6 ms  ✓ SciMLSensitivity
  114 dependencies successfully precompiled in 313 seconds. 162 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    668.6 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
    744.6 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
   1176.2 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  1 dependency successfully precompiled in 1 seconds. 89 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    456.3 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling LuxEnzymeExt...
   7059.8 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 8 seconds. 148 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    845.6 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1607.3 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  1 dependency successfully precompiled in 2 seconds. 110 already precompiled.
Precompiling LuxZygoteExt...
   2863.4 ms  ✓ Lux → LuxZygoteExt
  1 dependency successfully precompiled in 3 seconds. 168 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1779.4 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
  1 dependency successfully precompiled in 2 seconds. 118 already precompiled.
Precompiling LuxCUDA...
    389.7 ms  ✓ LaTeXStrings
    497.1 ms  ✓ PooledArrays
   1366.4 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
  20567.3 ms  ✓ PrettyTables
  45800.7 ms  ✓ DataFrames
  46864.9 ms  ✓ CUDA
   5354.5 ms  ✓ Atomix → AtomixCUDAExt
   8497.8 ms  ✓ cuDNN
   5377.6 ms  ✓ LuxCUDA
  9 dependencies successfully precompiled in 133 seconds. 94 already precompiled.
Precompiling EnzymeBFloat16sExt...
   5600.1 ms  ✓ Enzyme → EnzymeBFloat16sExt
  1 dependency successfully precompiled in 6 seconds. 47 already precompiled.
Precompiling ZygoteColorsExt...
   1870.0 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 106 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
   4910.6 ms  ✓ ArrayInterface → ArrayInterfaceCUDAExt
  1 dependency successfully precompiled in 5 seconds. 104 already precompiled.
Precompiling NNlibCUDAExt...
   4947.4 ms  ✓ CUDA → ChainRulesCoreExt
   5304.4 ms  ✓ NNlib → NNlibCUDAExt
  2 dependencies successfully precompiled in 6 seconds. 105 already precompiled.
Precompiling MLDataDevicesCUDAExt...
   4881.6 ms  ✓ MLDataDevices → MLDataDevicesCUDAExt
  1 dependency successfully precompiled in 5 seconds. 107 already precompiled.
Precompiling LuxLibCUDAExt...
   5066.0 ms  ✓ CUDA → SpecialFunctionsExt
   5216.5 ms  ✓ CUDA → EnzymeCoreExt
   5659.9 ms  ✓ LuxLib → LuxLibCUDAExt
  3 dependencies successfully precompiled in 6 seconds. 171 already precompiled.
Precompiling DiffEqBaseCUDAExt...
    641.2 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    670.0 ms  ✓ Accessors → TestExt
   5462.1 ms  ✓ DiffEqBase → DiffEqBaseCUDAExt
  3 dependencies successfully precompiled in 6 seconds. 169 already precompiled.
Precompiling LinearSolveCUDAExt...
   6471.1 ms  ✓ LinearSolve → LinearSolveCUDAExt
  1 dependency successfully precompiled in 7 seconds. 163 already precompiled.
Precompiling WeightInitializersCUDAExt...
   4923.8 ms  ✓ WeightInitializers → WeightInitializersCUDAExt
  1 dependency successfully precompiled in 5 seconds. 112 already precompiled.
Precompiling NNlibCUDACUDNNExt...
   5217.4 ms  ✓ NNlib → NNlibCUDACUDNNExt
  1 dependency successfully precompiled in 6 seconds. 109 already precompiled.
Precompiling MLDataDevicescuDNNExt...
   4975.9 ms  ✓ MLDataDevices → MLDataDevicescuDNNExt
  1 dependency successfully precompiled in 5 seconds. 110 already precompiled.
Precompiling LuxLibcuDNNExt...
   5750.7 ms  ✓ LuxLib → LuxLibcuDNNExt
  1 dependency successfully precompiled in 6 seconds. 178 already precompiled.
Precompiling OrdinaryDiffEqTsit5...
   4034.8 ms  ✓ OrdinaryDiffEqCore
   1299.6 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   7158.5 ms  ✓ OrdinaryDiffEqTsit5
  3 dependencies successfully precompiled in 13 seconds. 97 already precompiled.
Precompiling MLDatasets...
    448.2 ms  ✓ TensorCore
    538.2 ms  ✓ TranscodingStreams
    412.5 ms  ✓ LazyModules
    967.7 ms  ✓ OffsetArrays
    451.4 ms  ✓ MappedArrays
    693.2 ms  ✓ GZip
    615.3 ms  ✓ ZipFile
    765.6 ms  ✓ ConcurrentUtilities
    721.0 ms  ✓ Unitful → InverseFunctionsUnitfulExt
    683.1 ms  ✓ Accessors → UnitfulExt
    790.4 ms  ✓ BangBang
    518.9 ms  ✓ ExceptionUnwrapping
    849.1 ms  ✓ WeakRefStrings
   1154.5 ms  ✓ MLCore
   1412.1 ms  ✓ SplittablesBase
   1249.1 ms  ✓ FilePathsBase → FilePathsBaseTestExt
    480.0 ms  ✓ CodecZlib
   1852.0 ms  ✓ HDF5_jll
    419.0 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
   2467.5 ms  ✓ ColorVectorSpace
    424.1 ms  ✓ StackViews
    446.9 ms  ✓ PaddedViews
   1601.7 ms  ✓ NPZ
    672.8 ms  ✓ BangBang → BangBangStaticArraysExt
   2468.7 ms  ✓ Pickle
    496.1 ms  ✓ BangBang → BangBangChainRulesCoreExt
    518.8 ms  ✓ BangBang → BangBangTablesExt
   1651.7 ms  ✓ BangBang → BangBangDataFramesExt
   1087.0 ms  ✓ MicroCollections
  17593.5 ms  ✓ CSV
  19373.2 ms  ✓ HTTP
   3552.2 ms  ✓ ColorSchemes
    499.5 ms  ✓ MosaicViews
   7608.8 ms  ✓ HDF5
   2807.0 ms  ✓ Transducers
   3063.6 ms  ✓ DataDeps
   1884.1 ms  ✓ FileIO → HTTPExt
  34519.3 ms  ✓ JLD2
   1398.6 ms  ✓ Transducers → TransducersDataFramesExt
   2408.7 ms  ✓ MAT
    711.3 ms  ✓ Transducers → TransducersAdaptExt
   5714.9 ms  ✓ FLoops
   6144.1 ms  ✓ MLUtils
  19426.6 ms  ✓ ImageCore
   2091.6 ms  ✓ ImageBase
   1894.7 ms  ✓ ImageShow
   9864.4 ms  ✓ MLDatasets
  47 dependencies successfully precompiled in 72 seconds. 153 already precompiled.
Precompiling DistributionsTestExt...
   1656.5 ms  ✓ Distributions → DistributionsTestExt
  1 dependency successfully precompiled in 2 seconds. 53 already precompiled.
Precompiling BangBangStructArraysExt...
    509.1 ms  ✓ BangBang → BangBangStructArraysExt
  1 dependency successfully precompiled in 1 seconds. 24 already precompiled.
Precompiling SciMLBaseMLStyleExt...
   1099.2 ms  ✓ SciMLBase → SciMLBaseMLStyleExt
  1 dependency successfully precompiled in 1 seconds. 61 already precompiled.
Precompiling TransducersLazyArraysExt...
   1256.0 ms  ✓ Transducers → TransducersLazyArraysExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1588.8 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 104 already precompiled.
Precompiling LuxMLUtilsExt...
   2364.7 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 3 seconds. 170 already precompiled.

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST: Only 1500 for demonstration purposes
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split = :train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at = train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle = true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle = false),
    )
end
loadmnist (generic function with 1 method)

Define the Neural ODE Layer

First we will use the @compact macro to define the Neural ODE Layer.

julia
function NeuralODECompact(
        model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
    )
    return @compact(; model, solver, tspan, kwargs...) do x, p
        dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
        # Note the `p.model` here
        prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
        @return solve(prob, solver; kwargs...)
    end
end
NeuralODECompact (generic function with 1 method)

We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.

The NeuralODE is a ContainerLayer, which stores a model. The parameters and states of the NeuralODE are same as those of the underlying model.

julia
struct NeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <: Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function NeuralODE(
        model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
    )
    return NeuralODE(model, solver, tspan, kwargs)
end
Main.var"##230".NeuralODE

OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.

julia
function (n::NeuralODE)(x, ps, st)
    function dudt(u, p, t)
        u_, st = n.model(reshape(u, size(x)), p, st)
        return vec(u_)
    end
    prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st
end

@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
diffeqsol_to_array (generic function with 2 methods)

Create and Initialize the Neural ODE Layer

julia
function create_model(
        model_fn = NeuralODE; dev = gpu_device(), use_named_tuple::Bool = false,
        sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP())
    )
    # Construct the Neural ODE Model
    model = Chain(
        FlattenLayer(),
        Dense(784 => 20, tanh),
        model_fn(
            Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
            save_everystep = false, reltol = 1.0f-3,
            abstol = 1.0f-3, save_start = false, sensealg
        ),
        Base.Fix1(diffeqsol_to_array, 20),
        Dense(20 => 10)
    )

    rng = Random.default_rng()
    Random.seed!(rng, 0)

    ps, st = Lux.setup(rng, model)
    ps = (use_named_tuple ? ps : ComponentArray(ps)) |> dev
    st = st |> dev

    return model, ps, st
end
create_model (generic function with 2 methods)

Define Utility Functions

julia
const logitcrossentropy = CrossEntropyLoss(; logits = Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model(x, ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train(model_function; cpu::Bool = false, kwargs...)
    dev = cpu ? cpu_device() : gpu_device()
    model, ps, st = create_model(model_function; dev, kwargs...)

    # Training
    train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev

    tstate = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 9
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, tstate = Training.single_train_step!(
                AutoZygote(), logitcrossentropy, (x, y), tstate
            )
        end
        ttime = time() - stime

        tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
        te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
        @printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \
                 Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
    end
    return
end

train(NeuralODECompact)
[1/9]	Time 145.3108s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.7453s	Training Accuracy: 58.22222%	Test Accuracy: 57.33333%
[3/9]	Time 0.8750s	Training Accuracy: 67.85185%	Test Accuracy: 70.66667%
[4/9]	Time 0.7282s	Training Accuracy: 74.29630%	Test Accuracy: 74.66667%
[5/9]	Time 0.8601s	Training Accuracy: 76.29630%	Test Accuracy: 76.00000%
[6/9]	Time 0.7159s	Training Accuracy: 78.74074%	Test Accuracy: 80.00000%
[7/9]	Time 0.9108s	Training Accuracy: 82.22222%	Test Accuracy: 81.33333%
[8/9]	Time 0.7170s	Training Accuracy: 83.62963%	Test Accuracy: 83.33333%
[9/9]	Time 1.0314s	Training Accuracy: 85.18519%	Test Accuracy: 82.66667%
julia
train(NeuralODE)
[1/9]	Time 32.7108s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.7612s	Training Accuracy: 57.18519%	Test Accuracy: 57.33333%
[3/9]	Time 0.6098s	Training Accuracy: 68.37037%	Test Accuracy: 68.00000%
[4/9]	Time 0.8233s	Training Accuracy: 73.77778%	Test Accuracy: 75.33333%
[5/9]	Time 0.6099s	Training Accuracy: 76.14815%	Test Accuracy: 77.33333%
[6/9]	Time 0.6070s	Training Accuracy: 79.48148%	Test Accuracy: 80.66667%
[7/9]	Time 0.8186s	Training Accuracy: 81.25926%	Test Accuracy: 80.66667%
[8/9]	Time 0.6042s	Training Accuracy: 83.40741%	Test Accuracy: 82.66667%
[9/9]	Time 0.5945s	Training Accuracy: 84.81481%	Test Accuracy: 82.00000%

We can also change the sensealg and train the model! GaussAdjoint allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray).

julia
train(NeuralODE; sensealg = GaussAdjoint(; autojacvec = ZygoteVJP()), use_named_tuple = true)
[1/9]	Time 41.1292s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.5687s	Training Accuracy: 58.44444%	Test Accuracy: 58.00000%
[3/9]	Time 0.7158s	Training Accuracy: 66.96296%	Test Accuracy: 68.00000%
[4/9]	Time 0.5678s	Training Accuracy: 72.44444%	Test Accuracy: 73.33333%
[5/9]	Time 0.5838s	Training Accuracy: 76.37037%	Test Accuracy: 76.00000%
[6/9]	Time 0.7886s	Training Accuracy: 78.81481%	Test Accuracy: 79.33333%
[7/9]	Time 0.5685s	Training Accuracy: 80.51852%	Test Accuracy: 81.33333%
[8/9]	Time 0.5817s	Training Accuracy: 82.74074%	Test Accuracy: 83.33333%
[9/9]	Time 0.7789s	Training Accuracy: 85.25926%	Test Accuracy: 82.66667%

But remember some AD backends like ReverseDiff is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.

julia
train(NeuralODE; sensealg = InterpolatingAdjoint(; autojacvec = ReverseDiffVJP()), cpu = true)
[1/9]	Time 40.3517s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.4141s	Training Accuracy: 58.74074%	Test Accuracy: 56.66667%
[3/9]	Time 0.3689s	Training Accuracy: 69.92593%	Test Accuracy: 71.33333%
[4/9]	Time 0.3665s	Training Accuracy: 72.81481%	Test Accuracy: 74.00000%
[5/9]	Time 0.3616s	Training Accuracy: 76.37037%	Test Accuracy: 78.66667%
[6/9]	Time 0.3568s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 0.3605s	Training Accuracy: 81.62963%	Test Accuracy: 80.66667%
[8/9]	Time 0.3617s	Training Accuracy: 83.33333%	Test Accuracy: 80.00000%
[9/9]	Time 0.3760s	Training Accuracy: 85.40741%	Test Accuracy: 82.00000%

For completeness, let's also test out discrete sensitivities!

julia
train(NeuralODE; sensealg = ReverseDiffAdjoint(), cpu = true)
[1/9]	Time 38.4994s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 11.8771s	Training Accuracy: 58.66667%	Test Accuracy: 57.33333%
[3/9]	Time 10.8679s	Training Accuracy: 69.70370%	Test Accuracy: 71.33333%
[4/9]	Time 10.8193s	Training Accuracy: 72.74074%	Test Accuracy: 74.00000%
[5/9]	Time 10.7418s	Training Accuracy: 76.14815%	Test Accuracy: 78.66667%
[6/9]	Time 10.8277s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 10.6072s	Training Accuracy: 81.55556%	Test Accuracy: 80.66667%
[8/9]	Time 10.7403s	Training Accuracy: 83.40741%	Test Accuracy: 80.00000%
[9/9]	Time 10.6355s	Training Accuracy: 85.25926%	Test Accuracy: 81.33333%

Alternate Implementation using Stateful Layer

Starting v0.5.5, Lux provides a StatefulLuxLayer which can be used to avoid the Boxing of st. Using the @compact API avoids this problem entirely.

julia
struct StatefulNeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <:
    Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function StatefulNeuralODE(
        model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
    )
    return StatefulNeuralODE(model, solver, tspan, kwargs)
end

function (n::StatefulNeuralODE)(x, ps, st)
    st_model = StatefulLuxLayer{true}(n.model, ps, st)
    dudt(u, p, t) = st_model(u, p)
    prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st_model.st
end

Train the new Stateful Neural ODE

julia
train(StatefulNeuralODE)
[1/9]	Time 37.0941s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.5524s	Training Accuracy: 58.22222%	Test Accuracy: 55.33333%
[3/9]	Time 0.7739s	Training Accuracy: 68.29630%	Test Accuracy: 68.66667%
[4/9]	Time 0.5580s	Training Accuracy: 73.11111%	Test Accuracy: 76.00000%
[5/9]	Time 0.5594s	Training Accuracy: 75.92593%	Test Accuracy: 76.66667%
[6/9]	Time 0.5750s	Training Accuracy: 78.96296%	Test Accuracy: 80.66667%
[7/9]	Time 0.8628s	Training Accuracy: 80.81481%	Test Accuracy: 81.33333%
[8/9]	Time 0.5453s	Training Accuracy: 83.25926%	Test Accuracy: 82.66667%
[9/9]	Time 0.5475s	Training Accuracy: 84.59259%	Test Accuracy: 82.00000%

We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.

Type Stability

julia
model, ps, st = create_model(NeuralODE)

model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)

x = gpu_device()(ones(Float32, 28, 28, 1, 3));

NeuralODE is not type stable due to the boxing of st

julia
@code_warntype model(x, ps, st)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/src/layers/containers.jl:513
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
└──      return %3

We avoid the problem entirely by using StatefulNeuralODE

julia
@code_warntype model_stateful(x, ps_stateful, st_stateful)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/src/layers/containers.jl:513
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Note, that we still recommend using this layer internally and not exposing this as the default API to the users.

Finally checking the compact model

julia
model_compact, ps_compact, st_compact = create_model(NeuralODECompact)

@code_warntype model_compact(x, ps_compact, st_compact)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/src/layers/containers.jl:513
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.6.1
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.3
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.982 GiB / 4.750 GiB available)

This page was generated using Literate.jl.