MNIST Classification with SimpleChains
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
Package Imports
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains
Reactant.set_default_backend("cpu")
Precompiling Reactant...
91422.6 ms ✓ Reactant
1 dependency successfully precompiled in 92 seconds. 79 already precompiled.
Precompiling OptimisersReactantExt...
17206.1 ms ✓ Reactant → ReactantStatisticsExt
21232.9 ms ✓ Optimisers → OptimisersReactantExt
2 dependencies successfully precompiled in 22 seconds. 88 already precompiled.
Precompiling LuxCoreReactantExt...
17200.2 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 18 seconds. 85 already precompiled.
Precompiling MLDataDevicesReactantExt...
17540.2 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 18 seconds. 82 already precompiled.
Precompiling LuxLibReactantExt...
18609.1 ms ✓ Reactant → ReactantSpecialFunctionsExt
18848.9 ms ✓ LuxLib → LuxLibReactantExt
19026.8 ms ✓ Reactant → ReactantKernelAbstractionsExt
17242.0 ms ✓ Reactant → ReactantArrayInterfaceExt
4 dependencies successfully precompiled in 36 seconds. 158 already precompiled.
Precompiling WeightInitializersReactantExt...
17865.6 ms ✓ WeightInitializers → WeightInitializersReactantExt
1 dependency successfully precompiled in 18 seconds. 96 already precompiled.
Precompiling ReactantAbstractFFTsExt...
18221.6 ms ✓ Reactant → ReactantAbstractFFTsExt
1 dependency successfully precompiled in 19 seconds. 82 already precompiled.
Precompiling ReactantOneHotArraysExt...
18283.8 ms ✓ Reactant → ReactantOneHotArraysExt
1 dependency successfully precompiled in 19 seconds. 104 already precompiled.
Precompiling ReactantNNlibExt...
20511.1 ms ✓ Reactant → ReactantNNlibExt
1 dependency successfully precompiled in 21 seconds. 103 already precompiled.
Precompiling LuxReactantExt...
13087.9 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 14 seconds. 180 already precompiled.
Precompiling MLDatasets...
401.2 ms ✓ Glob
458.3 ms ✓ WorkerUtilities
501.7 ms ✓ BufferedStreams
343.8 ms ✓ PackageExtensionCompat
644.2 ms ✓ ZipFile
711.4 ms ✓ GZip
637.3 ms ✓ MPIPreferences
379.5 ms ✓ InternedStrings
629.9 ms ✓ Chemfiles_jll
658.7 ms ✓ libaec_jll
2805.7 ms ✓ UnitfulAtomic
2473.3 ms ✓ PeriodicTable
492.0 ms ✓ MicrosoftMPI_jll
617.6 ms ✓ Hwloc_jll
577.0 ms ✓ StringEncodings
824.6 ms ✓ WeakRefStrings
481.7 ms ✓ StridedViews
1875.4 ms ✓ ImageShow
1575.5 ms ✓ NPZ
1151.8 ms ✓ MPItrampoline_jll
3054.2 ms ✓ DataDeps
1137.7 ms ✓ OpenMPI_jll
1436.9 ms ✓ MPICH_jll
381.8 ms ✓ StridedViews → StridedViewsPtrArraysExt
2313.4 ms ✓ AtomsBase
2072.9 ms ✓ HDF5_jll
2385.1 ms ✓ Pickle
2439.4 ms ✓ Chemfiles
7336.0 ms ✓ HDF5
2581.7 ms ✓ MAT
19070.6 ms ✓ CSV
9439.0 ms ✓ MLDatasets
32 dependencies successfully precompiled in 37 seconds. 172 already precompiled.
Precompiling ReactantOffsetArraysExt...
18443.1 ms ✓ Reactant → ReactantOffsetArraysExt
1 dependency successfully precompiled in 19 seconds. 82 already precompiled.
2025-07-14 16:24:57.369844: I external/xla/xla/service/service.cc:153] XLA service 0x3369cdd0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-14 16:24:57.369873: I external/xla/xla/service/service.cc:161] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752510297.370560 987836 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752510297.370612 987836 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752510297.370656 987836 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752510297.384395 987836 cuda_dnn.cc:471] Loaded cuDNN version 90800
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split=:train)
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end
# Process images into (H, W, C, BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true, partial=false),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false, partial=false),
)
end
loadmnist (generic function with 1 method)
Define the Model
lux_model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
) # Total: 47_154 parameters,
# plus 0 states.
We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor
and providing the input dimensions.
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
),
) # Total: 47_154 parameters,
# plus 0 states.
Helper Functions
const lossfn = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(Array(y))
predicted_class = onecold(Array(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Define the Training Loop
function train(model, dev=cpu_device(); rng=Random.default_rng(), kwargs...)
train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))
ps, st = dev(Lux.setup(rng, model))
vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))
if dev isa ReactantDevice
x_ra = first(test_dataloader)[1]
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model(x_ra, ps, Lux.testmode(st))
end
else
model_compiled = model
end
### Lets train the model
nepochs = 10
tr_acc, te_acc = 0.0, 0.0
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
_, _, _, train_state = Training.single_train_step!(
vjp, lossfn, (x, y), train_state
)
end
ttime = time() - stime
tr_acc =
accuracy(
model_compiled, train_state.parameters, train_state.states, train_dataloader
) * 100
te_acc =
accuracy(
model_compiled, train_state.parameters, train_state.states, test_dataloader
) * 100
@printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
%.2f%%\n" epoch nepochs ttime tr_acc te_acc
end
return tr_acc, te_acc
end
train (generic function with 2 methods)
Finally Training the Model
First we will train the Lux model
tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] Time 388.15s Training Accuracy: 24.92% Test Accuracy: 28.91%
[ 2/10] Time 0.10s Training Accuracy: 32.50% Test Accuracy: 36.72%
[ 3/10] Time 0.10s Training Accuracy: 44.53% Test Accuracy: 42.97%
[ 4/10] Time 0.10s Training Accuracy: 55.55% Test Accuracy: 54.69%
[ 5/10] Time 0.10s Training Accuracy: 63.12% Test Accuracy: 60.94%
[ 6/10] Time 0.10s Training Accuracy: 67.97% Test Accuracy: 60.94%
[ 7/10] Time 0.09s Training Accuracy: 73.52% Test Accuracy: 67.97%
[ 8/10] Time 0.10s Training Accuracy: 76.56% Test Accuracy: 74.22%
[ 9/10] Time 0.10s Training Accuracy: 80.39% Test Accuracy: 77.34%
[10/10] Time 0.10s Training Accuracy: 81.33% Test Accuracy: 81.25%
Now we will train the SimpleChains model
tr_acc, te_acc = train(simple_chains_model)
[ 1/10] Time 861.97s Training Accuracy: 29.77% Test Accuracy: 32.03%
[ 2/10] Time 12.09s Training Accuracy: 41.09% Test Accuracy: 38.28%
[ 3/10] Time 12.09s Training Accuracy: 49.45% Test Accuracy: 46.09%
[ 4/10] Time 12.09s Training Accuracy: 55.55% Test Accuracy: 47.66%
[ 5/10] Time 12.10s Training Accuracy: 69.30% Test Accuracy: 64.06%
[ 6/10] Time 12.09s Training Accuracy: 73.36% Test Accuracy: 69.53%
[ 7/10] Time 12.09s Training Accuracy: 76.02% Test Accuracy: 71.09%
[ 8/10] Time 12.09s Training Accuracy: 80.47% Test Accuracy: 78.91%
[ 9/10] Time 12.15s Training Accuracy: 83.05% Test Accuracy: 80.47%
[10/10] Time 12.09s Training Accuracy: 85.70% Test Accuracy: 82.81%
On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.