MNIST Classification with SimpleChains
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
Package Imports
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains
Reactant.set_default_backend("cpu")
Precompiling Reactant...
1050.9 ms ✓ CUDA_Driver_jll
2228.4 ms ✓ Reactant_jll
75211.7 ms ✓ Reactant
3 dependencies successfully precompiled in 79 seconds. 61 already precompiled.
Precompiling LuxCoreReactantExt...
12014.1 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 12 seconds. 69 already precompiled.
Precompiling MLDataDevicesReactantExt...
12352.1 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 13 seconds. 66 already precompiled.
Precompiling LuxLibReactantExt...
12534.3 ms ✓ Reactant → ReactantStatisticsExt
13706.1 ms ✓ Reactant → ReactantNNlibExt
14093.4 ms ✓ LuxLib → LuxLibReactantExt
13797.1 ms ✓ Reactant → ReactantKernelAbstractionsExt
13010.5 ms ✓ Reactant → ReactantSpecialFunctionsExt
12864.1 ms ✓ Reactant → ReactantArrayInterfaceExt
6 dependencies successfully precompiled in 27 seconds. 143 already precompiled.
Precompiling WeightInitializersReactantExt...
12293.7 ms ✓ WeightInitializers → WeightInitializersReactantExt
1 dependency successfully precompiled in 13 seconds. 80 already precompiled.
Precompiling ReactantAbstractFFTsExt...
12277.8 ms ✓ Reactant → ReactantAbstractFFTsExt
1 dependency successfully precompiled in 13 seconds. 65 already precompiled.
Precompiling LuxReactantExt...
10189.8 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 11 seconds. 166 already precompiled.
Precompiling MLDatasets...
624.8 ms ✓ ZipFile
808.3 ms ✓ StructTypes
1136.7 ms ✓ MbedTLS
530.4 ms ✓ ExceptionUnwrapping
634.8 ms ✓ Chemfiles_jll
2151.6 ms ✓ ColorVectorSpace
655.6 ms ✓ libaec_jll
1061.1 ms ✓ FilePathsBase
1904.3 ms ✓ OpenSSL
1193.9 ms ✓ OpenMPI_jll
4481.0 ms ✓ FileIO
563.3 ms ✓ StringEncodings
803.9 ms ✓ WeakRefStrings
10344.6 ms ✓ JSON3
3649.6 ms ✓ ColorSchemes
21360.4 ms ✓ Unitful
543.7 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1281.4 ms ✓ FilePathsBase → FilePathsBaseTestExt
1886.1 ms ✓ HDF5_jll
1522.0 ms ✓ NPZ
19575.7 ms ✓ ImageCore
2313.2 ms ✓ Pickle
566.9 ms ✓ Unitful → ConstructionBaseUnitfulExt
586.5 ms ✓ Unitful → InverseFunctionsUnitfulExt
2919.0 ms ✓ UnitfulAtomic
2309.8 ms ✓ PeriodicTable
620.8 ms ✓ Accessors → UnitfulExt
19194.1 ms ✓ HTTP
7425.1 ms ✓ HDF5
2067.3 ms ✓ ImageBase
2213.0 ms ✓ AtomsBase
17147.2 ms ✓ CSV
3090.6 ms ✓ DataDeps
1935.0 ms ✓ FileIO → HTTPExt
1891.4 ms ✓ ImageShow
2422.4 ms ✓ MAT
33652.1 ms ✓ JLD2
2322.0 ms ✓ Chemfiles
9027.8 ms ✓ MLDatasets
39 dependencies successfully precompiled in 70 seconds. 161 already precompiled.
Precompiling ReactantOffsetArraysExt...
12422.7 ms ✓ Reactant → ReactantOffsetArraysExt
1 dependency successfully precompiled in 13 seconds. 66 already precompiled.
2025-03-07 22:03:58.056525: I external/xla/xla/service/service.cc:152] XLA service 0xa586010 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-07 22:03:58.056741: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741385038.057451 558870 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741385038.057518 558870 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741385038.057547 558870 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741385038.081646 558870 cuda_dnn.cc:529] Loaded cuDNN version 90400
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split = :train)
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end
# Process images into (H, W, C, BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at = train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle = true, partial = false),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle = false, partial = false),
)
end
loadmnist (generic function with 1 method)
Define the Model
lux_model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(
Dense(256 => 128, relu),
Dense(128 => 84, relu),
Dense(84 => 10)
)
)
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
) # Total: 47_154 parameters,
# plus 0 states.
We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor
and providing the input dimensions.
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
),
) # Total: 47_154 parameters,
# plus 0 states.
Helper Functions
const lossfn = CrossEntropyLoss(; logits = Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(Array(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Define the Training Loop
function train(model, dev = cpu_device(); rng = Random.default_rng(), kwargs...)
train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev
ps, st = Lux.setup(rng, model) |> dev
vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))
if dev isa ReactantDevice
x_ra = first(test_dataloader)[1]
model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
else
model_compiled = model
end
### Lets train the model
nepochs = 10
tr_acc, te_acc = 0.0, 0.0
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
_, _, _, train_state = Training.single_train_step!(
vjp, lossfn, (x, y), train_state
)
end
ttime = time() - stime
tr_acc = accuracy(
model_compiled, train_state.parameters, train_state.states, train_dataloader
) *
100
te_acc = accuracy(
model_compiled, train_state.parameters, train_state.states, test_dataloader
) *
100
@printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
%.2f%%\n" epoch nepochs ttime tr_acc te_acc
end
return tr_acc, te_acc
end
train (generic function with 2 methods)
Finally Training the Model
First we will train the Lux model
tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] Time 458.64s Training Accuracy: 16.88% Test Accuracy: 13.28%
[ 2/10] Time 0.39s Training Accuracy: 30.47% Test Accuracy: 25.00%
[ 3/10] Time 0.37s Training Accuracy: 41.33% Test Accuracy: 36.72%
[ 4/10] Time 0.38s Training Accuracy: 55.47% Test Accuracy: 43.75%
[ 5/10] Time 0.38s Training Accuracy: 63.67% Test Accuracy: 56.25%
[ 6/10] Time 0.39s Training Accuracy: 69.92% Test Accuracy: 64.84%
[ 7/10] Time 0.39s Training Accuracy: 74.30% Test Accuracy: 70.31%
[ 8/10] Time 0.39s Training Accuracy: 78.05% Test Accuracy: 75.78%
[ 9/10] Time 0.40s Training Accuracy: 80.16% Test Accuracy: 74.22%
[10/10] Time 0.39s Training Accuracy: 82.11% Test Accuracy: 76.56%
Now we will train the SimpleChains model
tr_acc, te_acc = train(simple_chains_model)
[ 1/10] Time 919.45s Training Accuracy: 26.95% Test Accuracy: 19.53%
[ 2/10] Time 12.46s Training Accuracy: 40.62% Test Accuracy: 35.16%
[ 3/10] Time 12.47s Training Accuracy: 52.34% Test Accuracy: 50.78%
[ 4/10] Time 12.44s Training Accuracy: 63.28% Test Accuracy: 61.72%
[ 5/10] Time 12.41s Training Accuracy: 74.06% Test Accuracy: 70.31%
[ 6/10] Time 12.39s Training Accuracy: 79.77% Test Accuracy: 73.44%
[ 7/10] Time 12.40s Training Accuracy: 84.06% Test Accuracy: 78.12%
[ 8/10] Time 12.51s Training Accuracy: 85.86% Test Accuracy: 82.03%
[ 9/10] Time 12.46s Training Accuracy: 86.48% Test Accuracy: 82.81%
[10/10] Time 12.43s Training Accuracy: 88.98% Test Accuracy: 82.81%
On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.