Skip to content

MNIST Classification using Neural ODEs

To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.

Package Imports

julia
using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEqTsit5,
      Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs

CUDA.allowscalar(false)
Precompiling Lux...
    494.7 ms  ✓ NaNMath
   3445.3 ms  ✓ ForwardDiff
   1015.0 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
    938.2 ms  ✓ NNlib → NNlibForwardDiffExt
   5701.6 ms  ✓ LuxLib
  10157.1 ms  ✓ Lux
  6 dependencies successfully precompiled in 21 seconds. 103 already precompiled.
Precompiling ComponentArrays...
    915.7 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
    494.5 ms  ✓ MLDataDevices → MLDataDevicesComponentArraysExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
    513.8 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1422.2 ms  ✓ Lux → LuxComponentArraysExt
   2112.1 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 2 seconds. 111 already precompiled.
Precompiling SciMLSensitivity...
    344.7 ms  ✓ RealDot
    662.1 ms  ✓ FunctionWrappers
    375.9 ms  ✓ HashArrayMappedTries
    918.0 ms  ✓ RandomNumbers
    839.4 ms  ✓ DifferentiationInterface
    615.2 ms  ✓ FiniteDiff
    965.0 ms  ✓ QuadGK
    690.7 ms  ✓ FastPower → FastPowerForwardDiffExt
   2476.0 ms  ✓ LazyArrays
   2095.3 ms  ✓ RecursiveArrayTools
    778.4 ms  ✓ PreallocationTools
   1271.3 ms  ✓ LoopVectorization → ForwardDiffExt
   5081.7 ms  ✓ Tracker
    359.8 ms  ✓ FunctionWrappersWrappers
   5319.3 ms  ✓ ChainRules
    351.6 ms  ✓ ScopedValues
    811.1 ms  ✓ Random123
    609.5 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
    412.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
    796.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
    645.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
    617.6 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
    585.2 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
   1029.7 ms  ✓ NLSolversBase
    463.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
   4859.0 ms  ✓ Distributions
   1283.0 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
  14730.1 ms  ✓ ReverseDiff
    760.0 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
    627.6 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
    883.2 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
    769.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
   1064.7 ms  ✓ FastPower → FastPowerTrackerExt
   1548.6 ms  ✓ Tracker → TrackerPDMatsExt
   1052.6 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1095.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   1221.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
    802.1 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
  11750.1 ms  ✓ SciMLBase
   1771.6 ms  ✓ LineSearches
   1446.1 ms  ✓ Distributions → DistributionsTestExt
   1416.0 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   3321.6 ms  ✓ FastPower → FastPowerReverseDiffExt
   3290.4 ms  ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
   3511.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
   3324.9 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
   1117.3 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   2808.9 ms  ✓ SciMLJacobianOperators
  32723.8 ms  ✓ Zygote
   3307.9 ms  ✓ Optim
   6152.4 ms  ✓ DiffEqBase
   1905.6 ms  ✓ Zygote → ZygoteTrackerExt
   1550.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   3181.0 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   3765.2 ms  ✓ SciMLBase → SciMLBaseZygoteExt
   1587.0 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
   2357.1 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   4724.8 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
   1938.2 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
   1707.9 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
   4471.9 ms  ✓ DiffEqCallbacks
  32340.1 ms  ✓ LinearSolve
   5402.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
   3910.3 ms  ✓ DiffEqNoiseProcess
   2670.1 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   4255.6 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
   2569.1 ms  ✓ LinearSolve → LinearSolveRecursiveArrayToolsExt
   4879.0 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
 217251.9 ms  ✓ Enzyme
   6290.2 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   8176.9 ms  ✓ Enzyme → EnzymeStaticArraysExt
  11243.2 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   6129.6 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   5752.0 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
   5801.5 ms  ✓ FastPower → FastPowerEnzymeExt
   6287.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   5890.9 ms  ✓ QuadGK → QuadGKEnzymeExt
   8169.1 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
  21807.1 ms  ✓ SciMLSensitivity
  79 dependencies successfully precompiled in 263 seconds. 212 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    595.4 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
    691.7 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
   1110.8 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  1 dependency successfully precompiled in 1 seconds. 98 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
   2414.3 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
  1 dependency successfully precompiled in 3 seconds. 97 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
   3956.6 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
  1 dependency successfully precompiled in 4 seconds. 105 already precompiled.
