Skip to content

MNIST Classification with SimpleChains

SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.

Package Imports

julia
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains

Reactant.set_default_backend("cpu")
Precompiling Lux...
    369.1 ms  ✓ Future
    423.7 ms  ✓ ConcreteStructs
    388.5 ms  ✓ CEnum
    592.0 ms  ✓ ADTypes
    415.5 ms  ✓ OpenLibm_jll
    569.8 ms  ✓ Statistics
    395.7 ms  ✓ ArgCheck
   1765.8 ms  ✓ UnsafeAtomics
    387.9 ms  ✓ ManualMemory
    334.3 ms  ✓ Reexport
    334.6 ms  ✓ SIMDTypes
    564.3 ms  ✓ EnzymeCore
    329.5 ms  ✓ IfElse
   1165.9 ms  ✓ IrrationalConstants
    341.9 ms  ✓ CommonWorldInvalidations
    354.8 ms  ✓ FastClosures
    392.6 ms  ✓ StaticArraysCore
   2516.1 ms  ✓ MacroTools
    512.2 ms  ✓ ArrayInterface
    458.9 ms  ✓ GPUArraysCore
    665.5 ms  ✓ CpuId
    660.7 ms  ✓ DocStringExtensions
    479.9 ms  ✓ JLLWrappers
    398.6 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    462.0 ms  ✓ NaNMath
    507.6 ms  ✓ Atomix
   1252.8 ms  ✓ ChainRulesCore
    379.9 ms  ✓ EnzymeCore → AdaptExt
    853.4 ms  ✓ ThreadingUtilities
    395.0 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    384.7 ms  ✓ DiffResults
    806.0 ms  ✓ Static
    681.9 ms  ✓ CommonSubexpressions
   1523.0 ms  ✓ DispatchDoctor
   1454.7 ms  ✓ Setfield
    396.6 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    394.2 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    623.7 ms  ✓ LogExpFunctions
    600.5 ms  ✓ Hwloc_jll
    642.8 ms  ✓ OpenSpecFun_jll
    686.2 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   1287.2 ms  ✓ Optimisers
    415.5 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    425.5 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    457.2 ms  ✓ BitTwiddlingConvenienceFunctions
   1036.9 ms  ✓ CPUSummary
   1589.9 ms  ✓ StaticArrayInterface
    441.0 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    660.6 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
   7469.5 ms  ✓ StaticArrays
   1225.5 ms  ✓ LuxCore
   1364.2 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    438.4 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    451.4 ms  ✓ Optimisers → OptimisersAdaptExt
   2197.9 ms  ✓ Hwloc
    488.3 ms  ✓ CloseOpenIntervals
    625.4 ms  ✓ PolyesterWeave
   2561.1 ms  ✓ SpecialFunctions
    618.0 ms  ✓ LayoutPointers
    654.9 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    626.0 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    639.0 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    622.4 ms  ✓ Adapt → AdaptStaticArraysExt
    695.1 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    476.4 ms  ✓ LuxCore → LuxCoreFunctorsExt
    667.6 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
    459.0 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    483.1 ms  ✓ LuxCore → LuxCoreSetfieldExt
    471.2 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    621.5 ms  ✓ DiffRules
   1746.5 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
    939.5 ms  ✓ StrideArraysCore
   2756.9 ms  ✓ WeightInitializers
    781.4 ms  ✓ Polyester
    952.0 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   3710.0 ms  ✓ ForwardDiff
   4180.1 ms  ✓ KernelAbstractions
    912.9 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
    680.5 ms  ✓ KernelAbstractions → LinearAlgebraExt
    785.8 ms  ✓ KernelAbstractions → EnzymeExt
   5657.1 ms  ✓ NNlib
    868.8 ms  ✓ NNlib → NNlibEnzymeCoreExt
    928.7 ms  ✓ NNlib → NNlibSpecialFunctionsExt
   1006.5 ms  ✓ NNlib → NNlibForwardDiffExt
   6468.7 ms  ✓ LuxLib
   9957.4 ms  ✓ Lux
  86 dependencies successfully precompiled in 49 seconds. 24 already precompiled.
