MNIST Classification with SimpleChains
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
Package Imports
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains
Reactant.set_default_backend("cpu")
Precompiling Lux...
431.4 ms ✓ ConcreteStructs
409.1 ms ✓ ArgCheck
558.0 ms ✓ ArrayInterface
478.5 ms ✓ LuxCore → LuxCoreSetfieldExt
387.6 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
423.4 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
1284.5 ms ✓ StaticArrayInterface
384.1 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
490.9 ms ✓ CloseOpenIntervals
680.0 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
617.2 ms ✓ LayoutPointers
2792.9 ms ✓ WeightInitializers
918.6 ms ✓ StrideArraysCore
892.2 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
771.8 ms ✓ Polyester
5632.8 ms ✓ LuxLib
9548.8 ms ✓ Lux
17 dependencies successfully precompiled in 21 seconds. 88 already precompiled.
Precompiling MLUtils...
408.6 ms ✓ StatsAPI
455.0 ms ✓ InverseFunctions
385.8 ms ✓ PrettyPrint
448.9 ms ✓ ShowCases
911.7 ms ✓ InitialValues
339.2 ms ✓ CompositionsBase
355.6 ms ✓ PtrArrays
359.0 ms ✓ DefineSingletons
447.5 ms ✓ DelimitedFiles
447.1 ms ✓ Missings
1106.5 ms ✓ Baselet
448.5 ms ✓ ContextVariablesX
1034.2 ms ✓ SimpleTraits
1160.9 ms ✓ SplittablesBase
626.5 ms ✓ InverseFunctions → InverseFunctionsTestExt
1703.9 ms ✓ DataStructures
421.9 ms ✓ InverseFunctions → InverseFunctionsDatesExt
421.2 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
425.7 ms ✓ NameResolution
411.2 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
490.4 ms ✓ AliasTables
622.9 ms ✓ FLoopsBase
523.5 ms ✓ SortingAlgorithms
1045.4 ms ✓ MLCore
2467.3 ms ✓ Accessors
2278.0 ms ✓ StatsBase
673.2 ms ✓ Accessors → TestExt
895.7 ms ✓ Accessors → LinearAlgebraExt
718.0 ms ✓ Accessors → StaticArraysExt
815.1 ms ✓ BangBang
542.6 ms ✓ BangBang → BangBangChainRulesCoreExt
718.0 ms ✓ BangBang → BangBangStaticArraysExt
534.8 ms ✓ BangBang → BangBangTablesExt
960.6 ms ✓ MicroCollections
2830.9 ms ✓ Transducers
698.4 ms ✓ Transducers → TransducersAdaptExt
17977.0 ms ✓ MLStyle
4287.6 ms ✓ JuliaVariables
5025.4 ms ✓ FLoops
5845.3 ms ✓ MLUtils
40 dependencies successfully precompiled in 35 seconds. 57 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
643.6 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1522.1 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 101 already precompiled.
Precompiling LuxMLUtilsExt...
2033.1 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 164 already precompiled.
Precompiling StructArraysExt...
477.1 ms ✓ Accessors → StructArraysExt
1 dependency successfully precompiled in 1 seconds. 20 already precompiled.
Precompiling BangBangStructArraysExt...
501.3 ms ✓ BangBang → BangBangStructArraysExt
1 dependency successfully precompiled in 1 seconds. 24 already precompiled.
Precompiling ArrayInterfaceChainRulesExt...
787.7 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling LuxZygoteExt...
2647.1 ms ✓ Lux → LuxZygoteExt
1 dependency successfully precompiled in 3 seconds. 143 already precompiled.
Precompiling OneHotArrays...
1028.7 ms ✓ OneHotArrays
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
772.3 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 1 seconds. 38 already precompiled.
Precompiling Reactant...
397.5 ms ✓ EnumX
614.9 ms ✓ URIs
750.6 ms ✓ ExpressionExplorer
364.5 ms ✓ SimpleBufferStream
387.2 ms ✓ BitFlags
571.8 ms ✓ TranscodingStreams
724.5 ms ✓ ConcurrentUtilities
544.7 ms ✓ LoggingExtras
524.1 ms ✓ ExceptionUnwrapping
1076.0 ms ✓ MbedTLS
651.1 ms ✓ OpenSSL_jll
633.1 ms ✓ LLVMOpenMP_jll
652.9 ms ✓ ReactantCore
489.0 ms ✓ CodecZlib
1545.7 ms ✓ CUDA_Driver_jll
1946.7 ms ✓ OpenSSL
2717.6 ms ✓ Reactant_jll
18941.0 ms ✓ HTTP
92752.0 ms ✓ Reactant
19 dependencies successfully precompiled in 117 seconds. 61 already precompiled.
Precompiling LuxLibEnzymeExt...
1346.4 ms ✓ LuxLib → LuxLibEnzymeExt
1 dependency successfully precompiled in 2 seconds. 133 already precompiled.
Precompiling LuxEnzymeExt...
7966.1 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 8 seconds. 149 already precompiled.
Precompiling OptimisersReactantExt...
21684.7 ms ✓ Reactant → ReactantStatisticsExt
25528.8 ms ✓ Optimisers → OptimisersReactantExt
2 dependencies successfully precompiled in 26 seconds. 88 already precompiled.
Precompiling LuxCoreReactantExt...
22115.4 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 23 seconds. 85 already precompiled.
Precompiling MLDataDevicesReactantExt...
21344.8 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 22 seconds. 82 already precompiled.
Precompiling WeightInitializersReactantExt...
21585.3 ms ✓ WeightInitializers → WeightInitializersReactantExt
21791.9 ms ✓ Reactant → ReactantSpecialFunctionsExt
2 dependencies successfully precompiled in 22 seconds. 95 already precompiled.
Precompiling ReactantAbstractFFTsExt...
22351.2 ms ✓ Reactant → ReactantAbstractFFTsExt
1 dependency successfully precompiled in 23 seconds. 82 already precompiled.
Precompiling ReactantKernelAbstractionsExt...
22353.0 ms ✓ Reactant → ReactantKernelAbstractionsExt
1 dependency successfully precompiled in 23 seconds. 92 already precompiled.
Precompiling ReactantArrayInterfaceExt...
21669.3 ms ✓ Reactant → ReactantArrayInterfaceExt
1 dependency successfully precompiled in 22 seconds. 83 already precompiled.
Precompiling ReactantNNlibExt...
23020.1 ms ✓ Reactant → ReactantNNlibExt
1 dependency successfully precompiled in 23 seconds. 103 already precompiled.
Precompiling LuxReactantExt...
13487.0 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 14 seconds. 180 already precompiled.
Precompiling MLDatasets...
389.2 ms ✓ LaTeXStrings
409.7 ms ✓ Glob
439.1 ms ✓ TensorCore
442.6 ms ✓ WorkerUtilities
517.3 ms ✓ BufferedStreams
971.2 ms ✓ OffsetArrays
413.1 ms ✓ LazyModules
647.1 ms ✓ InlineStrings
384.4 ms ✓ InvertedIndices
381.9 ms ✓ PackageExtensionCompat
410.1 ms ✓ MappedArrays
705.0 ms ✓ GZip
1183.4 ms ✓ Crayons
638.5 ms ✓ ZipFile
589.0 ms ✓ BFloat16s
506.5 ms ✓ PooledArrays
865.5 ms ✓ StructTypes
1371.3 ms ✓ SentinelArrays
2427.9 ms ✓ FixedPointNumbers
385.0 ms ✓ InternedStrings
637.1 ms ✓ MPIPreferences
671.1 ms ✓ Chemfiles_jll
674.4 ms ✓ libaec_jll
632.9 ms ✓ Hwloc_jll
513.0 ms ✓ MicrosoftMPI_jll
691.2 ms ✓ Libiconv_jll
1147.5 ms ✓ FilePathsBase
2095.2 ms ✓ StringManipulation
3349.5 ms ✓ DataDeps
455.0 ms ✓ OffsetArrays → OffsetArraysAdaptExt
448.7 ms ✓ StackViews
4700.8 ms ✓ FileIO
484.0 ms ✓ PaddedViews
542.7 ms ✓ InlineStrings → ParsersExt
473.9 ms ✓ StridedViews
1489.2 ms ✓ ColorTypes
1139.4 ms ✓ MPItrampoline_jll
1207.3 ms ✓ OpenMPI_jll
1500.5 ms ✓ MPICH_jll
564.6 ms ✓ StringEncodings
583.9 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1273.6 ms ✓ FilePathsBase → FilePathsBaseTestExt
9590.7 ms ✓ JSON3
1902.2 ms ✓ FileIO → HTTPExt
1619.0 ms ✓ NPZ
22465.0 ms ✓ Unitful
479.7 ms ✓ MosaicViews
823.1 ms ✓ WeakRefStrings
2131.8 ms ✓ ColorVectorSpace
4803.8 ms ✓ Colors
2101.6 ms ✓ HDF5_jll
2461.7 ms ✓ Pickle
587.0 ms ✓ Unitful → ConstructionBaseUnitfulExt
595.7 ms ✓ Unitful → InverseFunctionsUnitfulExt
20377.8 ms ✓ PrettyTables
2636.8 ms ✓ UnitfulAtomic
638.4 ms ✓ Accessors → UnitfulExt
2489.5 ms ✓ PeriodicTable
33889.1 ms ✓ JLD2
18570.2 ms ✓ CSV
19329.1 ms ✓ ImageCore
3554.0 ms ✓ ColorSchemes
2227.9 ms ✓ AtomsBase
2096.6 ms ✓ ImageBase
7370.4 ms ✓ HDF5
2315.7 ms ✓ Chemfiles
1964.6 ms ✓ ImageShow
2317.8 ms ✓ MAT
48234.4 ms ✓ DataFrames
1515.5 ms ✓ Transducers → TransducersDataFramesExt
1641.7 ms ✓ BangBang → BangBangDataFramesExt
10438.0 ms ✓ MLDatasets
72 dependencies successfully precompiled in 124 seconds. 130 already precompiled.
Precompiling SimpleChains...
352.3 ms ✓ UnPack
520.1 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
828.2 ms ✓ HostCPUFeatures
7777.4 ms ✓ VectorizationBase
1037.8 ms ✓ SLEEFPirates
1301.0 ms ✓ VectorizedRNG
727.1 ms ✓ VectorizedRNG → VectorizedRNGStaticArraysExt
28536.4 ms ✓ LoopVectorization
1057.7 ms ✓ LoopVectorization → SpecialFunctionsExt
1244.0 ms ✓ LoopVectorization → ForwardDiffExt
6069.4 ms ✓ SimpleChains
11 dependencies successfully precompiled in 46 seconds. 58 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
2373.5 ms ✓ LuxLib → LuxLibSLEEFPiratesExt
1 dependency successfully precompiled in 3 seconds. 93 already precompiled.
Precompiling ReactantOffsetArraysExt...
21885.2 ms ✓ Reactant → ReactantOffsetArraysExt
1 dependency successfully precompiled in 22 seconds. 82 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
5012.8 ms ✓ LuxLib → LuxLibLoopVectorizationExt
1 dependency successfully precompiled in 5 seconds. 101 already precompiled.
Precompiling LuxSimpleChainsExt...
2126.6 ms ✓ Lux → LuxSimpleChainsExt
1 dependency successfully precompiled in 2 seconds. 122 already precompiled.
2025-05-08 13:45:38.679856: I external/xla/xla/service/service.cc:152] XLA service 0x4f879090 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-08 13:45:38.680453: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1746711938.683242 3123647 se_gpu_pjrt_client.cc:1026] Using BFC allocator.
I0000 00:00:1746711938.683538 3123647 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1746711938.684854 3123647 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1746711938.740060 3123647 cuda_dnn.cc:529] Loaded cuDNN version 90400
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split=:train)
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end
# Process images into (H, W, C, BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true, partial=false),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false, partial=false),
)
end
loadmnist (generic function with 1 method)
Define the Model
lux_model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
) # Total: 47_154 parameters,
# plus 0 states.
We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor
and providing the input dimensions.
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
),
) # Total: 47_154 parameters,
# plus 0 states.
Helper Functions
const lossfn = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(Array(y))
predicted_class = onecold(Array(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Define the Training Loop
function train(model, dev=cpu_device(); rng=Random.default_rng(), kwargs...)
train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))
ps, st = dev(Lux.setup(rng, model))
vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))
if dev isa ReactantDevice
x_ra = first(test_dataloader)[1]
model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
else
model_compiled = model
end
### Lets train the model
nepochs = 10
tr_acc, te_acc = 0.0, 0.0
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
_, _, _, train_state = Training.single_train_step!(
vjp, lossfn, (x, y), train_state
)
end
ttime = time() - stime
tr_acc =
accuracy(
model_compiled, train_state.parameters, train_state.states, train_dataloader
) * 100
te_acc =
accuracy(
model_compiled, train_state.parameters, train_state.states, test_dataloader
) * 100
@printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
%.2f%%\n" epoch nepochs ttime tr_acc te_acc
end
return tr_acc, te_acc
end
train (generic function with 2 methods)
Finally Training the Model
First we will train the Lux model
tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] Time 269.70s Training Accuracy: 9.84% Test Accuracy: 8.59%
[ 2/10] Time 0.20s Training Accuracy: 22.34% Test Accuracy: 18.75%
[ 3/10] Time 0.23s Training Accuracy: 31.64% Test Accuracy: 28.91%
[ 4/10] Time 0.19s Training Accuracy: 43.59% Test Accuracy: 39.06%
[ 5/10] Time 0.20s Training Accuracy: 53.67% Test Accuracy: 50.00%
[ 6/10] Time 0.23s Training Accuracy: 61.33% Test Accuracy: 58.59%
[ 7/10] Time 0.21s Training Accuracy: 67.19% Test Accuracy: 59.38%
[ 8/10] Time 0.19s Training Accuracy: 73.12% Test Accuracy: 69.53%
[ 9/10] Time 0.22s Training Accuracy: 75.86% Test Accuracy: 75.00%
[10/10] Time 0.23s Training Accuracy: 79.30% Test Accuracy: 76.56%
Now we will train the SimpleChains model
tr_acc, te_acc = train(simple_chains_model)
[ 1/10] Time 869.90s Training Accuracy: 28.83% Test Accuracy: 26.56%
[ 2/10] Time 12.10s Training Accuracy: 41.02% Test Accuracy: 35.94%
[ 3/10] Time 12.10s Training Accuracy: 55.08% Test Accuracy: 54.69%
[ 4/10] Time 12.09s Training Accuracy: 64.45% Test Accuracy: 66.41%
[ 5/10] Time 12.13s Training Accuracy: 74.45% Test Accuracy: 75.78%
[ 6/10] Time 12.12s Training Accuracy: 81.80% Test Accuracy: 80.47%
[ 7/10] Time 12.12s Training Accuracy: 83.12% Test Accuracy: 82.81%
[ 8/10] Time 12.25s Training Accuracy: 84.30% Test Accuracy: 84.38%
[ 9/10] Time 12.13s Training Accuracy: 87.03% Test Accuracy: 85.16%
[10/10] Time 12.13s Training Accuracy: 88.36% Test Accuracy: 85.94%
On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.