Skip to content

MNIST Classification using Neural ODEs

To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.

Package Imports

julia
using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEqTsit5,
      Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs

CUDA.allowscalar(false)

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST: Only 1500 for demonstration purposes
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split=:train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false)
    )
end
loadmnist (generic function with 1 method)

Define the Neural ODE Layer

First we will use the @compact macro to define the Neural ODE Layer.

julia
function NeuralODECompact(
        model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
    return @compact(; model, solver, tspan, kwargs...) do x, p
        dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
        # Note the `p.model` here
        prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
        @return solve(prob, solver; kwargs...)
    end
end
NeuralODECompact (generic function with 1 method)

We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.

The NeuralODE is a ContainerLayer, which stores a model. The parameters and states of the NeuralODE are same as those of the underlying model.

julia
struct NeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <: Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function NeuralODE(
        model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
    return NeuralODE(model, solver, tspan, kwargs)
end
Main.var"##230".NeuralODE

OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.

julia
function (n::NeuralODE)(x, ps, st)
    function dudt(u, p, t)
        u_, st = n.model(reshape(u, size(x)), p, st)
        return vec(u_)
    end
    prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st
end

@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
diffeqsol_to_array (generic function with 2 methods)

Create and Initialize the Neural ODE Layer

julia
function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Bool=false,
        sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()))
    # Construct the Neural ODE Model
    model = Chain(FlattenLayer(),
        Dense(784 => 20, tanh),
        model_fn(
            Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
            save_everystep=false, reltol=1.0f-3,
            abstol=1.0f-3, save_start=false, sensealg),
        Base.Fix1(diffeqsol_to_array, 20),
        Dense(20 => 10))

    rng = Random.default_rng()
    Random.seed!(rng, 0)

    ps, st = Lux.setup(rng, model)
    ps = (use_named_tuple ? ps : ComponentArray(ps)) |> dev
    st = st |> dev

    return model, ps, st
end
create_model (generic function with 2 methods)

Define Utility Functions

julia
const logitcrossentropy = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model(x, ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train(model_function; cpu::Bool=false, kwargs...)
    dev = cpu ? cpu_device() : gpu_device()
    model, ps, st = create_model(model_function; dev, kwargs...)

    # Training
    train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev

    tstate = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 9
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, tstate = Training.single_train_step!(
                AutoZygote(), logitcrossentropy, (x, y), tstate)
        end
        ttime = time() - stime

        tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
        te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
        @printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \
                 Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
    end
end

train(NeuralODECompact)
[1/9]	Time 141.4740s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.5663s	Training Accuracy: 58.22222%	Test Accuracy: 57.33333%
[3/9]	Time 0.5114s	Training Accuracy: 67.85185%	Test Accuracy: 70.66667%
[4/9]	Time 0.6341s	Training Accuracy: 74.29630%	Test Accuracy: 74.66667%
[5/9]	Time 0.4868s	Training Accuracy: 76.29630%	Test Accuracy: 76.00000%
[6/9]	Time 0.5296s	Training Accuracy: 78.74074%	Test Accuracy: 80.00000%
[7/9]	Time 0.5241s	Training Accuracy: 82.22222%	Test Accuracy: 81.33333%
[8/9]	Time 0.5194s	Training Accuracy: 83.62963%	Test Accuracy: 83.33333%
[9/9]	Time 0.5298s	Training Accuracy: 85.18519%	Test Accuracy: 82.66667%
julia
train(NeuralODE)
[1/9]	Time 36.3912s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.5350s	Training Accuracy: 57.18519%	Test Accuracy: 57.33333%
[3/9]	Time 0.5248s	Training Accuracy: 68.37037%	Test Accuracy: 68.00000%
[4/9]	Time 0.5138s	Training Accuracy: 73.77778%	Test Accuracy: 75.33333%
[5/9]	Time 0.7870s	Training Accuracy: 76.14815%	Test Accuracy: 77.33333%
[6/9]	Time 0.5326s	Training Accuracy: 79.48148%	Test Accuracy: 80.66667%
[7/9]	Time 0.5433s	Training Accuracy: 81.25926%	Test Accuracy: 80.66667%
[8/9]	Time 0.5120s	Training Accuracy: 83.40741%	Test Accuracy: 82.66667%
[9/9]	Time 0.5301s	Training Accuracy: 84.81481%	Test Accuracy: 82.00000%

We can also change the sensealg and train the model! GaussAdjoint allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray).

julia
train(NeuralODE; sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), use_named_tuple=true)
[1/9]	Time 46.2201s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.5961s	Training Accuracy: 58.44444%	Test Accuracy: 58.00000%
[3/9]	Time 0.4625s	Training Accuracy: 66.96296%	Test Accuracy: 68.00000%
[4/9]	Time 0.4556s	Training Accuracy: 72.44444%	Test Accuracy: 73.33333%
[5/9]	Time 0.4694s	Training Accuracy: 76.37037%	Test Accuracy: 76.00000%
[6/9]	Time 0.4636s	Training Accuracy: 78.81481%	Test Accuracy: 79.33333%
[7/9]	Time 0.4857s	Training Accuracy: 80.51852%	Test Accuracy: 81.33333%
[8/9]	Time 0.4957s	Training Accuracy: 82.74074%	Test Accuracy: 83.33333%
[9/9]	Time 0.4902s	Training Accuracy: 85.25926%	Test Accuracy: 82.66667%