Precompiling MLUtils...
    326.5 ms  ✓ IteratorInterfaceExtensions
    405.4 ms  ✓ StatsAPI
    440.1 ms  ✓ InverseFunctions
    894.1 ms  ✓ InitialValues
    613.6 ms  ✓ Serialization
    393.4 ms  ✓ PrettyPrint
    445.4 ms  ✓ ShowCases
    455.5 ms  ✓ SuiteSparse_jll
    342.8 ms  ✓ DataValueInterfaces
    335.6 ms  ✓ CompositionsBase
    576.7 ms  ✓ OrderedCollections
    352.2 ms  ✓ PtrArrays
    384.5 ms  ✓ DefineSingletons
    474.1 ms  ✓ DelimitedFiles
    393.1 ms  ✓ DataAPI
   1098.9 ms  ✓ Baselet
   1203.1 ms  ✓ SimpleTraits
    439.5 ms  ✓ ContextVariablesX
    378.5 ms  ✓ TableTraits
    437.7 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    499.6 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   2022.9 ms  ✓ Distributed
    419.1 ms  ✓ NameResolution
   3840.5 ms  ✓ Test
    421.1 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
   3790.4 ms  ✓ SparseArrays
   1721.0 ms  ✓ DataStructures
    496.1 ms  ✓ AliasTables
    467.6 ms  ✓ Missings
    599.5 ms  ✓ FLoopsBase
    824.7 ms  ✓ Tables
    627.6 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   1347.1 ms  ✓ SplittablesBase
    664.7 ms  ✓ Statistics → SparseArraysExt
    689.3 ms  ✓ Adapt → AdaptSparseArraysExt
   2486.0 ms  ✓ Accessors
    698.4 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
    524.0 ms  ✓ SortingAlgorithms
   1054.7 ms  ✓ KernelAbstractions → SparseArraysExt
   1075.5 ms  ✓ MLCore
    893.9 ms  ✓ Accessors → LinearAlgebraExt
    690.6 ms  ✓ Accessors → TestExt
    705.0 ms  ✓ Accessors → StaticArraysExt
    806.5 ms  ✓ BangBang
    692.3 ms  ✓ BangBang → BangBangStaticArraysExt
   2328.1 ms  ✓ StatsBase
    530.9 ms  ✓ BangBang → BangBangChainRulesCoreExt
  17890.5 ms  ✓ MLStyle
    533.5 ms  ✓ BangBang → BangBangTablesExt
   1137.9 ms  ✓ MicroCollections
   2932.0 ms  ✓ Transducers
    716.3 ms  ✓ Transducers → TransducersAdaptExt
   4569.4 ms  ✓ JuliaVariables
   5380.1 ms  ✓ FLoops
   6184.9 ms  ✓ MLUtils
  55 dependencies successfully precompiled in 36 seconds. 45 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
    623.0 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    706.0 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1669.8 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 104 already precompiled.
Precompiling LuxMLUtilsExt...
   2134.3 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 170 already precompiled.
Precompiling Zygote...
    674.7 ms  ✓ AbstractFFTs
    657.7 ms  ✓ RealDot
    938.8 ms  ✓ FillArrays
    474.7 ms  ✓ HashArrayMappedTries
    485.2 ms  ✓ Zlib_jll
    675.9 ms  ✓ SuiteSparse
    836.5 ms  ✓ StructArrays
    471.9 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
   1057.1 ms  ✓ ZygoteRules
    483.5 ms  ✓ FillArrays → FillArraysStatisticsExt
   1994.3 ms  ✓ IRTools
    434.1 ms  ✓ ScopedValues
    742.8 ms  ✓ FillArrays → FillArraysSparseArraysExt
    601.3 ms  ✓ StructArrays → StructArraysAdaptExt
    672.8 ms  ✓ SparseInverseSubset
   1275.6 ms  ✓ LazyArtifacts
    681.7 ms  ✓ StructArrays → StructArraysSparseArraysExt
    714.5 ms  ✓ StructArrays → StructArraysStaticArraysExt
    450.9 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
    772.7 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
   1637.2 ms  ✓ LLVMExtra_jll
   5658.4 ms  ✓ ChainRules
   6311.0 ms  ✓ LLVM
   1781.6 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
   4673.4 ms  ✓ GPUArrays
  33592.1 ms  ✓ Zygote
  26 dependencies successfully precompiled in 53 seconds. 77 already precompiled.
Precompiling StructArraysExt...
    642.3 ms  ✓ Accessors → StructArraysExt
  1 dependency successfully precompiled in 1 seconds. 20 already precompiled.
