MNIST Classification using Neural ODEs
To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.
Package Imports
using Lux,
ComponentArrays,
SciMLSensitivity,
LuxCUDA,
Optimisers,
OrdinaryDiffEqTsit5,
Random,
Statistics,
Zygote,
OneHotArrays,
InteractiveUtils,
Printf
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs
CUDA.allowscalar(false)
Precompiling Lux...
411.1 ms ✓ ConcreteStructs
392.4 ms ✓ OpenLibm_jll
604.5 ms ✓ ADTypes
393.1 ms ✓ ArgCheck
393.7 ms ✓ ManualMemory
331.8 ms ✓ SIMDTypes
530.6 ms ✓ DocStringExtensions
535.6 ms ✓ EnzymeCore
333.4 ms ✓ IfElse
337.4 ms ✓ CommonWorldInvalidations
1119.8 ms ✓ IrrationalConstants
343.0 ms ✓ FastClosures
397.9 ms ✓ DiffResults
531.9 ms ✓ ArrayInterface
635.5 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
678.8 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
402.1 ms ✓ ADTypes → ADTypesChainRulesCoreExt
1506.9 ms ✓ Setfield
453.2 ms ✓ NaNMath
401.8 ms ✓ ADTypes → ADTypesConstructionBaseExt
382.5 ms ✓ EnzymeCore → AdaptExt
437.9 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
441.6 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
835.4 ms ✓ ThreadingUtilities
372.2 ms ✓ ADTypes → ADTypesEnzymeCoreExt
440.8 ms ✓ Optimisers → OptimisersEnzymeCoreExt
767.9 ms ✓ KernelAbstractions → EnzymeExt
620.8 ms ✓ LogExpFunctions
403.9 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
812.6 ms ✓ Static
421.8 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
391.4 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
482.0 ms ✓ LuxCore → LuxCoreSetfieldExt
1284.0 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
414.9 ms ✓ BitTwiddlingConvenienceFunctions
989.9 ms ✓ CPUSummary
2553.2 ms ✓ SpecialFunctions
666.2 ms ✓ PolyesterWeave
1273.0 ms ✓ StaticArrayInterface
1684.8 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
5674.1 ms ✓ NNlib
628.7 ms ✓ DiffRules
691.5 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
475.8 ms ✓ CloseOpenIntervals
615.0 ms ✓ LayoutPointers
2721.7 ms ✓ WeightInitializers
890.0 ms ✓ NNlib → NNlibEnzymeCoreExt
908.4 ms ✓ NNlib → NNlibSpecialFunctionsExt
923.2 ms ✓ StrideArraysCore
863.4 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
753.5 ms ✓ Polyester
3532.0 ms ✓ ForwardDiff
814.2 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
898.1 ms ✓ NNlib → NNlibForwardDiffExt
5431.5 ms ✓ LuxLib
9227.3 ms ✓ Lux
56 dependencies successfully precompiled in 32 seconds. 49 already precompiled.
Precompiling ComponentArrays...
884.2 ms ✓ ComponentArrays
1 dependency successfully precompiled in 1 seconds. 23 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
514.3 ms ✓ MLDataDevices → MLDataDevicesComponentArraysExt
1 dependency successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling LuxComponentArraysExt...
515.3 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1444.3 ms ✓ Lux → LuxComponentArraysExt
2092.8 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
3 dependencies successfully precompiled in 2 seconds. 107 already precompiled.
Precompiling SciMLSensitivity...
369.1 ms ✓ LaTeXStrings
411.8 ms ✓ StatsAPI
444.3 ms ✓ InverseFunctions
944.0 ms ✓ FillArrays
368.9 ms ✓ EnumX
928.8 ms ✓ RandomNumbers
389.1 ms ✓ StructIO
362.7 ms ✓ CompositionsBase
364.3 ms ✓ PtrArrays
2046.6 ms ✓ Distributed
428.3 ms ✓ RuntimeGeneratedFunctions
459.8 ms ✓ Parameters
625.8 ms ✓ SuiteSparse
1078.0 ms ✓ DifferentiationInterface
1136.2 ms ✓ Crayons
403.3 ms ✓ SciMLStructures
623.9 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
775.0 ms ✓ StructArrays
661.1 ms ✓ FiniteDiff
455.0 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
661.8 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
660.7 ms ✓ ResettableStacks
984.1 ms ✓ QuadGK
542.3 ms ✓ FunctionProperties
1140.8 ms ✓ HypergeometricFunctions
640.2 ms ✓ FastPower → FastPowerForwardDiffExt
719.5 ms ✓ PreallocationTools
6927.0 ms ✓ Krylov
818.0 ms ✓ FastBroadcast
449.1 ms ✓ InverseFunctions → InverseFunctionsDatesExt
411.9 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
4947.4 ms ✓ Tracker
432.7 ms ✓ FillArrays → FillArraysStatisticsExt
708.2 ms ✓ FillArrays → FillArraysSparseArraysExt
881.5 ms ✓ Random123
2120.7 ms ✓ ObjectFile
410.6 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
469.5 ms ✓ AliasTables
885.4 ms ✓ PDMats
616.0 ms ✓ SparseInverseSubset
640.9 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
400.3 ms ✓ DifferentiationInterface → DifferentiationInterfaceGPUArraysCoreExt
456.3 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
798.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
696.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
15136.9 ms ✓ ReverseDiff
480.1 ms ✓ StructArrays → StructArraysAdaptExt
11911.5 ms ✓ ArrayLayouts
669.5 ms ✓ StructArrays → StructArraysSparseArraysExt
689.5 ms ✓ StructArrays → StructArraysStaticArraysExt
442.4 ms ✓ StructArrays → StructArraysLinearAlgebraExt
832.3 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
713.4 ms ✓ FiniteDiff → FiniteDiffSparseArraysExt
653.2 ms ✓ FiniteDiff → FiniteDiffStaticArraysExt
539.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
1179.5 ms ✓ FastPower → FastPowerTrackerExt
2059.3 ms ✓ StatsFuns
1023.4 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
1152.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
2528.4 ms ✓ Accessors
2485.8 ms ✓ StatsBase
681.8 ms ✓ FillArrays → FillArraysPDMatsExt
1391.7 ms ✓ Tracker → TrackerPDMatsExt
3444.1 ms ✓ FastPower → FastPowerReverseDiffExt
3445.5 ms ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
20750.5 ms ✓ PrettyTables
3557.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
3533.4 ms ✓ PreallocationTools → PreallocationToolsReverseDiffExt
855.9 ms ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
1010.4 ms ✓ NLSolversBase
662.8 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
1586.1 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
532.0 ms ✓ Accessors → StructArraysExt
5481.1 ms ✓ ChainRules
902.0 ms ✓ Accessors → LinearAlgebraExt
708.7 ms ✓ Accessors → StaticArraysExt
2546.1 ms ✓ LazyArrays
1848.9 ms ✓ LineSearches
5159.4 ms ✓ Distributions
791.2 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1856.8 ms ✓ SciMLOperators
1662.7 ms ✓ SymbolicIndexingInterface
1483.9 ms ✓ LazyArrays → LazyArraysStaticArraysExt
2972.8 ms ✓ Optim
1560.3 ms ✓ Distributions → DistributionsChainRulesCoreExt
548.9 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
837.5 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
2289.0 ms ✓ RecursiveArrayTools
1089.7 ms ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
768.3 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
951.7 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
1305.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
904.3 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
10956.7 ms ✓ SciMLBase
1253.2 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
3149.3 ms ✓ SciMLJacobianOperators
34622.0 ms ✓ Zygote
3230.3 ms ✓ DiffEqBase
2021.1 ms ✓ Zygote → ZygoteTrackerExt
1545.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
3518.8 ms ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
3927.3 ms ✓ SciMLBase → SciMLBaseZygoteExt
1618.9 ms ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
17080.6 ms ✓ LinearSolve
2613.5 ms ✓ DiffEqBase → DiffEqBaseTrackerExt
2086.3 ms ✓ DiffEqBase → DiffEqBaseForwardDiffExt
4988.0 ms ✓ DiffEqBase → DiffEqBaseReverseDiffExt
2131.4 ms ✓ DiffEqBase → DiffEqBaseDistributionsExt
1737.5 ms ✓ DiffEqBase → DiffEqBaseSparseArraysExt
4495.5 ms ✓ DiffEqCallbacks
4201.1 ms ✓ OrdinaryDiffEqCore
3403.9 ms ✓ LinearSolve → LinearSolveKernelAbstractionsExt
5787.6 ms ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
1948.8 ms ✓ LinearSolve → LinearSolveEnzymeExt
4601.2 ms ✓ DiffEqNoiseProcess
5149.5 ms ✓ LinearSolve → LinearSolveSparseArraysExt
1586.8 ms ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
5289.5 ms ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
238690.4 ms ✓ Enzyme
7455.5 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
7632.5 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
21232.5 ms ✓ Enzyme → EnzymeStaticArraysExt
21230.6 ms ✓ Enzyme → EnzymeChainRulesCoreExt
7184.7 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
7332.8 ms ✓ FastPower → FastPowerEnzymeExt
7376.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
7620.0 ms ✓ QuadGK → QuadGKEnzymeExt
22644.7 ms ✓ DiffEqBase → DiffEqBaseEnzymeExt
31265.2 ms ✓ SciMLSensitivity
129 dependencies successfully precompiled in 329 seconds. 155 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
748.9 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
808.8 ms ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
1143.4 ms ✓ ComponentArrays → ComponentArraysSciMLBaseExt
1 dependency successfully precompiled in 1 seconds. 72 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
685.2 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
480.8 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling LuxLibEnzymeExt...
1268.4 ms ✓ LuxLib → LuxLibEnzymeExt
1 dependency successfully precompiled in 2 seconds. 133 already precompiled.
Precompiling LuxEnzymeExt...
8078.4 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 9 seconds. 149 already precompiled.
Precompiling MLDataDevicesTrackerExt...
1208.0 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
1 dependency successfully precompiled in 1 seconds. 58 already precompiled.
Precompiling LuxLibTrackerExt...
1030.4 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3387.4 ms ✓ LuxLib → LuxLibTrackerExt
2 dependencies successfully precompiled in 4 seconds. 96 already precompiled.
Precompiling LuxTrackerExt...
2028.8 ms ✓ Lux → LuxTrackerExt
1 dependency successfully precompiled in 2 seconds. 110 already precompiled.
Precompiling ComponentArraysTrackerExt...
1150.9 ms ✓ ComponentArrays → ComponentArraysTrackerExt
1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
3551.2 ms ✓ MLDataDevices → MLDataDevicesReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 43 already precompiled.
Precompiling LuxLibReverseDiffExt...
3493.0 ms ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
4355.7 ms ✓ LuxLib → LuxLibReverseDiffExt
2 dependencies successfully precompiled in 5 seconds. 94 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
3565.0 ms ✓ ComponentArrays → ComponentArraysReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 51 already precompiled.
Precompiling LuxReverseDiffExt...
4315.7 ms ✓ Lux → LuxReverseDiffExt
1 dependency successfully precompiled in 5 seconds. 111 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
878.9 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling MLDataDevicesZygoteExt...
1574.1 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
1 dependency successfully precompiled in 2 seconds. 71 already precompiled.
Precompiling LuxZygoteExt...
2701.1 ms ✓ Lux → LuxZygoteExt
1 dependency successfully precompiled in 3 seconds. 143 already precompiled.
Precompiling ComponentArraysZygoteExt...
1533.0 ms ✓ ComponentArrays → ComponentArraysZygoteExt
1 dependency successfully precompiled in 2 seconds. 77 already precompiled.
Precompiling LuxCUDA...
370.4 ms ✓ LLVMLoopInfo
401.1 ms ✓ GPUToolbox
680.3 ms ✓ InlineStrings
416.6 ms ✓ InvertedIndices
1669.2 ms ✓ CUDA_Runtime_Discovery
46989.3 ms ✓ DataFrames
55238.6 ms ✓ CUDA
5595.0 ms ✓ Atomix → AtomixCUDAExt
9064.9 ms ✓ cuDNN
5798.7 ms ✓ LuxCUDA
10 dependencies successfully precompiled in 124 seconds. 93 already precompiled.
Precompiling MLDataDevicesGPUArraysExt...
1631.9 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
1 dependency successfully precompiled in 2 seconds. 57 already precompiled.
Precompiling WeightInitializersGPUArraysExt...
1728.2 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
1 dependency successfully precompiled in 2 seconds. 60 already precompiled.
Precompiling ComponentArraysGPUArraysExt...
1890.0 ms ✓ ComponentArrays → ComponentArraysGPUArraysExt
1 dependency successfully precompiled in 2 seconds. 69 already precompiled.
Precompiling EnzymeBFloat16sExt...
6771.9 ms ✓ Enzyme → EnzymeBFloat16sExt
1 dependency successfully precompiled in 7 seconds. 47 already precompiled.
Precompiling ZygoteColorsExt...
1730.7 ms ✓ Zygote → ZygoteColorsExt
1 dependency successfully precompiled in 2 seconds. 69 already precompiled.
Precompiling ParsersExt...
537.1 ms ✓ InlineStrings → ParsersExt
1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
5360.5 ms ✓ ArrayInterface → ArrayInterfaceCUDAExt
1 dependency successfully precompiled in 6 seconds. 104 already precompiled.
Precompiling NNlibCUDAExt...
5449.1 ms ✓ CUDA → ChainRulesCoreExt
5863.3 ms ✓ NNlib → NNlibCUDAExt
2 dependencies successfully precompiled in 6 seconds. 105 already precompiled.
Precompiling MLDataDevicesCUDAExt...
5347.7 ms ✓ MLDataDevices → MLDataDevicesCUDAExt
1 dependency successfully precompiled in 6 seconds. 107 already precompiled.
Precompiling LuxLibCUDAExt...
5570.6 ms ✓ CUDA → EnzymeCoreExt
5746.6 ms ✓ CUDA → SpecialFunctionsExt
5883.4 ms ✓ LuxLib → LuxLibCUDAExt
3 dependencies successfully precompiled in 6 seconds. 169 already precompiled.
Precompiling DiffEqBaseCUDAExt...
5974.4 ms ✓ DiffEqBase → DiffEqBaseCUDAExt
1 dependency successfully precompiled in 6 seconds. 169 already precompiled.
Precompiling LinearSolveCUDAExt...
6511.6 ms ✓ LinearSolve → LinearSolveCUDAExt
1 dependency successfully precompiled in 7 seconds. 161 already precompiled.
Precompiling WeightInitializersCUDAExt...
5381.6 ms ✓ WeightInitializers → WeightInitializersCUDAExt
1 dependency successfully precompiled in 6 seconds. 112 already precompiled.
Precompiling NNlibCUDACUDNNExt...
5826.3 ms ✓ NNlib → NNlibCUDACUDNNExt
1 dependency successfully precompiled in 6 seconds. 109 already precompiled.
Precompiling MLDataDevicescuDNNExt...
5855.5 ms ✓ MLDataDevices → MLDataDevicescuDNNExt
1 dependency successfully precompiled in 6 seconds. 110 already precompiled.
Precompiling LuxLibcuDNNExt...
6165.5 ms ✓ LuxLib → LuxLibcuDNNExt
1 dependency successfully precompiled in 7 seconds. 176 already precompiled.
Precompiling OrdinaryDiffEqTsit5...
7014.7 ms ✓ OrdinaryDiffEqTsit5
1 dependency successfully precompiled in 7 seconds. 98 already precompiled.
Precompiling OneHotArrays...
989.9 ms ✓ OneHotArrays
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
740.5 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 1 seconds. 38 already precompiled.
Precompiling MLDatasets...
472.9 ms ✓ TensorCore
427.4 ms ✓ PrettyPrint
916.7 ms ✓ InitialValues
993.7 ms ✓ OffsetArrays
466.9 ms ✓ ShowCases
392.2 ms ✓ SimpleBufferStream
639.5 ms ✓ URIs
565.7 ms ✓ TranscodingStreams
395.4 ms ✓ DefineSingletons
412.8 ms ✓ LazyModules
488.4 ms ✓ DelimitedFiles
406.4 ms ✓ BitFlags
433.5 ms ✓ MappedArrays
710.6 ms ✓ GZip
1072.9 ms ✓ Baselet
649.8 ms ✓ ZipFile
729.8 ms ✓ ConcurrentUtilities
650.9 ms ✓ Unitful → InverseFunctionsUnitfulExt
632.9 ms ✓ Accessors → UnitfulExt
984.3 ms ✓ MLCore
2270.1 ms ✓ AtomsBase
3914.5 ms ✓ Test
832.2 ms ✓ WeakRefStrings
563.8 ms ✓ FilePathsBase → FilePathsBaseMmapExt
2086.0 ms ✓ HDF5_jll
2261.0 ms ✓ ColorVectorSpace
471.0 ms ✓ NameResolution
799.0 ms ✓ BangBang
434.6 ms ✓ OffsetArrays → OffsetArraysAdaptExt
457.7 ms ✓ StackViews
490.4 ms ✓ PaddedViews
500.5 ms ✓ CodecZlib
9114.4 ms ✓ JSON3
1922.3 ms ✓ OpenSSL
18006.0 ms ✓ MLStyle
1574.3 ms ✓ NPZ
542.1 ms ✓ ExceptionUnwrapping
2383.9 ms ✓ Chemfiles
654.0 ms ✓ InverseFunctions → InverseFunctionsTestExt
658.3 ms ✓ Accessors → TestExt
1180.2 ms ✓ SplittablesBase
1257.9 ms ✓ FilePathsBase → FilePathsBaseTestExt
2365.6 ms ✓ Pickle
3494.9 ms ✓ ColorSchemes
717.3 ms ✓ BangBang → BangBangStaticArraysExt
1769.8 ms ✓ BangBang → BangBangDataFramesExt
7449.5 ms ✓ HDF5
518.5 ms ✓ BangBang → BangBangChainRulesCoreExt
565.9 ms ✓ BangBang → BangBangTablesExt
495.8 ms ✓ MosaicViews
921.3 ms ✓ MicroCollections
4263.7 ms ✓ JuliaVariables
33718.0 ms ✓ JLD2
19101.3 ms ✓ CSV
2436.2 ms ✓ MAT
2836.8 ms ✓ Transducers
19779.2 ms ✓ HTTP
1456.2 ms ✓ Transducers → TransducersDataFramesExt
730.4 ms ✓ Transducers → TransducersAdaptExt
3218.3 ms ✓ DataDeps
5174.5 ms ✓ FLoops
1885.4 ms ✓ FileIO → HTTPExt
6060.8 ms ✓ MLUtils
18890.5 ms ✓ ImageCore
2094.4 ms ✓ ImageBase
1869.9 ms ✓ ImageShow
10221.1 ms ✓ MLDatasets
67 dependencies successfully precompiled in 86 seconds. 135 already precompiled.
Precompiling DistributionsTestExt...
1386.9 ms ✓ Distributions → DistributionsTestExt
1 dependency successfully precompiled in 2 seconds. 48 already precompiled.
Precompiling AbstractFFTsTestExt...
1344.8 ms ✓ AbstractFFTs → AbstractFFTsTestExt
1 dependency successfully precompiled in 2 seconds. 7 already precompiled.
Precompiling BangBangStructArraysExt...
533.9 ms ✓ BangBang → BangBangStructArraysExt
1 dependency successfully precompiled in 1 seconds. 24 already precompiled.
Precompiling SciMLBaseMLStyleExt...
1185.8 ms ✓ SciMLBase → SciMLBaseMLStyleExt
1 dependency successfully precompiled in 1 seconds. 60 already precompiled.
Precompiling TransducersLazyArraysExt...
1252.7 ms ✓ Transducers → TransducersLazyArraysExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1541.0 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 101 already precompiled.
Precompiling LuxMLUtilsExt...
2051.7 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 164 already precompiled.
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST: Only 1500 for demonstration purposes
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split=:train)
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end
# Process images into (H,W,C,BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false),
)
end
loadmnist (generic function with 1 method)
Define the Neural ODE Layer
First we will use the @compact
macro to define the Neural ODE Layer.
function NeuralODECompact(
model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
return @compact(; model, solver, tspan, kwargs...) do x, p
dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
# Note the `p.model` here
prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
@return solve(prob, solver; kwargs...)
end
end
NeuralODECompact (generic function with 1 method)
We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact
was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.
The NeuralODE is a ContainerLayer, which stores a model
. The parameters and states of the NeuralODE are same as those of the underlying model.
struct NeuralODE{M<:Lux.AbstractLuxLayer,So,T,K} <: Lux.AbstractLuxWrapperLayer{:model}
model::M
solver::So
tspan::T
kwargs::K
end
function NeuralODE(
model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
return NeuralODE(model, solver, tspan, kwargs)
end
Main.var"##230".NeuralODE
OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint
can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.
function (n::NeuralODE)(x, ps, st)
function dudt(u, p, t)
u_, st = n.model(reshape(u, size(x)), p, st)
return vec(u_)
end
prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
return solve(prob, n.solver; n.kwargs...), st
end
@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
diffeqsol_to_array (generic function with 2 methods)
Create and Initialize the Neural ODE Layer
function create_model(
model_fn=NeuralODE;
dev=gpu_device(),
use_named_tuple::Bool=false,
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
)
# Construct the Neural ODE Model
model = Chain(
FlattenLayer(),
Dense(784 => 20, tanh),
model_fn(
Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
save_everystep=false,
reltol=1.0f-3,
abstol=1.0f-3,
save_start=false,
sensealg,
),
Base.Fix1(diffeqsol_to_array, 20),
Dense(20 => 10),
)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
ps = dev((use_named_tuple ? ps : ComponentArray(ps)))
st = dev(st)
return model, ps, st
end
create_model (generic function with 2 methods)
Define Utility Functions
const logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model(x, ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
function train(model_function; cpu::Bool=false, kwargs...)
dev = cpu ? cpu_device() : gpu_device()
model, ps, st = create_model(model_function; dev, kwargs...)
# Training
train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))
tstate = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 9
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
_, _, _, tstate = Training.single_train_step!(
AutoZygote(), logitcrossentropy, (x, y), tstate
)
end
ttime = time() - stime
tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
@printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \
Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
end
return nothing
end
train(NeuralODECompact)
[1/9] Time 143.8989s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.7810s Training Accuracy: 58.22222% Test Accuracy: 57.33333%
[3/9] Time 0.7163s Training Accuracy: 67.85185% Test Accuracy: 70.66667%
[4/9] Time 0.6936s Training Accuracy: 74.29630% Test Accuracy: 74.66667%
[5/9] Time 0.7209s Training Accuracy: 76.29630% Test Accuracy: 76.00000%
[6/9] Time 0.7002s Training Accuracy: 78.74074% Test Accuracy: 80.00000%
[7/9] Time 0.6936s Training Accuracy: 82.22222% Test Accuracy: 81.33333%
[8/9] Time 0.9291s Training Accuracy: 83.62963% Test Accuracy: 83.33333%
[9/9] Time 0.6965s Training Accuracy: 85.18519% Test Accuracy: 82.66667%
train(NeuralODE)
[1/9] Time 35.0952s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.5959s Training Accuracy: 57.18519% Test Accuracy: 57.33333%
[3/9] Time 0.6160s Training Accuracy: 68.37037% Test Accuracy: 68.00000%
[4/9] Time 0.7905s Training Accuracy: 73.77778% Test Accuracy: 75.33333%
[5/9] Time 0.5913s Training Accuracy: 76.14815% Test Accuracy: 77.33333%
[6/9] Time 0.6245s Training Accuracy: 79.48148% Test Accuracy: 80.66667%
[7/9] Time 0.8594s Training Accuracy: 81.25926% Test Accuracy: 80.66667%
[8/9] Time 0.5968s Training Accuracy: 83.40741% Test Accuracy: 82.66667%
[9/9] Time 0.5916s Training Accuracy: 84.81481% Test Accuracy: 82.00000%
We can also change the sensealg and train the model! GaussAdjoint
allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray
).
train(NeuralODE; sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), use_named_tuple=true)
[1/9] Time 40.0036s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.5827s Training Accuracy: 58.44444% Test Accuracy: 58.00000%
[3/9] Time 0.5791s Training Accuracy: 66.96296% Test Accuracy: 68.00000%
[4/9] Time 0.5898s Training Accuracy: 72.44444% Test Accuracy: 73.33333%
[5/9] Time 0.8134s Training Accuracy: 76.37037% Test Accuracy: 76.00000%
[6/9] Time 0.5727s Training Accuracy: 78.81481% Test Accuracy: 79.33333%
[7/9] Time 0.5744s Training Accuracy: 80.51852% Test Accuracy: 81.33333%
[8/9] Time 0.6065s Training Accuracy: 82.74074% Test Accuracy: 83.33333%
[9/9] Time 0.5724s Training Accuracy: 85.25926% Test Accuracy: 82.66667%
But remember some AD backends like ReverseDiff
is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.
train(NeuralODE; sensealg=InterpolatingAdjoint(; autojacvec=ReverseDiffVJP()), cpu=true)
┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).
│
│ 1. If this was not the desired behavior overload the dispatch on `m`.
│
│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:9
[1/9] Time 48.3887s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 6.4558s Training Accuracy: 58.74074% Test Accuracy: 56.66667%
[3/9] Time 6.0971s Training Accuracy: 69.92593% Test Accuracy: 71.33333%
[4/9] Time 6.1885s Training Accuracy: 72.81481% Test Accuracy: 74.00000%
[5/9] Time 6.1188s Training Accuracy: 76.37037% Test Accuracy: 78.66667%
[6/9] Time 6.3157s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
[7/9] Time 6.3433s Training Accuracy: 81.62963% Test Accuracy: 80.66667%
[8/9] Time 6.3841s Training Accuracy: 83.33333% Test Accuracy: 80.00000%
[9/9] Time 6.6278s Training Accuracy: 85.40741% Test Accuracy: 82.00000%
For completeness, let's also test out discrete sensitivities!
train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true)
[1/9] Time 39.3309s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 11.3559s Training Accuracy: 58.66667% Test Accuracy: 57.33333%
[3/9] Time 10.8336s Training Accuracy: 69.70370% Test Accuracy: 71.33333%
[4/9] Time 10.5077s Training Accuracy: 72.74074% Test Accuracy: 74.00000%
[5/9] Time 10.3781s Training Accuracy: 76.14815% Test Accuracy: 78.66667%
[6/9] Time 10.4383s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
[7/9] Time 10.2801s Training Accuracy: 81.55556% Test Accuracy: 80.66667%
[8/9] Time 10.3542s Training Accuracy: 83.40741% Test Accuracy: 80.00000%
[9/9] Time 10.6291s Training Accuracy: 85.25926% Test Accuracy: 81.33333%
Alternate Implementation using Stateful Layer
Starting v0.5.5
, Lux provides a StatefulLuxLayer
which can be used to avoid the Box
ing of st
. Using the @compact
API avoids this problem entirely.
struct StatefulNeuralODE{M<:Lux.AbstractLuxLayer,So,T,K} <:
Lux.AbstractLuxWrapperLayer{:model}
model::M
solver::So
tspan::T
kwargs::K
end
function StatefulNeuralODE(
model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...
)
return StatefulNeuralODE(model, solver, tspan, kwargs)
end
function (n::StatefulNeuralODE)(x, ps, st)
st_model = StatefulLuxLayer{true}(n.model, ps, st)
dudt(u, p, t) = st_model(u, p)
prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
return solve(prob, n.solver; n.kwargs...), st_model.st
end
Train the new Stateful Neural ODE
train(StatefulNeuralODE)
[1/9] Time 36.3082s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
[2/9] Time 0.6160s Training Accuracy: 58.22222% Test Accuracy: 55.33333%
[3/9] Time 0.8826s Training Accuracy: 68.29630% Test Accuracy: 68.66667%
[4/9] Time 0.6078s Training Accuracy: 73.11111% Test Accuracy: 76.00000%
[5/9] Time 0.5947s Training Accuracy: 75.92593% Test Accuracy: 76.66667%
[6/9] Time 0.6177s Training Accuracy: 78.96296% Test Accuracy: 80.66667%
[7/9] Time 0.6261s Training Accuracy: 80.81481% Test Accuracy: 81.33333%
[8/9] Time 0.6027s Training Accuracy: 83.25926% Test Accuracy: 82.66667%
[9/9] Time 0.6036s Training Accuracy: 84.59259% Test Accuracy: 82.00000%
We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.
Type Stability
model, ps, st = create_model(NeuralODE)
model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE)
x = gpu_device()(ones(Float32, 28, 28, 1, 3));
NeuralODE is not type stable due to the boxing of st
@code_warntype model(x, ps, st)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│ %3 = (%1)(%2, x, ps, st)::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.DEVICEMEMORY}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}}
└── return %3
We avoid the problem entirely by using StatefulNeuralODE
@code_warntype model_stateful(x, ps_stateful, st_stateful)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple()))
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Main.var"##230".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│ %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└── return %3
Note, that we still recommend using this layer internally and not exposing this as the default API to the users.
Finally checking the compact model
model_compact, ps_compact, st_compact = create_model(NeuralODECompact)
@code_warntype model_compact(x, ps_compact, st_compact)
MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}})
from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/src/layers/containers.jl:509
Arguments
c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}
x::CUDA.CuArray{Float32, 4, CUDA.DeviceMemory}
ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784))), bias = ViewAxis(15681:15700, Shaped1DAxis((20,))))), layer_3 = ViewAxis(15701:16240, Axis(model = ViewAxis(1:540, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = ViewAxis(101:110, Shaped1DAxis((10,))))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10))), bias = ViewAxis(201:220, Shaped1DAxis((20,))))))),)), layer_4 = ViewAxis(16241:16240, Shaped1DAxis((0,))), layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = ViewAxis(201:210, Shaped1DAxis((10,))))))}}}
st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}
Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
1 ─ %1 = Lux.applychain::Core.Const(Lux.applychain)
│ %2 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer{Nothing}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Main.var"##230".var"#2#3", Nothing, @NamedTuple{model::Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{solver::Returns{OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, tspan::Returns{Tuple{Float32, Float32}}}}, Tuple{Tuple{Symbol}, Tuple{Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##230".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}
│ %3 = (%1)(%2, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.DeviceMemory}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{model::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, solver::OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, tspan::Tuple{Float32, Float32}, ₋₋₋kwargs₋₋₋::Lux.CompactMacroImpl.KwargsStorage{@NamedTuple{kwargs::Base.Pairs{Symbol, Any, NTuple{5, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool, sensealg::SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}}}}}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}}
└── return %3
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
CUDA runtime 12.8, artifact installation
CUDA driver 12.8
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.8.4
- CURAND: 10.3.9
- CUFFT: 11.3.3
- CUSOLVER: 11.7.3
- CUSPARSE: 12.5.8
- CUPTI: 2025.1.1 (API 26.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.7.3
- CUDA_Driver_jll: 0.12.1+1
- CUDA_Runtime_jll: 0.16.1+0
Toolchain:
- Julia: 1.11.5
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.889 GiB / 4.750 GiB available)
This page was generated using Literate.jl.