But remember some AD backends like ReverseDiff is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.

julia
train(NeuralODE; sensealg=InterpolatingAdjoint(; autojacvec=ReverseDiffVJP()), cpu=true)
[1/9]	Time 104.0135s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 10.2240s	Training Accuracy: 58.74074%	Test Accuracy: 56.66667%
[3/9]	Time 12.6256s	Training Accuracy: 69.92593%	Test Accuracy: 71.33333%
[4/9]	Time 15.6875s	Training Accuracy: 72.81481%	Test Accuracy: 74.00000%
[5/9]	Time 13.7557s	Training Accuracy: 76.37037%	Test Accuracy: 78.66667%
[6/9]	Time 9.5143s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 11.7677s	Training Accuracy: 81.62963%	Test Accuracy: 80.66667%
[8/9]	Time 18.2477s	Training Accuracy: 83.33333%	Test Accuracy: 80.00000%
[9/9]	Time 15.0654s	Training Accuracy: 85.40741%	Test Accuracy: 82.00000%

For completeness, let's also test out discrete sensitivities!

julia
train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true)
[1/9]	Time 57.1033s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 26.5368s	Training Accuracy: 58.66667%	Test Accuracy: 57.33333%
[3/9]	Time 27.5508s	Training Accuracy: 69.70370%	Test Accuracy: 71.33333%
[4/9]	Time 27.9761s	Training Accuracy: 72.74074%	Test Accuracy: 74.00000%
[5/9]	Time 28.6880s	Training Accuracy: 76.14815%	Test Accuracy: 78.66667%
[6/9]	Time 27.0568s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 27.1087s	Training Accuracy: 81.55556%	Test Accuracy: 80.66667%
[8/9]	Time 27.9895s	Training Accuracy: 83.40741%	Test Accuracy: 80.00000%
[9/9]	Time 28.5690s	Training Accuracy: 85.25926%	Test Accuracy: 81.33333%

Alternate Implementation using Stateful Layer

Starting v0.5.5, Lux provides a StatefulLuxLayer which can be used to avoid the Boxing of st. Using the @compact API avoids this problem entirely.

julia
struct StatefulNeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <:
       Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function StatefulNeuralODE(
        model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
    return StatefulNeuralODE(model, solver, tspan, kwargs)
end

function (n::StatefulNeuralODE)(x, ps, st)
    st_model = StatefulLuxLayer{true}(n.model, ps, st)
    dudt(u, p, t) = st_model(u, p)
    prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st_model.st
end

Train the new Stateful Neural ODE

julia
train(StatefulNeuralODE)
[1/9]	Time 40.1499s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.4769s	Training Accuracy: 58.22222%	Test Accuracy: 55.33333%
[3/9]	Time 0.6573s	Training Accuracy: 68.29630%	Test Accuracy: 68.66667%
[4/9]	Time 0.4516s	Training Accuracy: 73.11111%	Test Accuracy: 76.00000%
[5/9]	Time 0.4537s	Training Accuracy: 75.92593%	Test Accuracy: 76.66667%
[6/9]	Time 0.4658s	Training Accuracy: 78.96296%	Test Accuracy: 80.66667%
[7/9]	Time 0.7835s	Training Accuracy: 80.81481%	Test Accuracy: 81.33333%
[8/9]	Time 0.4535s	Training Accuracy: 83.25926%	Test Accuracy: 82.66667%
[9/9]	Time 0.4565s	Training Accuracy: 84.59259%	Test Accuracy: 82.00000%

We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.

Type Stability

julia
model, ps, st = create_model(NeuralODE)

model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)

x = gpu_device()(ones(Float32, 28, 28, 1, 3));

NeuralODE is not type stable due to the boxing of st

julia
@code_warntype model(x, ps, st)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = 15681:15700)), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = 201:220)))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/src/layers/containers.jl:480
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = 15681:15700)), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = 201:220)))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
└──      return %3

We avoid the problem entirely by using StatefulNeuralODE

julia
@code_warntype model_stateful(x, ps_stateful, st_stateful)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = 15681:15700)), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = 201:220)))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/src/layers/containers.jl:480
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = 15681:15700)), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = 201:220)))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Note, that we still recommend using this layer internally and not exposing this as the default API to the users.

Finally checking the compact model

julia
model_compact, ps_compact, st_compact = create_model(NeuralODECompact)

@code_warntype model_compact(x, ps_compact, st_compact)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = 15681:15700)), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = 201:220)))),)), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/src/layers/containers.jl:480
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = 15681:15700)), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = 201:220)))),)), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)))}}}
  st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.795 GiB / 4.750 GiB available)

This page was generated using Literate.jl.