MNIST Classification with SimpleChains
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
Package Imports
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains
Reactant.set_default_backend("cpu")
Precompiling Lux...
412.4 ms ✓ ConcreteStructs
2623.4 ms ✓ WeightInitializers
949.3 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
9467.4 ms ✓ Lux
4 dependencies successfully precompiled in 14 seconds. 106 already precompiled.
Precompiling MLUtils...
342.0 ms ✓ PtrArrays
391.8 ms ✓ StatsAPI
468.0 ms ✓ InverseFunctions
458.6 ms ✓ DelimitedFiles
475.3 ms ✓ AliasTables
1177.4 ms ✓ SimpleTraits
1343.6 ms ✓ SplittablesBase
642.7 ms ✓ InverseFunctions → InverseFunctionsTestExt
449.5 ms ✓ InverseFunctions → InverseFunctionsDatesExt
376.9 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
464.9 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
1173.0 ms ✓ MLCore
2400.9 ms ✓ StatsBase
2350.4 ms ✓ Accessors
693.4 ms ✓ Accessors → TestExt
770.3 ms ✓ Accessors → StaticArraysExt
893.0 ms ✓ Accessors → LinearAlgebraExt
766.4 ms ✓ BangBang
514.0 ms ✓ BangBang → BangBangChainRulesCoreExt
532.0 ms ✓ BangBang → BangBangTablesExt
702.8 ms ✓ BangBang → BangBangStaticArraysExt
1080.1 ms ✓ MicroCollections
2750.6 ms ✓ Transducers
693.4 ms ✓ Transducers → TransducersAdaptExt
5310.9 ms ✓ FLoops
5945.2 ms ✓ MLUtils
26 dependencies successfully precompiled in 22 seconds. 76 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1597.3 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 106 already precompiled.
Precompiling LuxMLUtilsExt...
2110.8 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 169 already precompiled.
Precompiling StructArraysExt...
502.4 ms ✓ Accessors → StructArraysExt
1 dependency successfully precompiled in 1 seconds. 20 already precompiled.
Precompiling BangBangStructArraysExt...
516.9 ms ✓ BangBang → BangBangStructArraysExt
1 dependency successfully precompiled in 1 seconds. 24 already precompiled.
Precompiling LuxZygoteExt...
2742.5 ms ✓ Lux → LuxZygoteExt
1 dependency successfully precompiled in 3 seconds. 148 already precompiled.
Precompiling OneHotArrays...
1086.0 ms ✓ OneHotArrays
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
732.5 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 1 seconds. 38 already precompiled.
Precompiling Reactant...
391.9 ms ✓ EnumX
386.2 ms ✓ SimpleBufferStream
638.1 ms ✓ URIs
408.0 ms ✓ BitFlags
544.6 ms ✓ TranscodingStreams
485.8 ms ✓ CodecZlib
1890.8 ms ✓ OpenSSL
18260.2 ms ✓ HTTP
74212.9 ms ✓ Reactant
9 dependencies successfully precompiled in 95 seconds. 68 already precompiled.
Precompiling LuxEnzymeExt...
7149.4 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 8 seconds. 148 already precompiled.
Precompiling LuxCoreReactantExt...
12987.8 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 13 seconds. 82 already precompiled.
Precompiling MLDataDevicesReactantExt...
13106.5 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 13 seconds. 79 already precompiled.
Precompiling WeightInitializersReactantExt...
13489.5 ms ✓ Reactant → ReactantStatisticsExt
13638.9 ms ✓ WeightInitializers → WeightInitializersReactantExt
13984.2 ms ✓ Reactant → ReactantSpecialFunctionsExt
3 dependencies successfully precompiled in 14 seconds. 91 already precompiled.
Precompiling ReactantAbstractFFTsExt...
12868.2 ms ✓ Reactant → ReactantAbstractFFTsExt
1 dependency successfully precompiled in 13 seconds. 79 already precompiled.
Precompiling ReactantKernelAbstractionsExt...
13686.3 ms ✓ Reactant → ReactantKernelAbstractionsExt
1 dependency successfully precompiled in 14 seconds. 89 already precompiled.
Precompiling ReactantArrayInterfaceExt...
13804.7 ms ✓ Reactant → ReactantArrayInterfaceExt
1 dependency successfully precompiled in 14 seconds. 80 already precompiled.
Precompiling ReactantNNlibExt...
13922.2 ms ✓ Reactant → ReactantNNlibExt
1 dependency successfully precompiled in 14 seconds. 102 already precompiled.
Precompiling LuxReactantExt...
11800.7 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 12 seconds. 178 already precompiled.
Precompiling MLDatasets...
457.2 ms ✓ TensorCore
425.5 ms ✓ LazyModules
435.8 ms ✓ MappedArrays
960.4 ms ✓ OffsetArrays
704.3 ms ✓ GZip
656.2 ms ✓ ZipFile
666.8 ms ✓ Unitful → InverseFunctionsUnitfulExt
691.1 ms ✓ Accessors → UnitfulExt
550.4 ms ✓ FilePathsBase → FilePathsBaseMmapExt
908.6 ms ✓ WeakRefStrings
1268.7 ms ✓ FilePathsBase → FilePathsBaseTestExt
2403.7 ms ✓ AtomsBase
1423.9 ms ✓ Transducers → TransducersDataFramesExt
2023.0 ms ✓ HDF5_jll
1717.9 ms ✓ BangBang → BangBangDataFramesExt
2106.3 ms ✓ FileIO → HTTPExt
3399.2 ms ✓ DataDeps
453.0 ms ✓ OffsetArrays → OffsetArraysAdaptExt
2149.7 ms ✓ ColorVectorSpace
481.3 ms ✓ StackViews
482.1 ms ✓ PaddedViews
1587.6 ms ✓ NPZ
2581.5 ms ✓ Pickle
2412.5 ms ✓ Chemfiles
7551.3 ms ✓ HDF5
3656.2 ms ✓ ColorSchemes
497.7 ms ✓ MosaicViews
17191.7 ms ✓ CSV
2491.6 ms ✓ MAT
33609.8 ms ✓ JLD2
19116.8 ms ✓ ImageCore
2086.1 ms ✓ ImageBase
1888.1 ms ✓ ImageShow
9078.2 ms ✓ MLDatasets
34 dependencies successfully precompiled in 62 seconds. 168 already precompiled.
Precompiling SimpleChains...
522.5 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
838.1 ms ✓ HostCPUFeatures
7608.6 ms ✓ VectorizationBase
1026.8 ms ✓ SLEEFPirates
1379.4 ms ✓ VectorizedRNG
722.2 ms ✓ VectorizedRNG → VectorizedRNGStaticArraysExt
27543.8 ms ✓ LoopVectorization
1184.0 ms ✓ LoopVectorization → SpecialFunctionsExt
1419.0 ms ✓ LoopVectorization → ForwardDiffExt
6463.5 ms ✓ SimpleChains
10 dependencies successfully precompiled in 45 seconds. 64 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
2611.4 ms ✓ LuxLib → LuxLibSLEEFPiratesExt
1 dependency successfully precompiled in 3 seconds. 98 already precompiled.
Precompiling ReactantOffsetArraysExt...
13542.2 ms ✓ Reactant → ReactantOffsetArraysExt
1 dependency successfully precompiled in 14 seconds. 79 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
3978.0 ms ✓ LuxLib → LuxLibLoopVectorizationExt
1 dependency successfully precompiled in 4 seconds. 106 already precompiled.
Precompiling LuxSimpleChainsExt...
1953.5 ms ✓ Lux → LuxSimpleChainsExt
1 dependency successfully precompiled in 2 seconds. 127 already precompiled.
2025-03-28 04:31:42.942423: I external/xla/xla/service/service.cc:152] XLA service 0x7a3d030 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-28 04:31:42.942548: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1743136302.943404 3356716 se_gpu_pjrt_client.cc:1039] Using BFC allocator.
I0000 00:00:1743136302.943481 3356716 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1743136302.943529 3356716 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1743136302.955986 3356716 cuda_dnn.cc:529] Loaded cuDNN version 90400
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split=:train)
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end
# Process images into (H, W, C, BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true, partial=false),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false, partial=false),
)
end
loadmnist (generic function with 1 method)
Define the Model
lux_model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
) # Total: 47_154 parameters,
# plus 0 states.
We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor
and providing the input dimensions.
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
),
) # Total: 47_154 parameters,
# plus 0 states.
Helper Functions
const lossfn = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(Array(y))
predicted_class = onecold(Array(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Define the Training Loop
function train(model, dev=cpu_device(); rng=Random.default_rng(), kwargs...)
train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))
ps, st = dev(Lux.setup(rng, model))
vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))
if dev isa ReactantDevice
x_ra = first(test_dataloader)[1]
model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
else
model_compiled = model
end
### Lets train the model
nepochs = 10
tr_acc, te_acc = 0.0, 0.0
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
_, _, _, train_state = Training.single_train_step!(
vjp, lossfn, (x, y), train_state
)
end
ttime = time() - stime
tr_acc =
accuracy(
model_compiled, train_state.parameters, train_state.states, train_dataloader
) * 100
te_acc =
accuracy(
model_compiled, train_state.parameters, train_state.states, test_dataloader
) * 100
@printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
%.2f%%\n" epoch nepochs ttime tr_acc te_acc
end
return tr_acc, te_acc
end
train (generic function with 2 methods)
Finally Training the Model
First we will train the Lux model
tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] Time 267.57s Training Accuracy: 17.27% Test Accuracy: 18.75%
[ 2/10] Time 0.21s Training Accuracy: 35.08% Test Accuracy: 32.03%
[ 3/10] Time 0.20s Training Accuracy: 51.88% Test Accuracy: 45.31%
[ 4/10] Time 0.21s Training Accuracy: 65.78% Test Accuracy: 54.69%
[ 5/10] Time 0.20s Training Accuracy: 70.78% Test Accuracy: 58.59%
[ 6/10] Time 0.21s Training Accuracy: 75.62% Test Accuracy: 67.19%
[ 7/10] Time 0.21s Training Accuracy: 78.75% Test Accuracy: 75.78%
[ 8/10] Time 0.20s Training Accuracy: 80.94% Test Accuracy: 78.12%
[ 9/10] Time 0.21s Training Accuracy: 84.22% Test Accuracy: 81.25%
[10/10] Time 0.22s Training Accuracy: 86.56% Test Accuracy: 82.03%
Now we will train the SimpleChains model
tr_acc, te_acc = train(simple_chains_model)
[ 1/10] Time 870.34s Training Accuracy: 12.42% Test Accuracy: 12.50%
[ 2/10] Time 12.24s Training Accuracy: 36.72% Test Accuracy: 32.81%
[ 3/10] Time 12.36s Training Accuracy: 55.62% Test Accuracy: 50.00%
[ 4/10] Time 12.23s Training Accuracy: 69.06% Test Accuracy: 60.16%
[ 5/10] Time 12.21s Training Accuracy: 76.64% Test Accuracy: 71.09%
[ 6/10] Time 12.20s Training Accuracy: 78.75% Test Accuracy: 73.44%
[ 7/10] Time 12.22s Training Accuracy: 81.02% Test Accuracy: 78.12%
[ 8/10] Time 12.26s Training Accuracy: 81.25% Test Accuracy: 78.91%
[ 9/10] Time 12.22s Training Accuracy: 84.06% Test Accuracy: 81.25%
[10/10] Time 12.27s Training Accuracy: 85.78% Test Accuracy: 83.59%
On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.