Skip to content

MNIST Classification with SimpleChains

SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.

Package Imports

julia
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains

Reactant.set_default_backend("cpu")
Precompiling Lux...
    417.0 ms  ✓ ScopedValues
   7227.3 ms  ✓ StaticArrays
    616.4 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    752.0 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    621.9 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    571.8 ms  ✓ Adapt → AdaptStaticArraysExt
    646.5 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    837.1 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   4239.1 ms  ✓ KernelAbstractions
    662.2 ms  ✓ KernelAbstractions → LinearAlgebraExt
    715.2 ms  ✓ KernelAbstractions → EnzymeExt
   5312.0 ms  ✓ NNlib
    831.7 ms  ✓ NNlib → NNlibEnzymeCoreExt
    830.4 ms  ✓ NNlib → NNlibSpecialFunctionsExt
    913.8 ms  ✓ NNlib → NNlibForwardDiffExt
   5473.6 ms  ✓ LuxLib
   9660.4 ms  ✓ Lux
  17 dependencies successfully precompiled in 35 seconds. 88 already precompiled.
Precompiling MLUtils...
    691.3 ms  ✓ Accessors → StaticArraysExt
    776.1 ms  ✓ BangBang → BangBangStaticArraysExt
    961.9 ms  ✓ KernelAbstractions → SparseArraysExt
   6010.7 ms  ✓ MLUtils
  4 dependencies successfully precompiled in 7 seconds. 93 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1479.4 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 101 already precompiled.
Precompiling LuxMLUtilsExt...
   2069.0 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 164 already precompiled.
Precompiling StructArraysStaticArraysExt...
    657.0 ms  ✓ StructArrays → StructArraysStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 19 already precompiled.
Precompiling StructArraysGPUArraysCoreExt...
    689.3 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
  1 dependency successfully precompiled in 1 seconds. 34 already precompiled.
Precompiling LuxZygoteExt...
   2734.9 ms  ✓ Lux → LuxZygoteExt
  1 dependency successfully precompiled in 3 seconds. 143 already precompiled.
Precompiling OneHotArrays...
    954.0 ms  ✓ OneHotArrays
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
    733.9 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
  1 dependency successfully precompiled in 1 seconds. 38 already precompiled.
Precompiling Reactant...
   1500.9 ms  ✓ Enzyme_jll
  95002.8 ms  ✓ Enzyme
   6801.3 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
  91683.4 ms  ✓ Reactant
  4 dependencies successfully precompiled in 195 seconds. 76 already precompiled.
Precompiling LuxLibEnzymeExt...
   7111.4 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   7094.4 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
  15241.3 ms  ✓ Enzyme → EnzymeStaticArraysExt
   1317.1 ms  ✓ LuxLib → LuxLibEnzymeExt
  16134.4 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
  5 dependencies successfully precompiled in 16 seconds. 129 already precompiled.
Precompiling LuxEnzymeExt...
   8471.5 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 9 seconds. 149 already precompiled.
Precompiling OptimisersReactantExt...
  19114.9 ms  ✓ Reactant → ReactantStatisticsExt
  21934.2 ms  ✓ Optimisers → OptimisersReactantExt
  2 dependencies successfully precompiled in 22 seconds. 88 already precompiled.
Precompiling LuxCoreReactantExt...
  18569.9 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 19 seconds. 85 already precompiled.
Precompiling MLDataDevicesReactantExt...
  18460.0 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 19 seconds. 82 already precompiled.
Precompiling LuxLibReactantExt...
  17286.9 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  18458.4 ms  ✓ LuxLib → LuxLibReactantExt
  19331.7 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  17842.2 ms  ✓ Reactant → ReactantArrayInterfaceExt
  4 dependencies successfully precompiled in 36 seconds. 158 already precompiled.
Precompiling WeightInitializersReactantExt...
  17921.5 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  1 dependency successfully precompiled in 18 seconds. 96 already precompiled.
Precompiling ReactantAbstractFFTsExt...
  17251.7 ms  ✓ Reactant → ReactantAbstractFFTsExt
  1 dependency successfully precompiled in 18 seconds. 82 already precompiled.
Precompiling ReactantOneHotArraysExt...
  18177.3 ms  ✓ Reactant → ReactantOneHotArraysExt
  1 dependency successfully precompiled in 19 seconds. 104 already precompiled.
Precompiling ReactantNNlibExt...
  20706.3 ms  ✓ Reactant → ReactantNNlibExt
  1 dependency successfully precompiled in 21 seconds. 103 already precompiled.
Precompiling LuxReactantExt...
  12936.8 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 13 seconds. 180 already precompiled.
Precompiling MLDatasets...
   2320.9 ms  ✓ AtomsBase
   2410.3 ms  ✓ Chemfiles
  35353.3 ms  ✓ JLD2
  10970.5 ms  ✓ MLDatasets
  4 dependencies successfully precompiled in 47 seconds. 200 already precompiled.
Precompiling SimpleChains...
   6293.0 ms  ✓ SimpleChains
  1 dependency successfully precompiled in 7 seconds. 68 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
   2410.8 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
  1 dependency successfully precompiled in 3 seconds. 93 already precompiled.
