Skip to content

MNIST Classification using Neural ODEs

To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.

Package Imports

julia
using Lux,
    ComponentArrays,
    SciMLSensitivity,
    LuxCUDA,
    Optimisers,
    OrdinaryDiffEqTsit5,
    Random,
    Statistics,
    Zygote,
    OneHotArrays,
    InteractiveUtils,
    Printf
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs

CUDA.allowscalar(false)
Precompiling Lux...
    411.1 ms  ✓ ConcreteStructs
    392.4 ms  ✓ OpenLibm_jll
    604.5 ms  ✓ ADTypes
    393.1 ms  ✓ ArgCheck
    393.7 ms  ✓ ManualMemory
    331.8 ms  ✓ SIMDTypes
    530.6 ms  ✓ DocStringExtensions
    535.6 ms  ✓ EnzymeCore
    333.4 ms  ✓ IfElse
    337.4 ms  ✓ CommonWorldInvalidations
   1119.8 ms  ✓ IrrationalConstants
    343.0 ms  ✓ FastClosures
    397.9 ms  ✓ DiffResults
    531.9 ms  ✓ ArrayInterface
    635.5 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    678.8 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    402.1 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
   1506.9 ms  ✓ Setfield
    453.2 ms  ✓ NaNMath
    401.8 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    382.5 ms  ✓ EnzymeCore → AdaptExt
    437.9 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    441.6 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    835.4 ms  ✓ ThreadingUtilities
    372.2 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    440.8 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    767.9 ms  ✓ KernelAbstractions → EnzymeExt
    620.8 ms  ✓ LogExpFunctions
    403.9 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    812.6 ms  ✓ Static
    421.8 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    391.4 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    482.0 ms  ✓ LuxCore → LuxCoreSetfieldExt
   1284.0 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    414.9 ms  ✓ BitTwiddlingConvenienceFunctions
    989.9 ms  ✓ CPUSummary
   2553.2 ms  ✓ SpecialFunctions
    666.2 ms  ✓ PolyesterWeave
   1273.0 ms  ✓ StaticArrayInterface
   1684.8 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   5674.1 ms  ✓ NNlib
    628.7 ms  ✓ DiffRules
    691.5 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    475.8 ms  ✓ CloseOpenIntervals
    615.0 ms  ✓ LayoutPointers
   2721.7 ms  ✓ WeightInitializers
    890.0 ms  ✓ NNlib → NNlibEnzymeCoreExt
    908.4 ms  ✓ NNlib → NNlibSpecialFunctionsExt
    923.2 ms  ✓ StrideArraysCore
    863.4 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
    753.5 ms  ✓ Polyester
   3532.0 ms  ✓ ForwardDiff
    814.2 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
    898.1 ms  ✓ NNlib → NNlibForwardDiffExt
   5431.5 ms  ✓ LuxLib
   9227.3 ms  ✓ Lux
  56 dependencies successfully precompiled in 32 seconds. 49 already precompiled.
Precompiling ComponentArrays...
    884.2 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 23 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
    514.3 ms  ✓ MLDataDevices → MLDataDevicesComponentArraysExt
  1 dependency successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling LuxComponentArraysExt...
    515.3 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1444.3 ms  ✓ Lux → LuxComponentArraysExt
   2092.8 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 2 seconds. 107 already precompiled.