Precompiling LuxLibEnzymeExt...
   1285.7 ms  ✓ LuxLib → LuxLibEnzymeExt
  1 dependency successfully precompiled in 2 seconds. 130 already precompiled.
Precompiling LuxEnzymeExt...
   6799.7 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 146 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1190.3 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
   1088.0 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3338.3 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 4 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
   2081.9 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 114 already precompiled.
Precompiling ComponentArraysTrackerExt...
   1165.5 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
  1 dependency successfully precompiled in 1 seconds. 70 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
   3427.4 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 49 already precompiled.
Precompiling LuxLibReverseDiffExt...
   3492.4 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   4175.7 ms  ✓ LuxLib → LuxLibReverseDiffExt
  2 dependencies successfully precompiled in 4 seconds. 98 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
   3474.6 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 57 already precompiled.
Precompiling LuxReverseDiffExt...
   4337.3 ms  ✓ Lux → LuxReverseDiffExt
  1 dependency successfully precompiled in 5 seconds. 115 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    832.9 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1629.8 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling LuxZygoteExt...
   2879.2 ms  ✓ Lux → LuxZygoteExt
  1 dependency successfully precompiled in 3 seconds. 166 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1597.9 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
   1838.2 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
  2 dependencies successfully precompiled in 2 seconds. 116 already precompiled.
Precompiling LuxCUDA...
    355.3 ms  ✓ LLVMLoopInfo
    625.4 ms  ✓ InlineStrings
    381.4 ms  ✓ InvertedIndices
   1487.4 ms  ✓ CUDA_Runtime_Discovery
   1144.9 ms  ✓ Crayons
  19329.3 ms  ✓ PrettyTables
  43759.0 ms  ✓ DataFrames
  45138.9 ms  ✓ CUDA
   4907.6 ms  ✓ Atomix → AtomixCUDAExt
   7987.0 ms  ✓ cuDNN
   5212.4 ms  ✓ LuxCUDA
  11 dependencies successfully precompiled in 129 seconds. 91 already precompiled.
Precompiling EnzymeBFloat16sExt...
   5528.6 ms  ✓ Enzyme → EnzymeBFloat16sExt
  1 dependency successfully precompiled in 6 seconds. 47 already precompiled.
Precompiling ZygoteColorsExt...
   1801.7 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 105 already precompiled.
Precompiling ParsersExt...
    519.8 ms  ✓ InlineStrings → ParsersExt
  1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
   4844.3 ms  ✓ ArrayInterface → ArrayInterfaceCUDAExt
  1 dependency successfully precompiled in 5 seconds. 103 already precompiled.
Precompiling NNlibCUDAExt...
   4717.8 ms  ✓ CUDA → ChainRulesCoreExt
   5282.6 ms  ✓ NNlib → NNlibCUDAExt
  2 dependencies successfully precompiled in 6 seconds. 104 already precompiled.
Precompiling MLDataDevicesCUDAExt...
   4759.0 ms  ✓ MLDataDevices → MLDataDevicesCUDAExt
  1 dependency successfully precompiled in 5 seconds. 106 already precompiled.
Precompiling LuxLibCUDAExt...
   4898.0 ms  ✓ CUDA → EnzymeCoreExt
   5391.1 ms  ✓ CUDA → SpecialFunctionsExt
   5897.7 ms  ✓ LuxLib → LuxLibCUDAExt
  3 dependencies successfully precompiled in 6 seconds. 169 already precompiled.
Precompiling DiffEqBaseCUDAExt...
   5532.2 ms  ✓ DiffEqBase → DiffEqBaseCUDAExt
  1 dependency successfully precompiled in 6 seconds. 186 already precompiled.
Precompiling LinearSolveCUDAExt...
   6796.8 ms  ✓ LinearSolve → LinearSolveCUDAExt
  1 dependency successfully precompiled in 7 seconds. 189 already precompiled.
