MNIST Classification with SimpleChains
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
Package Imports
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains
Reactant.set_default_backend("cpu")
Precompiling Lux...
9548.5 ms ✓ Lux
1 dependency successfully precompiled in 10 seconds. 104 already precompiled.
Precompiling LuxMLUtilsExt...
2207.7 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 164 already precompiled.
Precompiling LuxZygoteExt...
2605.0 ms ✓ Lux → LuxZygoteExt
1 dependency successfully precompiled in 3 seconds. 143 already precompiled.
Precompiling Reactant...
384.2 ms ✓ EnumX
653.3 ms ✓ URIs
736.0 ms ✓ ExpressionExplorer
441.5 ms ✓ SimpleBufferStream
407.2 ms ✓ BitFlags
757.6 ms ✓ ConcurrentUtilities
555.8 ms ✓ LoggingExtras
1101.0 ms ✓ MbedTLS
532.1 ms ✓ ExceptionUnwrapping
639.9 ms ✓ LLVMOpenMP_jll
654.0 ms ✓ ReactantCore
1569.9 ms ✓ CUDA_Driver_jll
1954.0 ms ✓ OpenSSL
2772.8 ms ✓ Reactant_jll
19045.6 ms ✓ HTTP
96167.7 ms ✓ Reactant
16 dependencies successfully precompiled in 120 seconds. 64 already precompiled.
Precompiling LuxEnzymeExt...
8608.1 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 9 seconds. 149 already precompiled.
Precompiling OptimisersReactantExt...
21700.0 ms ✓ Reactant → ReactantStatisticsExt
24261.7 ms ✓ Optimisers → OptimisersReactantExt
2 dependencies successfully precompiled in 25 seconds. 88 already precompiled.
Precompiling LuxCoreReactantExt...
21675.0 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 22 seconds. 85 already precompiled.
Precompiling MLDataDevicesReactantExt...
22802.7 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 23 seconds. 82 already precompiled.
Precompiling WeightInitializersReactantExt...
21893.0 ms ✓ Reactant → ReactantSpecialFunctionsExt
22349.3 ms ✓ WeightInitializers → WeightInitializersReactantExt
2 dependencies successfully precompiled in 23 seconds. 95 already precompiled.
Precompiling ReactantAbstractFFTsExt...
22757.8 ms ✓ Reactant → ReactantAbstractFFTsExt
1 dependency successfully precompiled in 23 seconds. 82 already precompiled.
Precompiling ReactantKernelAbstractionsExt...
22588.3 ms ✓ Reactant → ReactantKernelAbstractionsExt
1 dependency successfully precompiled in 23 seconds. 92 already precompiled.
Precompiling ReactantArrayInterfaceExt...
22545.4 ms ✓ Reactant → ReactantArrayInterfaceExt
1 dependency successfully precompiled in 23 seconds. 83 already precompiled.
Precompiling ReactantNNlibExt...
22087.9 ms ✓ Reactant → ReactantNNlibExt
1 dependency successfully precompiled in 22 seconds. 103 already precompiled.
Precompiling LuxReactantExt...
12590.2 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 13 seconds. 180 already precompiled.
Precompiling MLDatasets...
401.3 ms ✓ LaTeXStrings
387.4 ms ✓ Glob
448.5 ms ✓ TensorCore
413.2 ms ✓ WorkerUtilities
458.8 ms ✓ BufferedStreams
464.9 ms ✓ PaddedViews
393.0 ms ✓ LazyModules
382.5 ms ✓ InvertedIndices
679.1 ms ✓ InlineStrings
344.3 ms ✓ PackageExtensionCompat
445.6 ms ✓ MappedArrays
702.5 ms ✓ GZip
435.1 ms ✓ StackViews
1134.7 ms ✓ Crayons
625.9 ms ✓ ZipFile
562.4 ms ✓ BFloat16s
493.3 ms ✓ PooledArrays
808.8 ms ✓ StructTypes
1344.3 ms ✓ SentinelArrays
2616.2 ms ✓ FixedPointNumbers
595.9 ms ✓ MPIPreferences
374.2 ms ✓ InternedStrings
604.0 ms ✓ Chemfiles_jll
670.8 ms ✓ libaec_jll
624.5 ms ✓ Hwloc_jll
496.0 ms ✓ MicrosoftMPI_jll
1068.9 ms ✓ FilePathsBase
1966.0 ms ✓ StringManipulation
583.4 ms ✓ StringEncodings
4316.7 ms ✓ FileIO
3285.5 ms ✓ DataDeps
556.0 ms ✓ InlineStrings → ParsersExt
465.4 ms ✓ StridedViews
541.1 ms ✓ MosaicViews
1488.6 ms ✓ ColorTypes
1181.3 ms ✓ MPItrampoline_jll
1155.5 ms ✓ OpenMPI_jll
1504.7 ms ✓ MPICH_jll
575.7 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1391.1 ms ✓ FilePathsBase → FilePathsBaseTestExt
9559.7 ms ✓ JSON3
1963.1 ms ✓ FileIO → HTTPExt
20685.5 ms ✓ Unitful
1587.8 ms ✓ NPZ
836.8 ms ✓ WeakRefStrings
2430.6 ms ✓ Pickle
2234.2 ms ✓ ColorVectorSpace
4860.3 ms ✓ Colors
2011.5 ms ✓ HDF5_jll
600.5 ms ✓ Unitful → ConstructionBaseUnitfulExt
593.2 ms ✓ Unitful → InverseFunctionsUnitfulExt
20841.2 ms ✓ PrettyTables
2623.5 ms ✓ UnitfulAtomic
621.6 ms ✓ Accessors → UnitfulExt
2481.5 ms ✓ PeriodicTable
34966.4 ms ✓ JLD2
18758.1 ms ✓ CSV
3516.2 ms ✓ ColorSchemes
19918.5 ms ✓ ImageCore
2369.6 ms ✓ AtomsBase
2194.7 ms ✓ ImageBase
7465.0 ms ✓ HDF5
2365.1 ms ✓ Chemfiles
1929.0 ms ✓ ImageShow
2401.3 ms ✓ MAT
48301.5 ms ✓ DataFrames
1390.1 ms ✓ Transducers → TransducersDataFramesExt
1839.0 ms ✓ BangBang → BangBangDataFramesExt
9716.9 ms ✓ MLDatasets
69 dependencies successfully precompiled in 124 seconds. 133 already precompiled.
Precompiling ReactantOffsetArraysExt...
21544.3 ms ✓ Reactant → ReactantOffsetArraysExt
1 dependency successfully precompiled in 22 seconds. 82 already precompiled.
Precompiling LuxSimpleChainsExt...
1887.2 ms ✓ Lux → LuxSimpleChainsExt
1 dependency successfully precompiled in 2 seconds. 122 already precompiled.
2025-05-23 22:23:27.617145: I external/xla/xla/service/service.cc:152] XLA service 0x3c432a50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-23 22:23:27.617625: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748039007.619100 647856 se_gpu_pjrt_client.cc:1026] Using BFC allocator.
I0000 00:00:1748039007.619491 647856 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1748039007.619854 647856 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1748039007.635854 647856 cuda_dnn.cc:529] Loaded cuDNN version 90400
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split=:train)
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end
# Process images into (H, W, C, BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true, partial=false),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false, partial=false),
)
end
loadmnist (generic function with 1 method)
Define the Model
lux_model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
) # Total: 47_154 parameters,
# plus 0 states.
We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor
and providing the input dimensions.
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
),
) # Total: 47_154 parameters,
# plus 0 states.
Helper Functions
const lossfn = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(Array(y))
predicted_class = onecold(Array(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Define the Training Loop
function train(model, dev=cpu_device(); rng=Random.default_rng(), kwargs...)
train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))
ps, st = dev(Lux.setup(rng, model))
vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))
if dev isa ReactantDevice
x_ra = first(test_dataloader)[1]
model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
else
model_compiled = model
end
### Lets train the model
nepochs = 10
tr_acc, te_acc = 0.0, 0.0
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
_, _, _, train_state = Training.single_train_step!(
vjp, lossfn, (x, y), train_state
)
end
ttime = time() - stime
tr_acc =
accuracy(
model_compiled, train_state.parameters, train_state.states, train_dataloader
) * 100
te_acc =
accuracy(
model_compiled, train_state.parameters, train_state.states, test_dataloader
) * 100
@printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
%.2f%%\n" epoch nepochs ttime tr_acc te_acc
end
return tr_acc, te_acc
end
train (generic function with 2 methods)
Finally Training the Model
First we will train the Lux model
tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] Time 312.09s Training Accuracy: 16.09% Test Accuracy: 19.53%
[ 2/10] Time 0.23s Training Accuracy: 28.12% Test Accuracy: 23.44%
[ 3/10] Time 0.24s Training Accuracy: 46.72% Test Accuracy: 42.97%
[ 4/10] Time 0.23s Training Accuracy: 57.42% Test Accuracy: 56.25%
[ 5/10] Time 0.23s Training Accuracy: 66.48% Test Accuracy: 63.28%
[ 6/10] Time 0.21s Training Accuracy: 73.05% Test Accuracy: 68.75%
[ 7/10] Time 0.23s Training Accuracy: 76.25% Test Accuracy: 70.31%
[ 8/10] Time 0.22s Training Accuracy: 79.45% Test Accuracy: 74.22%
[ 9/10] Time 0.23s Training Accuracy: 81.25% Test Accuracy: 80.47%
[10/10] Time 0.22s Training Accuracy: 82.89% Test Accuracy: 81.25%
Now we will train the SimpleChains model
tr_acc, te_acc = train(simple_chains_model)
[ 1/10] Time 1081.64s Training Accuracy: 29.77% Test Accuracy: 27.34%
[ 2/10] Time 12.20s Training Accuracy: 50.23% Test Accuracy: 51.56%
[ 3/10] Time 12.19s Training Accuracy: 57.81% Test Accuracy: 54.69%
[ 4/10] Time 12.21s Training Accuracy: 62.97% Test Accuracy: 57.81%
[ 5/10] Time 12.21s Training Accuracy: 70.70% Test Accuracy: 62.50%
[ 6/10] Time 12.17s Training Accuracy: 74.84% Test Accuracy: 67.19%
[ 7/10] Time 12.16s Training Accuracy: 77.89% Test Accuracy: 74.22%
[ 8/10] Time 12.25s Training Accuracy: 80.78% Test Accuracy: 71.88%
[ 9/10] Time 12.16s Training Accuracy: 82.27% Test Accuracy: 79.69%
[10/10] Time 12.18s Training Accuracy: 85.00% Test Accuracy: 78.12%
On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.