Precompiling SciMLSensitivity...
    369.1 ms  ✓ LaTeXStrings
    411.8 ms  ✓ StatsAPI
    444.3 ms  ✓ InverseFunctions
    944.0 ms  ✓ FillArrays
    368.9 ms  ✓ EnumX
    928.8 ms  ✓ RandomNumbers
    389.1 ms  ✓ StructIO
    362.7 ms  ✓ CompositionsBase
    364.3 ms  ✓ PtrArrays
   2046.6 ms  ✓ Distributed
    428.3 ms  ✓ RuntimeGeneratedFunctions
    459.8 ms  ✓ Parameters
    625.8 ms  ✓ SuiteSparse
   1078.0 ms  ✓ DifferentiationInterface
   1136.2 ms  ✓ Crayons
    403.3 ms  ✓ SciMLStructures
    623.9 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    775.0 ms  ✓ StructArrays
    661.1 ms  ✓ FiniteDiff
    455.0 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
    661.8 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
    660.7 ms  ✓ ResettableStacks
    984.1 ms  ✓ QuadGK
    542.3 ms  ✓ FunctionProperties
   1140.8 ms  ✓ HypergeometricFunctions
    640.2 ms  ✓ FastPower → FastPowerForwardDiffExt
    719.5 ms  ✓ PreallocationTools
   6927.0 ms  ✓ Krylov
    818.0 ms  ✓ FastBroadcast
    449.1 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    411.9 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   4947.4 ms  ✓ Tracker
    432.7 ms  ✓ FillArrays → FillArraysStatisticsExt
    708.2 ms  ✓ FillArrays → FillArraysSparseArraysExt
    881.5 ms  ✓ Random123
   2120.7 ms  ✓ ObjectFile
    410.6 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    469.5 ms  ✓ AliasTables
    885.4 ms  ✓ PDMats
    616.0 ms  ✓ SparseInverseSubset
    640.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
    400.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceGPUArraysCoreExt
    456.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
    798.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
    696.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
  15136.9 ms  ✓ ReverseDiff
    480.1 ms  ✓ StructArrays → StructArraysAdaptExt
  11911.5 ms  ✓ ArrayLayouts
    669.5 ms  ✓ StructArrays → StructArraysSparseArraysExt
    689.5 ms  ✓ StructArrays → StructArraysStaticArraysExt
    442.4 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
    832.3 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
    713.4 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
    653.2 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
    539.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
   1179.5 ms  ✓ FastPower → FastPowerTrackerExt
   2059.3 ms  ✓ StatsFuns
   1023.4 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1152.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   2528.4 ms  ✓ Accessors
   2485.8 ms  ✓ StatsBase
    681.8 ms  ✓ FillArrays → FillArraysPDMatsExt
   1391.7 ms  ✓ Tracker → TrackerPDMatsExt
   3444.1 ms  ✓ FastPower → FastPowerReverseDiffExt
   3445.5 ms  ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
  20750.5 ms  ✓ PrettyTables
   3557.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
   3533.4 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
    855.9 ms  ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
   1010.4 ms  ✓ NLSolversBase
    662.8 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   1586.1 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
    532.0 ms  ✓ Accessors → StructArraysExt
   5481.1 ms  ✓ ChainRules
    902.0 ms  ✓ Accessors → LinearAlgebraExt
    708.7 ms  ✓ Accessors → StaticArraysExt
   2546.1 ms  ✓ LazyArrays
   1848.9 ms  ✓ LineSearches
   5159.4 ms  ✓ Distributions
    791.2 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   1856.8 ms  ✓ SciMLOperators
   1662.7 ms  ✓ SymbolicIndexingInterface
   1483.9 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
   2972.8 ms  ✓ Optim
   1560.3 ms  ✓ Distributions → DistributionsChainRulesCoreExt
    548.9 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
    837.5 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   2289.0 ms  ✓ RecursiveArrayTools
   1089.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
    768.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
    951.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   1305.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
    904.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
  10956.7 ms  ✓ SciMLBase
   1253.2 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   3149.3 ms  ✓ SciMLJacobianOperators
  34622.0 ms  ✓ Zygote
   3230.3 ms  ✓ DiffEqBase
   2021.1 ms  ✓ Zygote → ZygoteTrackerExt
   1545.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   3518.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   3927.3 ms  ✓ SciMLBase → SciMLBaseZygoteExt
   1618.9 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
  17080.6 ms  ✓ LinearSolve
   2613.5 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   2086.3 ms  ✓ DiffEqBase → DiffEqBaseForwardDiffExt
   4988.0 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
   2131.4 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
   1737.5 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
   4495.5 ms  ✓ DiffEqCallbacks
   4201.1 ms  ✓ OrdinaryDiffEqCore
   3403.9 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
   5787.6 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
   1948.8 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   4601.2 ms  ✓ DiffEqNoiseProcess
   5149.5 ms  ✓ LinearSolve → LinearSolveSparseArraysExt
   1586.8 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   5289.5 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
 238690.4 ms  ✓ Enzyme
   7455.5 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   7632.5 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
  21232.5 ms  ✓ Enzyme → EnzymeStaticArraysExt
  21230.6 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   7184.7 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
   7332.8 ms  ✓ FastPower → FastPowerEnzymeExt
   7376.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   7620.0 ms  ✓ QuadGK → QuadGKEnzymeExt
  22644.7 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
  31265.2 ms  ✓ SciMLSensitivity
  129 dependencies successfully precompiled in 329 seconds. 155 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    748.9 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
    808.8 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
   1143.4 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  1 dependency successfully precompiled in 1 seconds. 72 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    685.2 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    480.8 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling LuxLibEnzymeExt...
   1268.4 ms  ✓ LuxLib → LuxLibEnzymeExt
  1 dependency successfully precompiled in 2 seconds. 133 already precompiled.