Precompiling WeightInitializersCUDAExt...
   4778.1 ms  ✓ WeightInitializers → WeightInitializersCUDAExt
  1 dependency successfully precompiled in 5 seconds. 111 already precompiled.
Precompiling NNlibCUDACUDNNExt...
   5181.3 ms  ✓ NNlib → NNlibCUDACUDNNExt
  1 dependency successfully precompiled in 6 seconds. 108 already precompiled.
Precompiling MLDataDevicescuDNNExt...
   4837.8 ms  ✓ MLDataDevices → MLDataDevicescuDNNExt
  1 dependency successfully precompiled in 5 seconds. 109 already precompiled.
Precompiling LuxLibcuDNNExt...
   5502.8 ms  ✓ LuxLib → LuxLibcuDNNExt
  1 dependency successfully precompiled in 6 seconds. 176 already precompiled.
Precompiling OrdinaryDiffEqTsit5...
   4577.8 ms  ✓ OrdinaryDiffEqCore
   1459.7 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   7394.4 ms  ✓ OrdinaryDiffEqTsit5
  3 dependencies successfully precompiled in 14 seconds. 122 already precompiled.
Precompiling MLDatasets...
    427.2 ms  ✓ Glob
    455.0 ms  ✓ WorkerUtilities
    507.2 ms  ✓ BufferedStreams
    373.6 ms  ✓ SimpleBufferStream
    625.4 ms  ✓ URIs
    479.1 ms  ✓ CodecZlib
    342.1 ms  ✓ PackageExtensionCompat
    395.3 ms  ✓ BitFlags
    409.3 ms  ✓ InternedStrings
    758.7 ms  ✓ WeakRefStrings
   1008.4 ms  ✓ FilePathsBase
   1398.1 ms  ✓ Transducers → TransducersDataFramesExt
   1614.0 ms  ✓ BangBang → BangBangDataFramesExt
    448.9 ms  ✓ StridedViews
   2398.7 ms  ✓ MAT
    532.8 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1829.8 ms  ✓ OpenSSL
   1200.8 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   2234.5 ms  ✓ Pickle
   6492.3 ms  ✓ MLUtils
  16295.2 ms  ✓ CSV
  17771.7 ms  ✓ HTTP
   1819.6 ms  ✓ FileIO → HTTPExt
   2957.7 ms  ✓ DataDeps
   9850.1 ms  ✓ MLDatasets
  25 dependencies successfully precompiled in 40 seconds. 174 already precompiled.
Precompiling TransducersLazyArraysExt...
   1227.1 ms  ✓ Transducers → TransducersLazyArraysExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1591.7 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 102 already precompiled.
Precompiling LuxMLUtilsExt...
   2102.0 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 167 already precompiled.

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST: Only 1500 for demonstration purposes
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split=:train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false)
    )
end
loadmnist (generic function with 1 method)

Define the Neural ODE Layer

First we will use the @compact macro to define the Neural ODE Layer.

julia
function NeuralODECompact(
        model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
    return @compact(; model, solver, tspan, kwargs...) do x, p
        dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
        # Note the `p.model` here
        prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
        @return solve(prob, solver; kwargs...)
    end
end
NeuralODECompact (generic function with 1 method)

We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.

The NeuralODE is a ContainerLayer, which stores a model. The parameters and states of the NeuralODE are same as those of the underlying model.

julia
struct NeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <: Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function NeuralODE(
        model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
    return NeuralODE(model, solver, tspan, kwargs)
end
Main.var"##230".NeuralODE

OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.

julia
function (n::NeuralODE)(x, ps, st)
    function dudt(u, p, t)
        u_, st = n.model(reshape(u, size(x)), p, st)
        return vec(u_)
    end
    prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st
end

@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
diffeqsol_to_array (generic function with 2 methods)

Create and Initialize the Neural ODE Layer

julia
function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Bool=false,
        sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()))
    # Construct the Neural ODE Model
    model = Chain(FlattenLayer(),
        Dense(784 => 20, tanh),
        model_fn(
            Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
            save_everystep=false, reltol=1.0f-3,
            abstol=1.0f-3, save_start=false, sensealg),
        Base.Fix1(diffeqsol_to_array, 20),
        Dense(20 => 10))

    rng = Random.default_rng()
    Random.seed!(rng, 0)

    ps, st = Lux.setup(rng, model)
    ps = (use_named_tuple ? ps : ComponentArray(ps)) |> dev
    st = st |> dev

    return model, ps, st
