MNIST Classification with SimpleChains
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
Package Imports
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains
Reactant.set_default_backend("cpu")
Precompiling Lux...
369.1 ms ✓ Future
423.7 ms ✓ ConcreteStructs
388.5 ms ✓ CEnum
592.0 ms ✓ ADTypes
415.5 ms ✓ OpenLibm_jll
569.8 ms ✓ Statistics
395.7 ms ✓ ArgCheck
1765.8 ms ✓ UnsafeAtomics
387.9 ms ✓ ManualMemory
334.3 ms ✓ Reexport
334.6 ms ✓ SIMDTypes
564.3 ms ✓ EnzymeCore
329.5 ms ✓ IfElse
1165.9 ms ✓ IrrationalConstants
341.9 ms ✓ CommonWorldInvalidations
354.8 ms ✓ FastClosures
392.6 ms ✓ StaticArraysCore
2516.1 ms ✓ MacroTools
512.2 ms ✓ ArrayInterface
458.9 ms ✓ GPUArraysCore
665.5 ms ✓ CpuId
660.7 ms ✓ DocStringExtensions
479.9 ms ✓ JLLWrappers
398.6 ms ✓ ADTypes → ADTypesConstructionBaseExt
462.0 ms ✓ NaNMath
507.6 ms ✓ Atomix
1252.8 ms ✓ ChainRulesCore
379.9 ms ✓ EnzymeCore → AdaptExt
853.4 ms ✓ ThreadingUtilities
395.0 ms ✓ ADTypes → ADTypesEnzymeCoreExt
384.7 ms ✓ DiffResults
806.0 ms ✓ Static
681.9 ms ✓ CommonSubexpressions
1523.0 ms ✓ DispatchDoctor
1454.7 ms ✓ Setfield
396.6 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
394.2 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
623.7 ms ✓ LogExpFunctions
600.5 ms ✓ Hwloc_jll
642.8 ms ✓ OpenSpecFun_jll
686.2 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
1287.2 ms ✓ Optimisers
415.5 ms ✓ ADTypes → ADTypesChainRulesCoreExt
425.5 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
457.2 ms ✓ BitTwiddlingConvenienceFunctions
1036.9 ms ✓ CPUSummary
1589.9 ms ✓ StaticArrayInterface
441.0 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
660.6 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
7469.5 ms ✓ StaticArrays
1225.5 ms ✓ LuxCore
1364.2 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
438.4 ms ✓ Optimisers → OptimisersEnzymeCoreExt
451.4 ms ✓ Optimisers → OptimisersAdaptExt
2197.9 ms ✓ Hwloc
488.3 ms ✓ CloseOpenIntervals
625.4 ms ✓ PolyesterWeave
2561.1 ms ✓ SpecialFunctions
618.0 ms ✓ LayoutPointers
654.9 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
626.0 ms ✓ StaticArrays → StaticArraysStatisticsExt
639.0 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
622.4 ms ✓ Adapt → AdaptStaticArraysExt
695.1 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
476.4 ms ✓ LuxCore → LuxCoreFunctorsExt
667.6 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
459.0 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
483.1 ms ✓ LuxCore → LuxCoreSetfieldExt
471.2 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
621.5 ms ✓ DiffRules
1746.5 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
939.5 ms ✓ StrideArraysCore
2756.9 ms ✓ WeightInitializers
781.4 ms ✓ Polyester
952.0 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
3710.0 ms ✓ ForwardDiff
4180.1 ms ✓ KernelAbstractions
912.9 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
680.5 ms ✓ KernelAbstractions → LinearAlgebraExt
785.8 ms ✓ KernelAbstractions → EnzymeExt
5657.1 ms ✓ NNlib
868.8 ms ✓ NNlib → NNlibEnzymeCoreExt
928.7 ms ✓ NNlib → NNlibSpecialFunctionsExt
1006.5 ms ✓ NNlib → NNlibForwardDiffExt
6468.7 ms ✓ LuxLib
9957.4 ms ✓ Lux
86 dependencies successfully precompiled in 49 seconds. 24 already precompiled.
Precompiling MLUtils...
326.5 ms ✓ IteratorInterfaceExtensions
405.4 ms ✓ StatsAPI
440.1 ms ✓ InverseFunctions
894.1 ms ✓ InitialValues
613.6 ms ✓ Serialization
393.4 ms ✓ PrettyPrint
445.4 ms ✓ ShowCases
455.5 ms ✓ SuiteSparse_jll
342.8 ms ✓ DataValueInterfaces
335.6 ms ✓ CompositionsBase
576.7 ms ✓ OrderedCollections
352.2 ms ✓ PtrArrays
384.5 ms ✓ DefineSingletons
474.1 ms ✓ DelimitedFiles
393.1 ms ✓ DataAPI
1098.9 ms ✓ Baselet
1203.1 ms ✓ SimpleTraits
439.5 ms ✓ ContextVariablesX
378.5 ms ✓ TableTraits
437.7 ms ✓ InverseFunctions → InverseFunctionsDatesExt
499.6 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
2022.9 ms ✓ Distributed
419.1 ms ✓ NameResolution
3840.5 ms ✓ Test
421.1 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
3790.4 ms ✓ SparseArrays
1721.0 ms ✓ DataStructures
496.1 ms ✓ AliasTables
467.6 ms ✓ Missings
599.5 ms ✓ FLoopsBase
824.7 ms ✓ Tables
627.6 ms ✓ InverseFunctions → InverseFunctionsTestExt
1347.1 ms ✓ SplittablesBase
664.7 ms ✓ Statistics → SparseArraysExt
689.3 ms ✓ Adapt → AdaptSparseArraysExt
2486.0 ms ✓ Accessors
698.4 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
524.0 ms ✓ SortingAlgorithms
1054.7 ms ✓ KernelAbstractions → SparseArraysExt
1075.5 ms ✓ MLCore
893.9 ms ✓ Accessors → LinearAlgebraExt
690.6 ms ✓ Accessors → TestExt
705.0 ms ✓ Accessors → StaticArraysExt
806.5 ms ✓ BangBang
692.3 ms ✓ BangBang → BangBangStaticArraysExt
2328.1 ms ✓ StatsBase
530.9 ms ✓ BangBang → BangBangChainRulesCoreExt
17890.5 ms ✓ MLStyle
533.5 ms ✓ BangBang → BangBangTablesExt
1137.9 ms ✓ MicroCollections
2932.0 ms ✓ Transducers
716.3 ms ✓ Transducers → TransducersAdaptExt
4569.4 ms ✓ JuliaVariables
5380.1 ms ✓ FLoops
6184.9 ms ✓ MLUtils
55 dependencies successfully precompiled in 36 seconds. 45 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
623.0 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
706.0 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1669.8 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 104 already precompiled.
Precompiling LuxMLUtilsExt...
2134.3 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 170 already precompiled.
Precompiling Zygote...
674.7 ms ✓ AbstractFFTs
657.7 ms ✓ RealDot
938.8 ms ✓ FillArrays
474.7 ms ✓ HashArrayMappedTries
485.2 ms ✓ Zlib_jll
675.9 ms ✓ SuiteSparse
836.5 ms ✓ StructArrays
471.9 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
1057.1 ms ✓ ZygoteRules
483.5 ms ✓ FillArrays → FillArraysStatisticsExt
1994.3 ms ✓ IRTools
434.1 ms ✓ ScopedValues
742.8 ms ✓ FillArrays → FillArraysSparseArraysExt
601.3 ms ✓ StructArrays → StructArraysAdaptExt
672.8 ms ✓ SparseInverseSubset
1275.6 ms ✓ LazyArtifacts
681.7 ms ✓ StructArrays → StructArraysSparseArraysExt
714.5 ms ✓ StructArrays → StructArraysStaticArraysExt
450.9 ms ✓ StructArrays → StructArraysLinearAlgebraExt
772.7 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
1637.2 ms ✓ LLVMExtra_jll
5658.4 ms ✓ ChainRules
6311.0 ms ✓ LLVM
1781.6 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
4673.4 ms ✓ GPUArrays
33592.1 ms ✓ Zygote
26 dependencies successfully precompiled in 53 seconds. 77 already precompiled.
Precompiling StructArraysExt...
642.3 ms ✓ Accessors → StructArraysExt
1 dependency successfully precompiled in 1 seconds. 20 already precompiled.
Precompiling BangBangStructArraysExt...
494.1 ms ✓ BangBang → BangBangStructArraysExt
1 dependency successfully precompiled in 1 seconds. 24 already precompiled.
Precompiling ArrayInterfaceChainRulesExt...
808.4 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
853.0 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
477.8 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling AbstractFFTsTestExt...
1385.9 ms ✓ AbstractFFTs → AbstractFFTsTestExt
1 dependency successfully precompiled in 1 seconds. 7 already precompiled.
Precompiling MLDataDevicesZygoteExt...
1960.1 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
2047.3 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
2 dependencies successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling LuxZygoteExt...
1714.0 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
3025.2 ms ✓ Lux → LuxZygoteExt
2 dependencies successfully precompiled in 3 seconds. 167 already precompiled.
Precompiling OneHotArrays...
999.8 ms ✓ OneHotArrays
1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
775.3 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling Reactant...
387.0 ms ✓ EnumX
438.4 ms ✓ ExprTools
702.4 ms ✓ ExpressionExplorer
387.7 ms ✓ StructIO
442.9 ms ✓ Scratch
585.0 ms ✓ LLVMOpenMP_jll
1179.8 ms ✓ CUDA_Driver_jll
1413.3 ms ✓ Enzyme_jll
654.1 ms ✓ ReactantCore
2160.7 ms ✓ ObjectFile
3111.5 ms ✓ TimerOutputs
2530.2 ms ✓ Reactant_jll
27278.6 ms ✓ GPUCompiler
220191.9 ms ✓ Enzyme
5652.2 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
73368.0 ms ✓ Reactant
16 dependencies successfully precompiled in 331 seconds. 48 already precompiled.
Precompiling LuxLibEnzymeExt...
6765.8 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
8497.7 ms ✓ Enzyme → EnzymeStaticArraysExt
1362.3 ms ✓ LuxLib → LuxLibEnzymeExt
11686.6 ms ✓ Enzyme → EnzymeChainRulesCoreExt
6065.3 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
5 dependencies successfully precompiled in 13 seconds. 128 already precompiled.
Precompiling LuxEnzymeExt...
6925.8 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling LuxCoreReactantExt...
12186.0 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 12 seconds. 69 already precompiled.
Precompiling MLDataDevicesReactantExt...
12080.6 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 12 seconds. 66 already precompiled.
Precompiling LuxLibReactantExt...
12963.5 ms ✓ Reactant → ReactantStatisticsExt
13332.2 ms ✓ Reactant → ReactantNNlibExt
13945.3 ms ✓ LuxLib → LuxLibReactantExt
12561.8 ms ✓ Reactant → ReactantSpecialFunctionsExt
12331.4 ms ✓ Reactant → ReactantArrayInterfaceExt
13328.0 ms ✓ Reactant → ReactantKernelAbstractionsExt
6 dependencies successfully precompiled in 27 seconds. 143 already precompiled.
Precompiling WeightInitializersReactantExt...
12096.1 ms ✓ WeightInitializers → WeightInitializersReactantExt
1 dependency successfully precompiled in 12 seconds. 80 already precompiled.
Precompiling ReactantAbstractFFTsExt...
12149.4 ms ✓ Reactant → ReactantAbstractFFTsExt
1 dependency successfully precompiled in 12 seconds. 65 already precompiled.
Precompiling LuxReactantExt...
10406.2 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 11 seconds. 166 already precompiled.
Precompiling MLDatasets...
411.9 ms ✓ LaTeXStrings
420.6 ms ✓ Glob
439.8 ms ✓ TensorCore
435.1 ms ✓ WorkerUtilities
476.5 ms ✓ BufferedStreams
637.5 ms ✓ InlineStrings
1091.0 ms ✓ OffsetArrays
635.4 ms ✓ URIs
376.6 ms ✓ SimpleBufferStream
403.3 ms ✓ LazyModules
568.2 ms ✓ TranscodingStreams
401.0 ms ✓ InvertedIndices
346.4 ms ✓ PackageExtensionCompat
391.7 ms ✓ BitFlags
437.9 ms ✓ MappedArrays
1206.5 ms ✓ Crayons
713.7 ms ✓ GZip
792.2 ms ✓ ConcurrentUtilities
638.7 ms ✓ ZipFile
498.6 ms ✓ PooledArrays
912.6 ms ✓ StructTypes
1351.5 ms ✓ SentinelArrays
2480.8 ms ✓ FixedPointNumbers
1145.5 ms ✓ MbedTLS
553.7 ms ✓ LoggingExtras
581.4 ms ✓ MPIPreferences
393.1 ms ✓ InternedStrings
542.0 ms ✓ ExceptionUnwrapping
584.7 ms ✓ BFloat16s
650.6 ms ✓ OpenSSL_jll
632.7 ms ✓ Chemfiles_jll
518.5 ms ✓ MicrosoftMPI_jll
700.1 ms ✓ libaec_jll
652.9 ms ✓ Libiconv_jll
1061.7 ms ✓ FilePathsBase
2209.1 ms ✓ StringManipulation
563.3 ms ✓ InlineStrings → ParsersExt
436.2 ms ✓ OffsetArrays → OffsetArraysAdaptExt
439.3 ms ✓ StackViews
459.1 ms ✓ PaddedViews
520.8 ms ✓ CodecZlib
476.8 ms ✓ StridedViews
4684.1 ms ✓ FileIO
1557.2 ms ✓ ColorTypes
1192.1 ms ✓ OpenMPI_jll
1705.7 ms ✓ MPICH_jll
1392.7 ms ✓ MPItrampoline_jll
2015.4 ms ✓ OpenSSL
640.3 ms ✓ StringEncodings
10143.6 ms ✓ JSON3
559.6 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1291.7 ms ✓ FilePathsBase → FilePathsBaseTestExt
846.7 ms ✓ WeakRefStrings
507.1 ms ✓ MosaicViews
21886.9 ms ✓ Unitful
1562.5 ms ✓ NPZ
2133.5 ms ✓ ColorVectorSpace
5022.1 ms ✓ Colors
1887.6 ms ✓ HDF5_jll
21073.7 ms ✓ PrettyTables
2425.7 ms ✓ Pickle
20024.0 ms ✓ HTTP
639.3 ms ✓ Unitful → ConstructionBaseUnitfulExt
626.9 ms ✓ Unitful → InverseFunctionsUnitfulExt
3136.9 ms ✓ UnitfulAtomic
35038.7 ms ✓ JLD2
660.8 ms ✓ Accessors → UnitfulExt
2491.9 ms ✓ PeriodicTable
17534.6 ms ✓ CSV
3718.6 ms ✓ ColorSchemes
7809.7 ms ✓ HDF5
3221.0 ms ✓ DataDeps
1863.1 ms ✓ FileIO → HTTPExt
2328.8 ms ✓ AtomsBase
19783.1 ms ✓ ImageCore
2540.0 ms ✓ MAT
2177.6 ms ✓ ImageBase
2445.9 ms ✓ Chemfiles
1918.1 ms ✓ ImageShow
45995.5 ms ✓ DataFrames
1445.0 ms ✓ Transducers → TransducersDataFramesExt
1731.0 ms ✓ BangBang → BangBangDataFramesExt
10009.6 ms ✓ MLDatasets
83 dependencies successfully precompiled in 126 seconds. 117 already precompiled.
Precompiling SimpleChains...
366.0 ms ✓ UnPack
517.1 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
886.9 ms ✓ HostCPUFeatures
7603.0 ms ✓ VectorizationBase
991.8 ms ✓ SLEEFPirates
1340.0 ms ✓ VectorizedRNG
748.8 ms ✓ VectorizedRNG → VectorizedRNGStaticArraysExt
26938.7 ms ✓ LoopVectorization
1153.8 ms ✓ LoopVectorization → SpecialFunctionsExt
1256.2 ms ✓ LoopVectorization → ForwardDiffExt
6144.0 ms ✓ SimpleChains
11 dependencies successfully precompiled in 44 seconds. 63 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
2438.5 ms ✓ LuxLib → LuxLibSLEEFPiratesExt
1 dependency successfully precompiled in 3 seconds. 98 already precompiled.
Precompiling ReactantOffsetArraysExt...
11866.4 ms ✓ Reactant → ReactantOffsetArraysExt
1 dependency successfully precompiled in 12 seconds. 66 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
4117.7 ms ✓ LuxLib → LuxLibLoopVectorizationExt
1 dependency successfully precompiled in 4 seconds. 106 already precompiled.
Precompiling LuxSimpleChainsExt...
2051.5 ms ✓ Lux → LuxSimpleChainsExt
1 dependency successfully precompiled in 2 seconds. 127 already precompiled.
2025-03-11 22:47:16.870590: I external/xla/xla/service/service.cc:152] XLA service 0x5907b40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-11 22:47:16.871013: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741733236.873427 997531 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741733236.873626 997531 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741733236.873819 997531 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741733236.890114 997531 cuda_dnn.cc:529] Loaded cuDNN version 90400
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split = :train)
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end
# Process images into (H, W, C, BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at = train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle = true, partial = false),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle = false, partial = false),
)
end
loadmnist (generic function with 1 method)
Define the Model
lux_model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(
Dense(256 => 128, relu),
Dense(128 => 84, relu),
Dense(84 => 10)
)
)
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
) # Total: 47_154 parameters,
# plus 0 states.
We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor
and providing the input dimensions.
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
),
) # Total: 47_154 parameters,
# plus 0 states.
Helper Functions
const lossfn = CrossEntropyLoss(; logits = Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(Array(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Define the Training Loop
function train(model, dev = cpu_device(); rng = Random.default_rng(), kwargs...)
train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev
ps, st = Lux.setup(rng, model) |> dev
vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))
if dev isa ReactantDevice
x_ra = first(test_dataloader)[1]
model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
else
model_compiled = model
end
### Lets train the model
nepochs = 10
tr_acc, te_acc = 0.0, 0.0
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
_, _, _, train_state = Training.single_train_step!(
vjp, lossfn, (x, y), train_state
)
end
ttime = time() - stime
tr_acc = accuracy(
model_compiled, train_state.parameters, train_state.states, train_dataloader
) *
100
te_acc = accuracy(
model_compiled, train_state.parameters, train_state.states, test_dataloader
) *
100
@printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
%.2f%%\n" epoch nepochs ttime tr_acc te_acc
end
return tr_acc, te_acc
end
train (generic function with 2 methods)
Finally Training the Model
First we will train the Lux model
tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] Time 321.37s Training Accuracy: 14.06% Test Accuracy: 17.19%
[ 2/10] Time 0.39s Training Accuracy: 29.53% Test Accuracy: 34.38%
[ 3/10] Time 0.39s Training Accuracy: 43.75% Test Accuracy: 41.41%
[ 4/10] Time 0.39s Training Accuracy: 54.30% Test Accuracy: 50.00%
[ 5/10] Time 0.39s Training Accuracy: 63.75% Test Accuracy: 57.81%
[ 6/10] Time 0.38s Training Accuracy: 71.25% Test Accuracy: 63.28%
[ 7/10] Time 0.39s Training Accuracy: 75.23% Test Accuracy: 66.41%
[ 8/10] Time 0.39s Training Accuracy: 79.69% Test Accuracy: 67.19%
[ 9/10] Time 0.40s Training Accuracy: 82.97% Test Accuracy: 71.88%
[10/10] Time 0.38s Training Accuracy: 84.69% Test Accuracy: 73.44%
Now we will train the SimpleChains model
tr_acc, te_acc = train(simple_chains_model)
[ 1/10] Time 903.97s Training Accuracy: 42.03% Test Accuracy: 41.41%
[ 2/10] Time 12.28s Training Accuracy: 52.73% Test Accuracy: 52.34%
[ 3/10] Time 12.23s Training Accuracy: 60.47% Test Accuracy: 58.59%
[ 4/10] Time 12.23s Training Accuracy: 69.38% Test Accuracy: 66.41%
[ 5/10] Time 12.22s Training Accuracy: 74.53% Test Accuracy: 68.75%
[ 6/10] Time 12.22s Training Accuracy: 78.44% Test Accuracy: 75.00%
[ 7/10] Time 12.22s Training Accuracy: 80.47% Test Accuracy: 78.91%
[ 8/10] Time 12.24s Training Accuracy: 82.50% Test Accuracy: 79.69%
[ 9/10] Time 12.23s Training Accuracy: 84.61% Test Accuracy: 82.03%
[10/10] Time 12.23s Training Accuracy: 87.03% Test Accuracy: 82.03%
On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.