Precompiling LuxEnzymeExt...
   8078.4 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 9 seconds. 149 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1208.0 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 58 already precompiled.
Precompiling LuxLibTrackerExt...
   1030.4 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3387.4 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 4 seconds. 96 already precompiled.
Precompiling LuxTrackerExt...
   2028.8 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 110 already precompiled.
Precompiling ComponentArraysTrackerExt...
   1150.9 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
  1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
   3551.2 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 43 already precompiled.
Precompiling LuxLibReverseDiffExt...
   3493.0 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   4355.7 ms  ✓ LuxLib → LuxLibReverseDiffExt
  2 dependencies successfully precompiled in 5 seconds. 94 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
   3565.0 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 51 already precompiled.
Precompiling LuxReverseDiffExt...
   4315.7 ms  ✓ Lux → LuxReverseDiffExt
  1 dependency successfully precompiled in 5 seconds. 111 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    878.9 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1574.1 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  1 dependency successfully precompiled in 2 seconds. 71 already precompiled.
Precompiling LuxZygoteExt...
   2701.1 ms  ✓ Lux → LuxZygoteExt
  1 dependency successfully precompiled in 3 seconds. 143 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1533.0 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
  1 dependency successfully precompiled in 2 seconds. 77 already precompiled.
Precompiling LuxCUDA...
    370.4 ms  ✓ LLVMLoopInfo
    401.1 ms  ✓ GPUToolbox
    680.3 ms  ✓ InlineStrings
    416.6 ms  ✓ InvertedIndices
   1669.2 ms  ✓ CUDA_Runtime_Discovery
  46989.3 ms  ✓ DataFrames
  55238.6 ms  ✓ CUDA
   5595.0 ms  ✓ Atomix → AtomixCUDAExt
   9064.9 ms  ✓ cuDNN
   5798.7 ms  ✓ LuxCUDA
  10 dependencies successfully precompiled in 124 seconds. 93 already precompiled.
Precompiling MLDataDevicesGPUArraysExt...
   1631.9 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
  1 dependency successfully precompiled in 2 seconds. 57 already precompiled.
Precompiling WeightInitializersGPUArraysExt...
   1728.2 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
  1 dependency successfully precompiled in 2 seconds. 60 already precompiled.
Precompiling ComponentArraysGPUArraysExt...
   1890.0 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
  1 dependency successfully precompiled in 2 seconds. 69 already precompiled.
Precompiling EnzymeBFloat16sExt...
   6771.9 ms  ✓ Enzyme → EnzymeBFloat16sExt
  1 dependency successfully precompiled in 7 seconds. 47 already precompiled.
Precompiling ZygoteColorsExt...
   1730.7 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 69 already precompiled.
Precompiling ParsersExt...
    537.1 ms  ✓ InlineStrings → ParsersExt
  1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
   5360.5 ms  ✓ ArrayInterface → ArrayInterfaceCUDAExt
  1 dependency successfully precompiled in 6 seconds. 104 already precompiled.
Precompiling NNlibCUDAExt...
   5449.1 ms  ✓ CUDA → ChainRulesCoreExt
   5863.3 ms  ✓ NNlib → NNlibCUDAExt
  2 dependencies successfully precompiled in 6 seconds. 105 already precompiled.
Precompiling MLDataDevicesCUDAExt...
   5347.7 ms  ✓ MLDataDevices → MLDataDevicesCUDAExt
  1 dependency successfully precompiled in 6 seconds. 107 already precompiled.
Precompiling LuxLibCUDAExt...
   5570.6 ms  ✓ CUDA → EnzymeCoreExt
   5746.6 ms  ✓ CUDA → SpecialFunctionsExt
   5883.4 ms  ✓ LuxLib → LuxLibCUDAExt
  3 dependencies successfully precompiled in 6 seconds. 169 already precompiled.