Precompiling BangBangStructArraysExt...
    494.1 ms  ✓ BangBang → BangBangStructArraysExt
  1 dependency successfully precompiled in 1 seconds. 24 already precompiled.
Precompiling ArrayInterfaceChainRulesExt...
    808.4 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    853.0 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    477.8 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling AbstractFFTsTestExt...
   1385.9 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
  1 dependency successfully precompiled in 1 seconds. 7 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1960.1 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   2047.3 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  2 dependencies successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling LuxZygoteExt...
   1714.0 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
   3025.2 ms  ✓ Lux → LuxZygoteExt
  2 dependencies successfully precompiled in 3 seconds. 167 already precompiled.
Precompiling OneHotArrays...
    999.8 ms  ✓ OneHotArrays
  1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
    775.3 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
  1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling Reactant...
    387.0 ms  ✓ EnumX
    438.4 ms  ✓ ExprTools
    702.4 ms  ✓ ExpressionExplorer
    387.7 ms  ✓ StructIO
    442.9 ms  ✓ Scratch
    585.0 ms  ✓ LLVMOpenMP_jll
   1179.8 ms  ✓ CUDA_Driver_jll
   1413.3 ms  ✓ Enzyme_jll
    654.1 ms  ✓ ReactantCore
   2160.7 ms  ✓ ObjectFile
   3111.5 ms  ✓ TimerOutputs
   2530.2 ms  ✓ Reactant_jll
  27278.6 ms  ✓ GPUCompiler
 220191.9 ms  ✓ Enzyme
   5652.2 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
  73368.0 ms  ✓ Reactant
  16 dependencies successfully precompiled in 331 seconds. 48 already precompiled.
Precompiling LuxLibEnzymeExt...
   6765.8 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   8497.7 ms  ✓ Enzyme → EnzymeStaticArraysExt
   1362.3 ms  ✓ LuxLib → LuxLibEnzymeExt
  11686.6 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   6065.3 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
  5 dependencies successfully precompiled in 13 seconds. 128 already precompiled.
Precompiling LuxEnzymeExt...
   6925.8 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling LuxCoreReactantExt...
  12186.0 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 12 seconds. 69 already precompiled.
Precompiling MLDataDevicesReactantExt...
  12080.6 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 12 seconds. 66 already precompiled.
Precompiling LuxLibReactantExt...
  12963.5 ms  ✓ Reactant → ReactantStatisticsExt
  13332.2 ms  ✓ Reactant → ReactantNNlibExt
  13945.3 ms  ✓ LuxLib → LuxLibReactantExt
  12561.8 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  12331.4 ms  ✓ Reactant → ReactantArrayInterfaceExt
  13328.0 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  6 dependencies successfully precompiled in 27 seconds. 143 already precompiled.
Precompiling WeightInitializersReactantExt...
  12096.1 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  1 dependency successfully precompiled in 12 seconds. 80 already precompiled.
Precompiling ReactantAbstractFFTsExt...
  12149.4 ms  ✓ Reactant → ReactantAbstractFFTsExt
  1 dependency successfully precompiled in 12 seconds. 65 already precompiled.
Precompiling LuxReactantExt...
  10406.2 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 11 seconds. 166 already precompiled.
