MNIST Classification using Neural ODEs
To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.
Package Imports
using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEqTsit5,
Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs
CUDA.allowscalar(false)
Precompiling SciMLSensitivity...
345.9 ms ✓ IteratorInterfaceExtensions
386.3 ms ✓ ExprTools
578.9 ms ✓ AbstractFFTs
381.8 ms ✓ StatsAPI
433.0 ms ✓ InverseFunctions
317.2 ms ✓ DataValueInterfaces
1020.5 ms ✓ FillArrays
549.4 ms ✓ EnumX
494.4 ms ✓ StructIO
376.0 ms ✓ Zlib_jll
352.6 ms ✓ PtrArrays
371.4 ms ✓ DataAPI
790.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
913.5 ms ✓ PDMats
372.3 ms ✓ SciMLStructures
513.7 ms ✓ TruncatedStacktraces
475.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
424.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
619.6 ms ✓ ResettableStacks
611.2 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
539.8 ms ✓ FunctionProperties
664.9 ms ✓ FastPower → FastPowerForwardDiffExt
880.2 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
743.2 ms ✓ PreallocationTools
1241.8 ms ✓ NLSolversBase
6286.6 ms ✓ Krylov
1207.2 ms ✓ FastPower → FastPowerTrackerExt
3674.8 ms ✓ FastPower → FastPowerReverseDiffExt
2149.9 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
2030.2 ms ✓ FastBroadcast
4603.5 ms ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
386.5 ms ✓ TableTraits
476.9 ms ✓ RuntimeGeneratedFunctions
455.0 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
436.6 ms ✓ InverseFunctions → InverseFunctionsDatesExt
391.6 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
479.0 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
683.6 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
426.2 ms ✓ FillArrays → FillArraysStatisticsExt
2837.3 ms ✓ TimerOutputs
710.2 ms ✓ FillArrays → FillArraysSparseArraysExt
447.7 ms ✓ AliasTables
466.9 ms ✓ Missings
2024.3 ms ✓ ObjectFile
648.0 ms ✓ FillArrays → FillArraysPDMatsExt
1560.6 ms ✓ Tracker → TrackerPDMatsExt
1830.6 ms ✓ LineSearches
840.9 ms ✓ Tables
3391.2 ms ✓ PreallocationTools → PreallocationToolsReverseDiffExt
2560.2 ms ✓ Accessors
2434.2 ms ✓ StatsBase
12043.1 ms ✓ ArrayLayouts
794.5 ms ✓ StructArrays
992.8 ms ✓ Accessors → LinearAlgebraExt
726.3 ms ✓ Accessors → StaticArraysExt
3358.6 ms ✓ Optim
853.8 ms ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
406.6 ms ✓ StructArrays → StructArraysAdaptExt
5240.5 ms ✓ Distributions
652.1 ms ✓ StructArrays → StructArraysSparseArraysExt
665.8 ms ✓ StructArrays → StructArraysStaticArraysExt
715.5 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
416.6 ms ✓ StructArrays → StructArraysLinearAlgebraExt
475.3 ms ✓ Accessors → StructArraysExt
1893.7 ms ✓ SciMLOperators
1548.4 ms ✓ SymbolicIndexingInterface
1456.0 ms ✓ Distributions → DistributionsChainRulesCoreExt
2495.9 ms ✓ LazyArrays
585.7 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
811.7 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
2340.9 ms ✓ RecursiveArrayTools
5559.6 ms ✓ ChainRules
1323.3 ms ✓ LazyArrays → LazyArraysStaticArraysExt
871.7 ms ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
639.5 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
887.2 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
1213.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
782.3 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
804.1 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
27467.4 ms ✓ GPUCompiler
11131.2 ms ✓ SciMLBase
1127.8 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
2839.3 ms ✓ SciMLJacobianOperators
5779.9 ms ✓ DiffEqBase
35610.9 ms ✓ Zygote
1431.6 ms ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
17326.7 ms ✓ LinearSolve
2555.6 ms ✓ DiffEqBase → DiffEqBaseTrackerExt
1668.3 ms ✓ DiffEqBase → DiffEqBaseForwardDiffExt
5029.5 ms ✓ DiffEqBase → DiffEqBaseReverseDiffExt
2119.5 ms ✓ DiffEqBase → DiffEqBaseDistributionsExt
1668.8 ms ✓ DiffEqBase → DiffEqBaseSparseArraysExt
2224.6 ms ✓ Zygote → ZygoteTrackerExt
1614.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
4551.1 ms ✓ DiffEqCallbacks
3591.9 ms ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
3838.7 ms ✓ SciMLBase → SciMLBaseZygoteExt
1943.9 ms ✓ LinearSolve → LinearSolveEnzymeExt
3818.0 ms ✓ LinearSolve → LinearSolveKernelAbstractionsExt
4864.2 ms ✓ LinearSolve → LinearSolveSparseArraysExt
3909.3 ms ✓ DiffEqNoiseProcess
4912.5 ms ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
5809.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
221803.3 ms ✓ Enzyme
6517.0 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
8443.8 ms ✓ Enzyme → EnzymeStaticArraysExt
11592.5 ms ✓ Enzyme → EnzymeChainRulesCoreExt
6371.2 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
5916.4 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
6151.5 ms ✓ FastPower → FastPowerEnzymeExt
6344.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
6280.2 ms ✓ QuadGK → QuadGKEnzymeExt
8080.8 ms ✓ DiffEqBase → DiffEqBaseEnzymeExt
22124.6 ms ✓ SciMLSensitivity
114 dependencies successfully precompiled in 313 seconds. 162 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
668.6 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
744.6 ms ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
1176.2 ms ✓ ComponentArrays → ComponentArraysSciMLBaseExt
1 dependency successfully precompiled in 1 seconds. 89 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
456.3 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling LuxEnzymeExt...
7059.8 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 8 seconds. 148 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
845.6 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling MLDataDevicesZygoteExt...
1607.3 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
1 dependency successfully precompiled in 2 seconds. 110 already precompiled.
Precompiling LuxZygoteExt...
2863.4 ms ✓ Lux → LuxZygoteExt
1 dependency successfully precompiled in 3 seconds. 168 already precompiled.
Precompiling ComponentArraysZygoteExt...
1779.4 ms ✓ ComponentArrays → ComponentArraysZygoteExt
1 dependency successfully precompiled in 2 seconds. 118 already precompiled.
Precompiling LuxCUDA...
389.7 ms ✓ LaTeXStrings
497.1 ms ✓ PooledArrays
1366.4 ms ✓ AbstractFFTs → AbstractFFTsTestExt
20567.3 ms ✓ PrettyTables
45800.7 ms ✓ DataFrames
46864.9 ms ✓ CUDA
5354.5 ms ✓ Atomix → AtomixCUDAExt
8497.8 ms ✓ cuDNN
5377.6 ms ✓ LuxCUDA
9 dependencies successfully precompiled in 133 seconds. 94 already precompiled.
Precompiling EnzymeBFloat16sExt...
5600.1 ms ✓ Enzyme → EnzymeBFloat16sExt
1 dependency successfully precompiled in 6 seconds. 47 already precompiled.
Precompiling ZygoteColorsExt...
1870.0 ms ✓ Zygote → ZygoteColorsExt
1 dependency successfully precompiled in 2 seconds. 106 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
4910.6 ms ✓ ArrayInterface → ArrayInterfaceCUDAExt
1 dependency successfully precompiled in 5 seconds. 104 already precompiled.
Precompiling NNlibCUDAExt...
4947.4 ms ✓ CUDA → ChainRulesCoreExt
5304.4 ms ✓ NNlib → NNlibCUDAExt
2 dependencies successfully precompiled in 6 seconds. 105 already precompiled.
Precompiling MLDataDevicesCUDAExt...
4881.6 ms ✓ MLDataDevices → MLDataDevicesCUDAExt
1 dependency successfully precompiled in 5 seconds. 107 already precompiled.
Precompiling LuxLibCUDAExt...
5066.0 ms ✓ CUDA → SpecialFunctionsExt
5216.5 ms ✓ CUDA → EnzymeCoreExt
5659.9 ms ✓ LuxLib → LuxLibCUDAExt
3 dependencies successfully precompiled in 6 seconds. 171 already precompiled.
Precompiling DiffEqBaseCUDAExt...
641.2 ms ✓ InverseFunctions → InverseFunctionsTestExt
670.0 ms ✓ Accessors → TestExt
5462.1 ms ✓ DiffEqBase → DiffEqBaseCUDAExt
3 dependencies successfully precompiled in 6 seconds. 169 already precompiled.
Precompiling LinearSolveCUDAExt...
6471.1 ms ✓ LinearSolve → LinearSolveCUDAExt
1 dependency successfully precompiled in 7 seconds. 163 already precompiled.
Precompiling WeightInitializersCUDAExt...
4923.8 ms ✓ WeightInitializers → WeightInitializersCUDAExt
1 dependency successfully precompiled in 5 seconds. 112 already precompiled.
Precompiling NNlibCUDACUDNNExt...
5217.4 ms ✓ NNlib → NNlibCUDACUDNNExt
1 dependency successfully precompiled in 6 seconds. 109 already precompiled.
Precompiling MLDataDevicescuDNNExt...
4975.9 ms ✓ MLDataDevices → MLDataDevicescuDNNExt
1 dependency successfully precompiled in 5 seconds. 110 already precompiled.
Precompiling LuxLibcuDNNExt...
5750.7 ms ✓ LuxLib → LuxLibcuDNNExt
1 dependency successfully precompiled in 6 seconds. 178 already precompiled.
Precompiling OrdinaryDiffEqTsit5...
4034.8 ms ✓ OrdinaryDiffEqCore
1299.6 ms ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
7158.5 ms ✓ OrdinaryDiffEqTsit5
3 dependencies successfully precompiled in 13 seconds. 97 already precompiled.
Precompiling MLDatasets...
448.2 ms ✓ TensorCore
538.2 ms ✓ TranscodingStreams
412.5 ms ✓ LazyModules
967.7 ms ✓ OffsetArrays
451.4 ms ✓ MappedArrays
693.2 ms ✓ GZip
615.3 ms ✓ ZipFile
765.6 ms ✓ ConcurrentUtilities
721.0 ms ✓ Unitful → InverseFunctionsUnitfulExt
683.1 ms ✓ Accessors → UnitfulExt
790.4 ms ✓ BangBang
518.9 ms ✓ ExceptionUnwrapping
849.1 ms ✓ WeakRefStrings
1154.5 ms ✓ MLCore
1412.1 ms ✓ SplittablesBase
1249.1 ms ✓ FilePathsBase → FilePathsBaseTestExt
480.0 ms ✓ CodecZlib
1852.0 ms ✓ HDF5_jll
419.0 ms ✓ OffsetArrays → OffsetArraysAdaptExt
2467.5 ms ✓ ColorVectorSpace
424.1 ms ✓ StackViews
446.9 ms ✓ PaddedViews
1601.7 ms ✓ NPZ
672.8 ms ✓ BangBang → BangBangStaticArraysExt
2468.7 ms ✓ Pickle
496.1 ms ✓ BangBang → BangBangChainRulesCoreExt
518.8 ms ✓ BangBang → BangBangTablesExt
1651.7 ms ✓ BangBang → BangBangDataFramesExt
1087.0 ms ✓ MicroCollections
17593.5 ms ✓ CSV
19373.2 ms ✓ HTTP
3552.2 ms ✓ ColorSchemes
499.5 ms ✓ MosaicViews
7608.8 ms ✓ HDF5
2807.0 ms ✓ Transducers
3063.6 ms ✓ DataDeps
1884.1 ms ✓ FileIO → HTTPExt
34519.3 ms ✓ JLD2
1398.6 ms ✓ Transducers → TransducersDataFramesExt
2408.7 ms ✓ MAT
711.3 ms ✓ Transducers → TransducersAdaptExt
5714.9 ms ✓ FLoops
6144.1 ms ✓ MLUtils
19426.6 ms ✓ ImageCore
2091.6 ms ✓ ImageBase
1894.7 ms ✓ ImageShow
9864.4 ms ✓ MLDatasets
47 dependencies successfully precompiled in 72 seconds. 153 already precompiled.
Precompiling DistributionsTestExt...
1656.5 ms ✓ Distributions → DistributionsTestExt
1 dependency successfully precompiled in 2 seconds. 53 already precompiled.
Precompiling BangBangStructArraysExt...
509.1 ms ✓ BangBang → BangBangStructArraysExt
1 dependency successfully precompiled in 1 seconds. 24 already precompiled.
Precompiling SciMLBaseMLStyleExt...
1099.2 ms ✓ SciMLBase → SciMLBaseMLStyleExt
1 dependency successfully precompiled in 1 seconds. 61 already precompiled.
Precompiling TransducersLazyArraysExt...
1256.0 ms ✓ Transducers → TransducersLazyArraysExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1588.8 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 104 already precompiled.
Precompiling LuxMLUtilsExt...
2364.7 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 3 seconds. 170 already precompiled.
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST: Only 1500 for demonstration purposes
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split = :train)
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end
# Process images into (H,W,C,BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at = train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle = true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle = false),
)
end
loadmnist (generic function with 1 method)
Define the Neural ODE Layer
First we will use the @compact
macro to define the Neural ODE Layer.
function NeuralODECompact(
model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
)
return @compact(; model, solver, tspan, kwargs...) do x, p
dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
# Note the `p.model` here
prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
@return solve(prob, solver; kwargs...)
end
end
NeuralODECompact (generic function with 1 method)
We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact
was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.
The NeuralODE is a ContainerLayer, which stores a model
. The parameters and states of the NeuralODE are same as those of the underlying model.
struct NeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <: Lux.AbstractLuxWrapperLayer{:model}
model::M
solver::So
tspan::T
kwargs::K
end
function NeuralODE(
model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
)
return NeuralODE(model, solver, tspan, kwargs)
end
Main.var"##230".NeuralODE
OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint
can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.
function (n::NeuralODE)(x, ps, st)
function dudt(u, p, t)
u_, st = n.model(reshape(u, size(x)), p, st)
return vec(u_)
end
prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
return solve(prob, n.solver; n.kwargs...), st
end
@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
diffeqsol_to_array (generic function with 2 methods)
Create and Initialize the Neural ODE Layer
function create_model(
model_fn = NeuralODE; dev = gpu_device(), use_named_tuple::Bool = false,
sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP())
)
# Construct the Neural ODE Model
model = Chain(
FlattenLayer(),
Dense(784 => 20, tanh),
model_fn(
Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
save_everystep = false, reltol = 1.0f-3,
abstol = 1.0f-3, save_start = false, sensealg
),
Base.Fix1(diffeqsol_to_array, 20),
Dense(20 => 10)
)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
ps = (use_named_tuple ? ps : ComponentArray(ps)) |> dev
st = st |> dev
return model, ps, st
end
create_model (generic function with 2 methods)
Define Utility Functions
const logitcrossentropy = CrossEntropyLoss(; logits = Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model(x, ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
function train(model_function; cpu::Bool = false, kwargs...)
dev = cpu ? cpu_device() : gpu_device()
model, ps, st = create_model(model_function; dev, kwargs...)
# Training
train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev
tstate = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 9
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
_, _, _, tstate = Training.single_train_step!(
AutoZygote(), logitcrossentropy, (x, y), tstate
)
end
ttime = time() - stime
tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
@printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \
Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
end
return
end
train(NeuralODECompact)
[1/9] Time 145.3108s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.7453s Training Accuracy: 58.22222% Test Accuracy: 57.33333%
[3/9] Time 0.8750s Training Accuracy: 67.85185% Test Accuracy: 70.66667%
[4/9] Time 0.7282s Training Accuracy: 74.29630% Test Accuracy: 74.66667%
[5/9] Time 0.8601s Training Accuracy: 76.29630% Test Accuracy: 76.00000%
[6/9] Time 0.7159s Training Accuracy: 78.74074% Test Accuracy: 80.00000%
[7/9] Time 0.9108s Training Accuracy: 82.22222% Test Accuracy: 81.33333%
[8/9] Time 0.7170s Training Accuracy: 83.62963% Test Accuracy: 83.33333%
[9/9] Time 1.0314s Training Accuracy: 85.18519% Test Accuracy: 82.66667%
train(NeuralODE)
[1/9] Time 32.7108s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.7612s Training Accuracy: 57.18519% Test Accuracy: 57.33333%
[3/9] Time 0.6098s Training Accuracy: 68.37037% Test Accuracy: 68.00000%
[4/9] Time 0.8233s Training Accuracy: 73.77778% Test Accuracy: 75.33333%
[5/9] Time 0.6099s Training Accuracy: 76.14815% Test Accuracy: 77.33333%
[6/9] Time 0.6070s Training Accuracy: 79.48148% Test Accuracy: 80.66667%
[7/9] Time 0.8186s Training Accuracy: 81.25926% Test Accuracy: 80.66667%
[8/9] Time 0.6042s Training Accuracy: 83.40741% Test Accuracy: 82.66667%
[9/9] Time 0.5945s Training Accuracy: 84.81481% Test Accuracy: 82.00000%
We can also change the sensealg and train the model! GaussAdjoint
allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray
).
train(NeuralODE; sensealg = GaussAdjoint(; autojacvec = ZygoteVJP()), use_named_tuple = true)
[1/9] Time 41.1292s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.5687s Training Accuracy: 58.44444% Test Accuracy: 58.00000%
[3/9] Time 0.7158s Training Accuracy: 66.96296% Test Accuracy: 68.00000%
[4/9] Time 0.5678s Training Accuracy: 72.44444% Test Accuracy: 73.33333%
[5/9] Time 0.5838s Training Accuracy: 76.37037% Test Accuracy: 76.00000%
[6/9] Time 0.7886s Training Accuracy: 78.81481% Test Accuracy: 79.33333%
[7/9] Time 0.5685s Training Accuracy: 80.51852% Test Accuracy: 81.33333%
[8/9] Time 0.5817s Training Accuracy: 82.74074% Test Accuracy: 83.33333%
[9/9] Time 0.7789s Training Accuracy: 85.25926% Test Accuracy: 82.66667%
But remember some AD backends like ReverseDiff
is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.
train(NeuralODE; sensealg = InterpolatingAdjoint(; autojacvec = ReverseDiffVJP()), cpu = true)
[1/9] Time 40.3517s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.4141s Training Accuracy: 58.74074% Test Accuracy: 56.66667%
[3/9] Time 0.3689s Training Accuracy: 69.92593% Test Accuracy: 71.33333%
[4/9] Time 0.3665s Training Accuracy: 72.81481% Test Accuracy: 74.00000%
[5/9] Time 0.3616s Training Accuracy: 76.37037% Test Accuracy: 78.66667%
[6/9] Time 0.3568s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
[7/9] Time 0.3605s Training Accuracy: 81.62963% Test Accuracy: 80.66667%
[8/9] Time 0.3617s Training Accuracy: 83.33333% Test Accuracy: 80.00000%
[9/9] Time 0.3760s Training Accuracy: 85.40741% Test Accuracy: 82.00000%
For completeness, let's also test out discrete sensitivities!
train(NeuralODE; sensealg = ReverseDiffAdjoint(), cpu = true)
[1/9] Time 38.4994s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 11.8771s Training Accuracy: 58.66667% Test Accuracy: 57.33333%
[3/9] Time 10.8679s Training Accuracy: 69.70370% Test Accuracy: 71.33333%
[4/9] Time 10.8193s Training Accuracy: 72.74074% Test Accuracy: 74.00000%
[5/9] Time 10.7418s Training Accuracy: 76.14815% Test Accuracy: 78.66667%
[6/9] Time 10.8277s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
[7/9] Time 10.6072s Training Accuracy: 81.55556% Test Accuracy: 80.66667%
[8/9] Time 10.7403s Training Accuracy: 83.40741% Test Accuracy: 80.00000%
[9/9] Time 10.6355s Training Accuracy: 85.25926% Test Accuracy: 81.33333%
Alternate Implementation using Stateful Layer
Starting v0.5.5
, Lux provides a StatefulLuxLayer
which can be used to avoid the Box
ing of st
. Using the @compact
API avoids this problem entirely.
struct StatefulNeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <:
Lux.AbstractLuxWrapperLayer{:model}
model::M
solver::So
tspan::T
kwargs::K
end
function StatefulNeuralODE(
model::Lux.AbstractLuxLayer; solver = Tsit5(), tspan = (0.0f0, 1.0f0), kwargs...
)
return StatefulNeuralODE(model, solver, tspan, kwargs)
end
function (n::StatefulNeuralODE)(x, ps, st)
st_model = StatefulLuxLayer{true}(n.model, ps, st)
dudt(u, p, t) = st_model(u, p)
prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
return solve(prob, n.solver; n.kwargs...), st_model.st
end
Train the new Stateful Neural ODE
train(StatefulNeuralODE)
[1/9] Time 37.0941s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.5524s Training Accuracy: 58.22222% Test Accuracy: 55.33333%
[3/9] Time 0.7739s Training Accuracy: 68.29630% Test Accuracy: 68.66667%
[4/9] Time 0.5580s Training Accuracy: 73.11111% Test Accuracy: 76.00000%
[5/9] Time 0.5594s Training Accuracy: 75.92593% Test Accuracy: 76.66667%
[6/9] Time 0.5750s Training Accuracy: 78.96296% Test Accuracy: 80.66667%
[7/9] Time 0.8628s Training Accuracy: 80.81481% Test Accuracy: 81.33333%
[8/9] Time 0.5453s Training Accuracy: 83.25926% Test Accuracy: 82.66667%
[9/9] Time 0.5475s Training Accuracy: 84.59259% Test Accuracy: 82.00000%
We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.
Type Stability
model, ps, st = create_model(NeuralODE)
model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)
x = gpu_device()(ones(Float32, 28, 28, 1, 3));
NeuralODE is not type stable due to the boxing of st
@code_warntype model(x, ps, st)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/src/layers/containers.jl:513
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│ %3 = (%1)(%2, x, ps, st)::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
└── return %3
We avoid the problem entirely by using StatefulNeuralODE
@code_warntype model_stateful(x, ps_stateful, st_stateful)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/src/layers/containers.jl:513
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│ %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└── return %3
Note, that we still recommend using this layer internally and not exposing this as the default API to the users.
Finally checking the compact model
model_compact, ps_compact, st_compact = create_model(NeuralODECompact)
@code_warntype model_compact(x, ps_compact, st_compact)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/src/layers/containers.jl:513
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│ %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└── return %3
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.6.1
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.3
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.982 GiB / 4.750 GiB available)
This page was generated using Literate.jl.