Precompiling DiffEqBaseCUDAExt...
   5974.4 ms  ✓ DiffEqBase → DiffEqBaseCUDAExt
  1 dependency successfully precompiled in 6 seconds. 169 already precompiled.
Precompiling LinearSolveCUDAExt...
   6511.6 ms  ✓ LinearSolve → LinearSolveCUDAExt
  1 dependency successfully precompiled in 7 seconds. 161 already precompiled.
Precompiling WeightInitializersCUDAExt...
   5381.6 ms  ✓ WeightInitializers → WeightInitializersCUDAExt
  1 dependency successfully precompiled in 6 seconds. 112 already precompiled.
Precompiling NNlibCUDACUDNNExt...
   5826.3 ms  ✓ NNlib → NNlibCUDACUDNNExt
  1 dependency successfully precompiled in 6 seconds. 109 already precompiled.
Precompiling MLDataDevicescuDNNExt...
   5855.5 ms  ✓ MLDataDevices → MLDataDevicescuDNNExt
  1 dependency successfully precompiled in 6 seconds. 110 already precompiled.
Precompiling LuxLibcuDNNExt...
   6165.5 ms  ✓ LuxLib → LuxLibcuDNNExt
  1 dependency successfully precompiled in 7 seconds. 176 already precompiled.
Precompiling OrdinaryDiffEqTsit5...
   7014.7 ms  ✓ OrdinaryDiffEqTsit5
  1 dependency successfully precompiled in 7 seconds. 98 already precompiled.
Precompiling OneHotArrays...
    989.9 ms  ✓ OneHotArrays
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
    740.5 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
  1 dependency successfully precompiled in 1 seconds. 38 already precompiled.
Precompiling MLDatasets...
    472.9 ms  ✓ TensorCore
    427.4 ms  ✓ PrettyPrint
    916.7 ms  ✓ InitialValues
    993.7 ms  ✓ OffsetArrays
    466.9 ms  ✓ ShowCases
    392.2 ms  ✓ SimpleBufferStream
    639.5 ms  ✓ URIs
    565.7 ms  ✓ TranscodingStreams
    395.4 ms  ✓ DefineSingletons
    412.8 ms  ✓ LazyModules
    488.4 ms  ✓ DelimitedFiles
    406.4 ms  ✓ BitFlags
    433.5 ms  ✓ MappedArrays
    710.6 ms  ✓ GZip
   1072.9 ms  ✓ Baselet
    649.8 ms  ✓ ZipFile
    729.8 ms  ✓ ConcurrentUtilities
    650.9 ms  ✓ Unitful → InverseFunctionsUnitfulExt
    632.9 ms  ✓ Accessors → UnitfulExt
    984.3 ms  ✓ MLCore
   2270.1 ms  ✓ AtomsBase
   3914.5 ms  ✓ Test
    832.2 ms  ✓ WeakRefStrings
    563.8 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   2086.0 ms  ✓ HDF5_jll
   2261.0 ms  ✓ ColorVectorSpace
    471.0 ms  ✓ NameResolution
    799.0 ms  ✓ BangBang
    434.6 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    457.7 ms  ✓ StackViews
    490.4 ms  ✓ PaddedViews
    500.5 ms  ✓ CodecZlib
   9114.4 ms  ✓ JSON3
   1922.3 ms  ✓ OpenSSL
  18006.0 ms  ✓ MLStyle
   1574.3 ms  ✓ NPZ
    542.1 ms  ✓ ExceptionUnwrapping
   2383.9 ms  ✓ Chemfiles
    654.0 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    658.3 ms  ✓ Accessors → TestExt
   1180.2 ms  ✓ SplittablesBase
   1257.9 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   2365.6 ms  ✓ Pickle
   3494.9 ms  ✓ ColorSchemes
    717.3 ms  ✓ BangBang → BangBangStaticArraysExt
   1769.8 ms  ✓ BangBang → BangBangDataFramesExt
   7449.5 ms  ✓ HDF5
    518.5 ms  ✓ BangBang → BangBangChainRulesCoreExt
    565.9 ms  ✓ BangBang → BangBangTablesExt
    495.8 ms  ✓ MosaicViews
    921.3 ms  ✓ MicroCollections
   4263.7 ms  ✓ JuliaVariables
  33718.0 ms  ✓ JLD2
  19101.3 ms  ✓ CSV
   2436.2 ms  ✓ MAT
   2836.8 ms  ✓ Transducers
  19779.2 ms  ✓ HTTP
   1456.2 ms  ✓ Transducers → TransducersDataFramesExt
    730.4 ms  ✓ Transducers → TransducersAdaptExt
   3218.3 ms  ✓ DataDeps
   5174.5 ms  ✓ FLoops
   1885.4 ms  ✓ FileIO → HTTPExt
   6060.8 ms  ✓ MLUtils
  18890.5 ms  ✓ ImageCore
   2094.4 ms  ✓ ImageBase
   1869.9 ms  ✓ ImageShow
  10221.1 ms  ✓ MLDatasets
  67 dependencies successfully precompiled in 86 seconds. 135 already precompiled.