Precompiling MLDatasets...
    411.9 ms  ✓ LaTeXStrings
    420.6 ms  ✓ Glob
    439.8 ms  ✓ TensorCore
    435.1 ms  ✓ WorkerUtilities
    476.5 ms  ✓ BufferedStreams
    637.5 ms  ✓ InlineStrings
   1091.0 ms  ✓ OffsetArrays
    635.4 ms  ✓ URIs
    376.6 ms  ✓ SimpleBufferStream
    403.3 ms  ✓ LazyModules
    568.2 ms  ✓ TranscodingStreams
    401.0 ms  ✓ InvertedIndices
    346.4 ms  ✓ PackageExtensionCompat
    391.7 ms  ✓ BitFlags
    437.9 ms  ✓ MappedArrays
   1206.5 ms  ✓ Crayons
    713.7 ms  ✓ GZip
    792.2 ms  ✓ ConcurrentUtilities
    638.7 ms  ✓ ZipFile
    498.6 ms  ✓ PooledArrays
    912.6 ms  ✓ StructTypes
   1351.5 ms  ✓ SentinelArrays
   2480.8 ms  ✓ FixedPointNumbers
   1145.5 ms  ✓ MbedTLS
    553.7 ms  ✓ LoggingExtras
    581.4 ms  ✓ MPIPreferences
    393.1 ms  ✓ InternedStrings
    542.0 ms  ✓ ExceptionUnwrapping
    584.7 ms  ✓ BFloat16s
    650.6 ms  ✓ OpenSSL_jll
    632.7 ms  ✓ Chemfiles_jll
    518.5 ms  ✓ MicrosoftMPI_jll
    700.1 ms  ✓ libaec_jll
    652.9 ms  ✓ Libiconv_jll
   1061.7 ms  ✓ FilePathsBase
   2209.1 ms  ✓ StringManipulation
    563.3 ms  ✓ InlineStrings → ParsersExt
    436.2 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    439.3 ms  ✓ StackViews
    459.1 ms  ✓ PaddedViews
    520.8 ms  ✓ CodecZlib
    476.8 ms  ✓ StridedViews
   4684.1 ms  ✓ FileIO
   1557.2 ms  ✓ ColorTypes
   1192.1 ms  ✓ OpenMPI_jll
   1705.7 ms  ✓ MPICH_jll
   1392.7 ms  ✓ MPItrampoline_jll
   2015.4 ms  ✓ OpenSSL
    640.3 ms  ✓ StringEncodings
  10143.6 ms  ✓ JSON3
    559.6 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1291.7 ms  ✓ FilePathsBase → FilePathsBaseTestExt
    846.7 ms  ✓ WeakRefStrings
    507.1 ms  ✓ MosaicViews
  21886.9 ms  ✓ Unitful
   1562.5 ms  ✓ NPZ
   2133.5 ms  ✓ ColorVectorSpace
   5022.1 ms  ✓ Colors
   1887.6 ms  ✓ HDF5_jll
  21073.7 ms  ✓ PrettyTables
   2425.7 ms  ✓ Pickle
  20024.0 ms  ✓ HTTP
    639.3 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    626.9 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   3136.9 ms  ✓ UnitfulAtomic
  35038.7 ms  ✓ JLD2
    660.8 ms  ✓ Accessors → UnitfulExt
   2491.9 ms  ✓ PeriodicTable
  17534.6 ms  ✓ CSV
   3718.6 ms  ✓ ColorSchemes
   7809.7 ms  ✓ HDF5
   3221.0 ms  ✓ DataDeps
   1863.1 ms  ✓ FileIO → HTTPExt
   2328.8 ms  ✓ AtomsBase
  19783.1 ms  ✓ ImageCore
   2540.0 ms  ✓ MAT
   2177.6 ms  ✓ ImageBase
   2445.9 ms  ✓ Chemfiles
   1918.1 ms  ✓ ImageShow
  45995.5 ms  ✓ DataFrames
   1445.0 ms  ✓ Transducers → TransducersDataFramesExt
   1731.0 ms  ✓ BangBang → BangBangDataFramesExt
  10009.6 ms  ✓ MLDatasets
  83 dependencies successfully precompiled in 126 seconds. 117 already precompiled.
Precompiling SimpleChains...
    366.0 ms  ✓ UnPack
    517.1 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
    886.9 ms  ✓ HostCPUFeatures
   7603.0 ms  ✓ VectorizationBase
    991.8 ms  ✓ SLEEFPirates
   1340.0 ms  ✓ VectorizedRNG
    748.8 ms  ✓ VectorizedRNG → VectorizedRNGStaticArraysExt
  26938.7 ms  ✓ LoopVectorization
   1153.8 ms  ✓ LoopVectorization → SpecialFunctionsExt
   1256.2 ms  ✓ LoopVectorization → ForwardDiffExt
   6144.0 ms  ✓ SimpleChains
  11 dependencies successfully precompiled in 44 seconds. 63 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
   2438.5 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
  1 dependency successfully precompiled in 3 seconds. 98 already precompiled.
Precompiling ReactantOffsetArraysExt...
  11866.4 ms  ✓ Reactant → ReactantOffsetArraysExt
  1 dependency successfully precompiled in 12 seconds. 66 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
   4117.7 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
  1 dependency successfully precompiled in 4 seconds. 106 already precompiled.
Precompiling LuxSimpleChainsExt...
   2051.5 ms  ✓ Lux → LuxSimpleChainsExt
  1 dependency successfully precompiled in 2 seconds. 127 already precompiled.