end
create_model (generic function with 2 methods)

Define Utility Functions

julia
const logitcrossentropy = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model(x, ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train(model_function; cpu::Bool=false, kwargs...)
    dev = cpu ? cpu_device() : gpu_device()
    model, ps, st = create_model(model_function; dev, kwargs...)

    # Training
    train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev

    tstate = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 9
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, tstate = Training.single_train_step!(
                AutoZygote(), logitcrossentropy, (x, y), tstate)
        end
        ttime = time() - stime

        tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
        te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
        @printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \
                 Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
    end
end

train(NeuralODECompact)
[1/9]	Time 141.1364s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.6655s	Training Accuracy: 58.22222%	Test Accuracy: 57.33333%
[3/9]	Time 0.7698s	Training Accuracy: 67.85185%	Test Accuracy: 70.66667%
[4/9]	Time 0.5835s	Training Accuracy: 74.29630%	Test Accuracy: 74.66667%
[5/9]	Time 0.7761s	Training Accuracy: 76.29630%	Test Accuracy: 76.00000%
[6/9]	Time 0.5822s	Training Accuracy: 78.74074%	Test Accuracy: 80.00000%
[7/9]	Time 0.5991s	Training Accuracy: 82.22222%	Test Accuracy: 81.33333%
[8/9]	Time 0.5885s	Training Accuracy: 83.62963%	Test Accuracy: 83.33333%
[9/9]	Time 0.5881s	Training Accuracy: 85.18519%	Test Accuracy: 82.66667%
julia
train(NeuralODE)
[1/9]	Time 33.1417s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.7340s	Training Accuracy: 57.18519%	Test Accuracy: 57.33333%
[3/9]	Time 0.6041s	Training Accuracy: 68.37037%	Test Accuracy: 68.00000%
[4/9]	Time 0.7873s	Training Accuracy: 73.77778%	Test Accuracy: 75.33333%
[5/9]	Time 0.5895s	Training Accuracy: 76.14815%	Test Accuracy: 77.33333%
[6/9]	Time 0.8318s	Training Accuracy: 79.48148%	Test Accuracy: 80.66667%
[7/9]	Time 0.5976s	Training Accuracy: 81.25926%	Test Accuracy: 80.66667%
[8/9]	Time 0.5905s	Training Accuracy: 83.40741%	Test Accuracy: 82.66667%
[9/9]	Time 0.8935s	Training Accuracy: 84.81481%	Test Accuracy: 82.00000%

We can also change the sensealg and train the model! GaussAdjoint allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray).

julia
train(NeuralODE; sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), use_named_tuple=true)
[1/9]	Time 41.0506s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.5771s	Training Accuracy: 58.44444%	Test Accuracy: 58.00000%
[3/9]	Time 0.5568s	Training Accuracy: 66.96296%	Test Accuracy: 68.00000%
[4/9]	Time 0.5534s	Training Accuracy: 72.44444%	Test Accuracy: 73.33333%
[5/9]	Time 0.7472s	Training Accuracy: 76.37037%	Test Accuracy: 76.00000%
[6/9]	Time 0.5562s	Training Accuracy: 78.81481%	Test Accuracy: 79.33333%
[7/9]	Time 0.5533s	Training Accuracy: 80.51852%	Test Accuracy: 81.33333%
[8/9]	Time 0.7875s	Training Accuracy: 82.74074%	Test Accuracy: 83.33333%
[9/9]	Time 0.5482s	Training Accuracy: 85.25926%	Test Accuracy: 82.66667%

But remember some AD backends like ReverseDiff is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.