Precompiling DistributionsTestExt...
   1386.9 ms  ✓ Distributions → DistributionsTestExt
  1 dependency successfully precompiled in 2 seconds. 48 already precompiled.
Precompiling AbstractFFTsTestExt...
   1344.8 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
  1 dependency successfully precompiled in 2 seconds. 7 already precompiled.
Precompiling BangBangStructArraysExt...
    533.9 ms  ✓ BangBang → BangBangStructArraysExt
  1 dependency successfully precompiled in 1 seconds. 24 already precompiled.
Precompiling SciMLBaseMLStyleExt...
   1185.8 ms  ✓ SciMLBase → SciMLBaseMLStyleExt
  1 dependency successfully precompiled in 1 seconds. 60 already precompiled.
Precompiling TransducersLazyArraysExt...
   1252.7 ms  ✓ Transducers → TransducersLazyArraysExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1541.0 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 101 already precompiled.
Precompiling LuxMLUtilsExt...
   2051.7 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 164 already precompiled.

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST: Only 1500 for demonstration purposes
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split=:train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false),
    )
end
loadmnist (generic function with 1 method)

Define the Neural ODE Layer

First we will use the @compact macro to define the Neural ODE Layer.

julia
function NeuralODECompact(
    model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
    return @compact(; model, solver, tspan, kwargs...) do x, p
        dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
        # Note the `p.model` here
        prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
        @return solve(prob, solver; kwargs...)
    end
end
NeuralODECompact (generic function with 1 method)

We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.

The NeuralODE is a ContainerLayer, which stores a model. The parameters and states of the NeuralODE are same as those of the underlying model.

julia
struct NeuralODE{M<:Lux.AbstractLuxLayer,So,T,K} <: Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function NeuralODE(
    model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
    return NeuralODE(model, solver, tspan, kwargs)
end
Main.var"##230".NeuralODE

OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.

julia
function (n::NeuralODE)(x, ps, st)
    function dudt(u, p, t)
        u_, st = n.model(reshape(u, size(x)), p, st)
        return vec(u_)
    end
    prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st
end

@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
diffeqsol_to_array (generic function with 2 methods)

Create and Initialize the Neural ODE Layer

julia
function create_model(
    model_fn=NeuralODE;
    dev=gpu_device(),
    use_named_tuple::Bool=false,
    sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
)
    # Construct the Neural ODE Model
    model = Chain(
        FlattenLayer(),
        Dense(784 => 20, tanh),
        model_fn(
            Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
            save_everystep=false,
            reltol=1.0f-3,
            abstol=1.0f-3,
            save_start=false,
            sensealg,
        ),
        Base.Fix1(diffeqsol_to_array, 20),
        Dense(20 => 10),
    )

    rng = Random.default_rng()
    Random.seed!(rng, 0)

    ps, st = Lux.setup(rng, model)
    ps = dev((use_named_tuple ? ps : ComponentArray(ps)))
    st = dev(st)

    return model, ps, st
end
create_model (generic function with 2 methods)

Define Utility Functions

julia
const logitcrossentropy = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model(x, ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train(model_function; cpu::Bool=false, kwargs...)
    dev = cpu ? cpu_device() : gpu_device()
    model, ps, st = create_model(model_function; dev, kwargs...)

    # Training
    train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))

    tstate = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 9
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, tstate = Training.single_train_step!(
                AutoZygote(), logitcrossentropy, (x, y), tstate
            )
        end
        ttime = time() - stime

        tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
        te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
        @printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \
                 Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
    end
    return nothing
end

train(NeuralODECompact)
[1/9]	Time 143.8989s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.7810s	Training Accuracy: 58.22222%	Test Accuracy: 57.33333%
[3/9]	Time 0.7163s	Training Accuracy: 67.85185%	Test Accuracy: 70.66667%
[4/9]	Time 0.6936s	Training Accuracy: 74.29630%	Test Accuracy: 74.66667%
[5/9]	Time 0.7209s	Training Accuracy: 76.29630%	Test Accuracy: 76.00000%
[6/9]	Time 0.7002s	Training Accuracy: 78.74074%	Test Accuracy: 80.00000%
[7/9]	Time 0.6936s	Training Accuracy: 82.22222%	Test Accuracy: 81.33333%
[8/9]	Time 0.9291s	Training Accuracy: 83.62963%	Test Accuracy: 83.33333%
[9/9]	Time 0.6965s	Training Accuracy: 85.18519%	Test Accuracy: 82.66667%
julia
train(NeuralODE)
[1/9]	Time 35.0952s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.5959s	Training Accuracy: 57.18519%	Test Accuracy: 57.33333%
[3/9]	Time 0.6160s	Training Accuracy: 68.37037%	Test Accuracy: 68.00000%
[4/9]	Time 0.7905s	Training Accuracy: 73.77778%	Test Accuracy: 75.33333%
[5/9]	Time 0.5913s	Training Accuracy: 76.14815%	Test Accuracy: 77.33333%
[6/9]	Time 0.6245s	Training Accuracy: 79.48148%	Test Accuracy: 80.66667%
[7/9]	Time 0.8594s	Training Accuracy: 81.25926%	Test Accuracy: 80.66667%
[8/9]	Time 0.5968s	Training Accuracy: 83.40741%	Test Accuracy: 82.66667%
[9/9]	Time 0.5916s	Training Accuracy: 84.81481%	Test Accuracy: 82.00000%

We can also change the sensealg and train the model! GaussAdjoint allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray).