2025-03-11 22:47:16.870590: I external/xla/xla/service/service.cc:152] XLA service 0x5907b40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-11 22:47:16.871013: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741733236.873427  997531 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741733236.873626  997531 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741733236.873819  997531 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741733236.890114  997531 cuda_dnn.cc:529] Loaded cuDNN version 90400

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split = :train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at = train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle = true, partial = false),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle = false, partial = false),
    )
end
loadmnist (generic function with 1 method)

Define the Model

julia
lux_model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 10)
    )
)
Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
    layer_4 = MaxPool((2, 2)),
    layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
        layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
        layer_3 = Dense(84 => 10),      # 850 parameters
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.

julia
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
    Chain(
        layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
        layer_2 = MaxPool((2, 2)),
        layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
        layer_4 = MaxPool((2, 2)),
        layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
        layer_6 = Chain(
            layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
            layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
            layer_3 = Dense(84 => 10),  # 850 parameters
        ),
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

Helper Functions

julia
const lossfn = CrossEntropyLoss(; logits = Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Define the Training Loop

julia
function train(model, dev = cpu_device(); rng = Random.default_rng(), kwargs...)
    train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev
    ps, st = Lux.setup(rng, model) |> dev

    vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    if dev isa ReactantDevice
        x_ra = first(test_dataloader)[1]
        model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
    else
        model_compiled = model
    end

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, train_state = Training.single_train_step!(
                vjp, lossfn, (x, y), train_state
            )
        end
        ttime = time() - stime

        tr_acc = accuracy(
            model_compiled, train_state.parameters, train_state.states, train_dataloader
        ) *
            100
        te_acc = accuracy(
            model_compiled, train_state.parameters, train_state.states, test_dataloader
        ) *
            100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end
train (generic function with 2 methods)

Finally Training the Model

First we will train the Lux model

julia
tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] 	 Time 321.37s 	 Training Accuracy: 14.06% 	 Test Accuracy: 17.19%
[ 2/10] 	 Time 0.39s 	 Training Accuracy: 29.53% 	 Test Accuracy: 34.38%
[ 3/10] 	 Time 0.39s 	 Training Accuracy: 43.75% 	 Test Accuracy: 41.41%
[ 4/10] 	 Time 0.39s 	 Training Accuracy: 54.30% 	 Test Accuracy: 50.00%
[ 5/10] 	 Time 0.39s 	 Training Accuracy: 63.75% 	 Test Accuracy: 57.81%
[ 6/10] 	 Time 0.38s 	 Training Accuracy: 71.25% 	 Test Accuracy: 63.28%
[ 7/10] 	 Time 0.39s 	 Training Accuracy: 75.23% 	 Test Accuracy: 66.41%
[ 8/10] 	 Time 0.39s 	 Training Accuracy: 79.69% 	 Test Accuracy: 67.19%
[ 9/10] 	 Time 0.40s 	 Training Accuracy: 82.97% 	 Test Accuracy: 71.88%
[10/10] 	 Time 0.38s 	 Training Accuracy: 84.69% 	 Test Accuracy: 73.44%

Now we will train the SimpleChains model

julia
tr_acc, te_acc = train(simple_chains_model)
[ 1/10] 	 Time 903.97s 	 Training Accuracy: 42.03% 	 Test Accuracy: 41.41%
[ 2/10] 	 Time 12.28s 	 Training Accuracy: 52.73% 	 Test Accuracy: 52.34%
[ 3/10] 	 Time 12.23s 	 Training Accuracy: 60.47% 	 Test Accuracy: 58.59%
[ 4/10] 	 Time 12.23s 	 Training Accuracy: 69.38% 	 Test Accuracy: 66.41%
[ 5/10] 	 Time 12.22s 	 Training Accuracy: 74.53% 	 Test Accuracy: 68.75%
[ 6/10] 	 Time 12.22s 	 Training Accuracy: 78.44% 	 Test Accuracy: 75.00%
[ 7/10] 	 Time 12.22s 	 Training Accuracy: 80.47% 	 Test Accuracy: 78.91%
[ 8/10] 	 Time 12.24s 	 Training Accuracy: 82.50% 	 Test Accuracy: 79.69%
[ 9/10] 	 Time 12.23s 	 Training Accuracy: 84.61% 	 Test Accuracy: 82.03%
[10/10] 	 Time 12.23s 	 Training Accuracy: 87.03% 	 Test Accuracy: 82.03%

On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.