Precompiling ReactantOffsetArraysExt...
  17501.6 ms  ✓ Reactant → ReactantOffsetArraysExt
  1 dependency successfully precompiled in 18 seconds. 82 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
   5145.6 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
  1 dependency successfully precompiled in 5 seconds. 101 already precompiled.
Precompiling LuxSimpleChainsExt...
   2111.7 ms  ✓ Lux → LuxSimpleChainsExt
  1 dependency successfully precompiled in 2 seconds. 122 already precompiled.
2025-07-16 22:33:30.907947: I external/xla/xla/service/service.cc:153] XLA service 0xe928220 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-16 22:33:30.908038: I external/xla/xla/service/service.cc:161]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752705210.909225  482161 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752705210.909365  482161 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752705210.909454  482161 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752705210.923050  482161 cuda_dnn.cc:471] Loaded cuDNN version 90800

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split=:train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true, partial=false),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false, partial=false),
    )
end
loadmnist (generic function with 1 method)

Define the Model

julia
lux_model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)
Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
    layer_4 = MaxPool((2, 2)),
    layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
        layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
        layer_3 = Dense(84 => 10),      # 850 parameters
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.

julia
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
    Chain(
        layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
        layer_2 = MaxPool((2, 2)),
        layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
        layer_4 = MaxPool((2, 2)),
        layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
        layer_6 = Chain(
            layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
            layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
            layer_3 = Dense(84 => 10),  # 850 parameters
        ),
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

Helper Functions

julia
const lossfn = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(Array(y))
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Define the Training Loop

julia
function train(model, dev=cpu_device(); rng=Random.default_rng(), kwargs...)
    train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))
    ps, st = dev(Lux.setup(rng, model))

    vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    if dev isa ReactantDevice
        x_ra = first(test_dataloader)[1]
        model_compiled = Reactant.with_config(;
            dot_general_precision=PrecisionConfig.HIGH,
            convolution_precision=PrecisionConfig.HIGH,
        ) do
            @compile model(x_ra, ps, Lux.testmode(st))
        end
    else
        model_compiled = model
    end

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, train_state = Training.single_train_step!(
                vjp, lossfn, (x, y), train_state
            )
        end
        ttime = time() - stime

        tr_acc =
            accuracy(
                model_compiled, train_state.parameters, train_state.states, train_dataloader
            ) * 100
        te_acc =
            accuracy(
                model_compiled, train_state.parameters, train_state.states, test_dataloader
            ) * 100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end
train (generic function with 2 methods)

Finally Training the Model

First we will train the Lux model

julia
tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] 	 Time 398.39s 	 Training Accuracy: 12.66% 	 Test Accuracy: 11.72%
[ 2/10] 	 Time 0.10s 	 Training Accuracy: 23.28% 	 Test Accuracy: 20.31%
[ 3/10] 	 Time 0.10s 	 Training Accuracy: 35.94% 	 Test Accuracy: 29.69%
[ 4/10] 	 Time 0.22s 	 Training Accuracy: 48.75% 	 Test Accuracy: 35.94%
[ 5/10] 	 Time 0.09s 	 Training Accuracy: 57.11% 	 Test Accuracy: 47.66%
[ 6/10] 	 Time 0.10s 	 Training Accuracy: 62.42% 	 Test Accuracy: 52.34%
[ 7/10] 	 Time 0.09s 	 Training Accuracy: 68.75% 	 Test Accuracy: 60.16%
[ 8/10] 	 Time 0.09s 	 Training Accuracy: 71.48% 	 Test Accuracy: 66.41%
[ 9/10] 	 Time 0.10s 	 Training Accuracy: 74.38% 	 Test Accuracy: 72.66%
[10/10] 	 Time 0.09s 	 Training Accuracy: 76.80% 	 Test Accuracy: 75.78%

Now we will train the SimpleChains model

julia
tr_acc, te_acc = train(simple_chains_model)
[ 1/10] 	 Time 946.20s 	 Training Accuracy: 22.19% 	 Test Accuracy: 20.31%
[ 2/10] 	 Time 12.07s 	 Training Accuracy: 38.28% 	 Test Accuracy: 36.72%
[ 3/10] 	 Time 12.11s 	 Training Accuracy: 52.19% 	 Test Accuracy: 47.66%
[ 4/10] 	 Time 12.07s 	 Training Accuracy: 67.66% 	 Test Accuracy: 64.84%
[ 5/10] 	 Time 12.03s 	 Training Accuracy: 75.47% 	 Test Accuracy: 71.88%
[ 6/10] 	 Time 12.01s 	 Training Accuracy: 79.06% 	 Test Accuracy: 78.91%
[ 7/10] 	 Time 12.01s 	 Training Accuracy: 83.52% 	 Test Accuracy: 80.47%
[ 8/10] 	 Time 12.03s 	 Training Accuracy: 83.98% 	 Test Accuracy: 76.56%
[ 9/10] 	 Time 12.14s 	 Training Accuracy: 86.41% 	 Test Accuracy: 82.81%
[10/10] 	 Time 12.04s 	 Training Accuracy: 87.27% 	 Test Accuracy: 82.81%

On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.