julia
train(NeuralODE; sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), use_named_tuple=true)
[1/9]	Time 40.0036s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.5827s	Training Accuracy: 58.44444%	Test Accuracy: 58.00000%
[3/9]	Time 0.5791s	Training Accuracy: 66.96296%	Test Accuracy: 68.00000%
[4/9]	Time 0.5898s	Training Accuracy: 72.44444%	Test Accuracy: 73.33333%
[5/9]	Time 0.8134s	Training Accuracy: 76.37037%	Test Accuracy: 76.00000%
[6/9]	Time 0.5727s	Training Accuracy: 78.81481%	Test Accuracy: 79.33333%
[7/9]	Time 0.5744s	Training Accuracy: 80.51852%	Test Accuracy: 81.33333%
[8/9]	Time 0.6065s	Training Accuracy: 82.74074%	Test Accuracy: 83.33333%
[9/9]	Time 0.5724s	Training Accuracy: 85.25926%	Test Accuracy: 82.66667%

But remember some AD backends like ReverseDiff is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.

julia
train(NeuralODE; sensealg=InterpolatingAdjoint(; autojacvec=ReverseDiffVJP()), cpu=true)
┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).

│ 1. If this was not the desired behavior overload the dispatch on `m`.

│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:9
[1/9]	Time 48.3887s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 6.4558s	Training Accuracy: 58.74074%	Test Accuracy: 56.66667%
[3/9]	Time 6.0971s	Training Accuracy: 69.92593%	Test Accuracy: 71.33333%
[4/9]	Time 6.1885s	Training Accuracy: 72.81481%	Test Accuracy: 74.00000%
[5/9]	Time 6.1188s	Training Accuracy: 76.37037%	Test Accuracy: 78.66667%
[6/9]	Time 6.3157s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 6.3433s	Training Accuracy: 81.62963%	Test Accuracy: 80.66667%
[8/9]	Time 6.3841s	Training Accuracy: 83.33333%	Test Accuracy: 80.00000%
[9/9]	Time 6.6278s	Training Accuracy: 85.40741%	Test Accuracy: 82.00000%