julia
train(NeuralODE; sensealg=InterpolatingAdjoint(; autojacvec=ReverseDiffVJP()), cpu=true)
[1/9]	Time 99.8126s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 15.9640s	Training Accuracy: 58.74074%	Test Accuracy: 56.66667%
[3/9]	Time 15.5603s	Training Accuracy: 69.92593%	Test Accuracy: 71.33333%
[4/9]	Time 12.7875s	Training Accuracy: 72.81481%	Test Accuracy: 74.00000%
[5/9]	Time 18.3293s	Training Accuracy: 76.37037%	Test Accuracy: 78.66667%
[6/9]	Time 18.2118s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 13.6867s	Training Accuracy: 81.62963%	Test Accuracy: 80.66667%
[8/9]	Time 15.2422s	Training Accuracy: 83.33333%	Test Accuracy: 80.00000%
[9/9]	Time 16.6457s	Training Accuracy: 85.40741%	Test Accuracy: 82.00000%

For completeness, let's also test out discrete sensitivities!

julia
train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true)
[1/9]	Time 54.5690s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 28.3915s	Training Accuracy: 58.66667%	Test Accuracy: 57.33333%
[3/9]	Time 26.6205s	Training Accuracy: 69.70370%	Test Accuracy: 71.33333%
[4/9]	Time 25.1649s	Training Accuracy: 72.74074%	Test Accuracy: 74.00000%
[5/9]	Time 22.8727s	Training Accuracy: 76.14815%	Test Accuracy: 78.66667%
[6/9]	Time 25.6663s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 23.8760s	Training Accuracy: 81.55556%	Test Accuracy: 80.66667%
[8/9]	Time 26.0145s	Training Accuracy: 83.40741%	Test Accuracy: 80.00000%
[9/9]	Time 20.3738s	Training Accuracy: 85.25926%	Test Accuracy: 81.33333%

Alternate Implementation using Stateful Layer

Starting v0.5.5, Lux provides a StatefulLuxLayer which can be used to avoid the Boxing of st. Using the @compact API avoids this problem entirely.

julia
struct StatefulNeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <:
       Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function StatefulNeuralODE(
        model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
    return StatefulNeuralODE(model, solver, tspan, kwargs)
end

function (n::StatefulNeuralODE)(x, ps, st)
    st_model = StatefulLuxLayer{true}(n.model, ps, st)
    dudt(u, p, t) = st_model(u, p)
    prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st_model.st
end

Train the new Stateful Neural ODE

julia
train(StatefulNeuralODE)
[1/9]	Time 36.2557s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.5595s	Training Accuracy: 58.22222%	Test Accuracy: 55.33333%
[3/9]	Time 0.7871s	Training Accuracy: 68.29630%	Test Accuracy: 68.66667%
[4/9]	Time 0.5487s	Training Accuracy: 73.11111%	Test Accuracy: 76.00000%
[5/9]	Time 0.5565s	Training Accuracy: 75.92593%	Test Accuracy: 76.66667%
[6/9]	Time 0.8646s	Training Accuracy: 78.96296%	Test Accuracy: 80.66667%
[7/9]	Time 0.5478s	Training Accuracy: 80.81481%	Test Accuracy: 81.33333%
[8/9]	Time 0.5532s	Training Accuracy: 83.25926%	Test Accuracy: 82.66667%
[9/9]	Time 0.5573s	Training Accuracy: 84.59259%	Test Accuracy: 82.00000%

We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.

Type Stability

julia
model, ps, st = create_model(NeuralODE)

model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)

x = gpu_device()(ones(Float32, 28, 28, 1, 3));

NeuralODE is not type stable due to the boxing of st

julia
@code_warntype model(x, ps, st)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/src/layers/containers.jl:480
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
└──      return %3

We avoid the problem entirely by using StatefulNeuralODE

julia
@code_warntype model_stateful(x, ps_stateful, st_stateful)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/src/layers/containers.jl:480
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Note, that we still recommend using this layer internally and not exposing this as the default API to the users.

Finally checking the compact model

julia
model_compact, ps_compact, st_compact = create_model(NeuralODECompact)

@code_warntype model_compact(x, ps_compact, st_compact)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/src/layers/containers.jl:480
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.6.1
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.3
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.014 GiB / 4.750 GiB available)

This page was generated using Literate.jl.