For completeness, let's also test out discrete sensitivities!

julia
train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true)
[1/9]	Time 39.3309s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 11.3559s	Training Accuracy: 58.66667%	Test Accuracy: 57.33333%
[3/9]	Time 10.8336s	Training Accuracy: 69.70370%	Test Accuracy: 71.33333%
[4/9]	Time 10.5077s	Training Accuracy: 72.74074%	Test Accuracy: 74.00000%
[5/9]	Time 10.3781s	Training Accuracy: 76.14815%	Test Accuracy: 78.66667%
[6/9]	Time 10.4383s	Training Accuracy: 79.03704%	Test Accuracy: 80.66667%
[7/9]	Time 10.2801s	Training Accuracy: 81.55556%	Test Accuracy: 80.66667%
[8/9]	Time 10.3542s	Training Accuracy: 83.40741%	Test Accuracy: 80.00000%
[9/9]	Time 10.6291s	Training Accuracy: 85.25926%	Test Accuracy: 81.33333%

Alternate Implementation using Stateful Layer

Starting v0.5.5, Lux provides a StatefulLuxLayer which can be used to avoid the Boxing of st. Using the @compact API avoids this problem entirely.

julia
struct StatefulNeuralODE{M<:Lux.AbstractLuxLayer,So,T,K} <:
       Lux.AbstractLuxWrapperLayer{:model}
    model::M
    solver::So
    tspan::T
    kwargs::K
end

function StatefulNeuralODE(
    model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
    return StatefulNeuralODE(model, solver, tspan, kwargs)
end

function (n::StatefulNeuralODE)(x, ps, st)
    st_model = StatefulLuxLayer{true}(n.model, ps, st)
    dudt(u, p, t) = st_model(u, p)
    prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
    return solve(prob, n.solver; n.kwargs...), st_model.st
end

Train the new Stateful Neural ODE

julia
train(StatefulNeuralODE)
[1/9]	Time 36.3082s	Training Accuracy: 37.48148%	Test Accuracy: 40.00000%
[2/9]	Time 0.6160s	Training Accuracy: 58.22222%	Test Accuracy: 55.33333%
[3/9]	Time 0.8826s	Training Accuracy: 68.29630%	Test Accuracy: 68.66667%
[4/9]	Time 0.6078s	Training Accuracy: 73.11111%	Test Accuracy: 76.00000%
[5/9]	Time 0.5947s	Training Accuracy: 75.92593%	Test Accuracy: 76.66667%
[6/9]	Time 0.6177s	Training Accuracy: 78.96296%	Test Accuracy: 80.66667%
[7/9]	Time 0.6261s	Training Accuracy: 80.81481%	Test Accuracy: 81.33333%
[8/9]	Time 0.6027s	Training Accuracy: 83.25926%	Test Accuracy: 82.66667%
[9/9]	Time 0.6036s	Training Accuracy: 84.59259%	Test Accuracy: 82.00000%

We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.

Type Stability

julia
model, ps, st = create_model(NeuralODE)

model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)

x = gpu_device()(ones(Float32, 28, 28, 1, 3));

NeuralODE is not type stable due to the boxing of st

julia
@code_warntype model(x, ps, st)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
└──      return %3

We avoid the problem entirely by using StatefulNeuralODE

julia
@code_warntype model_stateful(x, ps_stateful, st_stateful)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Note, that we still recommend using this layer internally and not exposing this as the default API to the users.

Finally checking the compact model

julia
model_compact, ps_compact, st_compact = create_model(NeuralODECompact)

@code_warntype model_compact(x, ps_compact, st_compact)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
  from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
  c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
  x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
  ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
  st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│   %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│   %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└──      return %3

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

CUDA runtime 12.8, artifact installation
CUDA driver 12.8
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.8.4
- CURAND: 10.3.9
- CUFFT: 11.3.3
- CUSOLVER: 11.7.3
- CUSPARSE: 12.5.8
- CUPTI: 2025.1.1 (API 26.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.7.3
- CUDA_Driver_jll: 0.12.1+1
- CUDA_Runtime_jll: 0.16.1+0

Toolchain:
- Julia: 1.11.5
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.889 GiB / 4.750 GiB available)

This page was generated using Literate.jl.