Skip to content

MNIST Classification with SimpleChains

SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.

Package Imports

julia
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains

Reactant.set_default_backend("cpu")
Precompiling Lux...
    412.4 ms  ✓ ConcreteStructs
   2623.4 ms  ✓ WeightInitializers
    949.3 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   9467.4 ms  ✓ Lux
  4 dependencies successfully precompiled in 14 seconds. 106 already precompiled.
Precompiling MLUtils...
    342.0 ms  ✓ PtrArrays
    391.8 ms  ✓ StatsAPI
    468.0 ms  ✓ InverseFunctions
    458.6 ms  ✓ DelimitedFiles
    475.3 ms  ✓ AliasTables
   1177.4 ms  ✓ SimpleTraits
   1343.6 ms  ✓ SplittablesBase
    642.7 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    449.5 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    376.9 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    464.9 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   1173.0 ms  ✓ MLCore
   2400.9 ms  ✓ StatsBase
   2350.4 ms  ✓ Accessors
    693.4 ms  ✓ Accessors → TestExt
    770.3 ms  ✓ Accessors → StaticArraysExt
    893.0 ms  ✓ Accessors → LinearAlgebraExt
    766.4 ms  ✓ BangBang
    514.0 ms  ✓ BangBang → BangBangChainRulesCoreExt
    532.0 ms  ✓ BangBang → BangBangTablesExt
    702.8 ms  ✓ BangBang → BangBangStaticArraysExt
   1080.1 ms  ✓ MicroCollections
   2750.6 ms  ✓ Transducers
    693.4 ms  ✓ Transducers → TransducersAdaptExt
   5310.9 ms  ✓ FLoops
   5945.2 ms  ✓ MLUtils
  26 dependencies successfully precompiled in 22 seconds. 76 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1597.3 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 106 already precompiled.
Precompiling LuxMLUtilsExt...
   2110.8 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 169 already precompiled.
Precompiling StructArraysExt...
    502.4 ms  ✓ Accessors → StructArraysExt
  1 dependency successfully precompiled in 1 seconds. 20 already precompiled.
Precompiling BangBangStructArraysExt...
    516.9 ms  ✓ BangBang → BangBangStructArraysExt
  1 dependency successfully precompiled in 1 seconds. 24 already precompiled.
Precompiling LuxZygoteExt...
   2742.5 ms  ✓ Lux → LuxZygoteExt
  1 dependency successfully precompiled in 3 seconds. 148 already precompiled.
Precompiling OneHotArrays...
   1086.0 ms  ✓ OneHotArrays
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
    732.5 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
  1 dependency successfully precompiled in 1 seconds. 38 already precompiled.
Precompiling Reactant...
    391.9 ms  ✓ EnumX
    386.2 ms  ✓ SimpleBufferStream
    638.1 ms  ✓ URIs
    408.0 ms  ✓ BitFlags
    544.6 ms  ✓ TranscodingStreams
    485.8 ms  ✓ CodecZlib
   1890.8 ms  ✓ OpenSSL
  18260.2 ms  ✓ HTTP
  74212.9 ms  ✓ Reactant
  9 dependencies successfully precompiled in 95 seconds. 68 already precompiled.
Precompiling LuxEnzymeExt...
   7149.4 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 8 seconds. 148 already precompiled.
Precompiling LuxCoreReactantExt...
  12987.8 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 13 seconds. 82 already precompiled.
Precompiling MLDataDevicesReactantExt...
  13106.5 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 13 seconds. 79 already precompiled.
Precompiling WeightInitializersReactantExt...
  13489.5 ms  ✓ Reactant → ReactantStatisticsExt
  13638.9 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  13984.2 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  3 dependencies successfully precompiled in 14 seconds. 91 already precompiled.
Precompiling ReactantAbstractFFTsExt...
  12868.2 ms  ✓ Reactant → ReactantAbstractFFTsExt
  1 dependency successfully precompiled in 13 seconds. 79 already precompiled.
Precompiling ReactantKernelAbstractionsExt...
  13686.3 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  1 dependency successfully precompiled in 14 seconds. 89 already precompiled.
Precompiling ReactantArrayInterfaceExt...
  13804.7 ms  ✓ Reactant → ReactantArrayInterfaceExt
  1 dependency successfully precompiled in 14 seconds. 80 already precompiled.
Precompiling ReactantNNlibExt...
  13922.2 ms  ✓ Reactant → ReactantNNlibExt
  1 dependency successfully precompiled in 14 seconds. 102 already precompiled.
Precompiling LuxReactantExt...
  11800.7 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 12 seconds. 178 already precompiled.
Precompiling MLDatasets...
    457.2 ms  ✓ TensorCore
    425.5 ms  ✓ LazyModules
    435.8 ms  ✓ MappedArrays
    960.4 ms  ✓ OffsetArrays
    704.3 ms  ✓ GZip
    656.2 ms  ✓ ZipFile
    666.8 ms  ✓ Unitful → InverseFunctionsUnitfulExt
    691.1 ms  ✓ Accessors → UnitfulExt
    550.4 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
    908.6 ms  ✓ WeakRefStrings
   1268.7 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   2403.7 ms  ✓ AtomsBase
   1423.9 ms  ✓ Transducers → TransducersDataFramesExt
   2023.0 ms  ✓ HDF5_jll
   1717.9 ms  ✓ BangBang → BangBangDataFramesExt
   2106.3 ms  ✓ FileIO → HTTPExt
   3399.2 ms  ✓ DataDeps
    453.0 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
   2149.7 ms  ✓ ColorVectorSpace
    481.3 ms  ✓ StackViews
    482.1 ms  ✓ PaddedViews
   1587.6 ms  ✓ NPZ
   2581.5 ms  ✓ Pickle
   2412.5 ms  ✓ Chemfiles
   7551.3 ms  ✓ HDF5
   3656.2 ms  ✓ ColorSchemes
    497.7 ms  ✓ MosaicViews
  17191.7 ms  ✓ CSV
   2491.6 ms  ✓ MAT
  33609.8 ms  ✓ JLD2
  19116.8 ms  ✓ ImageCore
   2086.1 ms  ✓ ImageBase
   1888.1 ms  ✓ ImageShow
   9078.2 ms  ✓ MLDatasets
  34 dependencies successfully precompiled in 62 seconds. 168 already precompiled.
Precompiling SimpleChains...
    522.5 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
    838.1 ms  ✓ HostCPUFeatures
   7608.6 ms  ✓ VectorizationBase
   1026.8 ms  ✓ SLEEFPirates
   1379.4 ms  ✓ VectorizedRNG
    722.2 ms  ✓ VectorizedRNG → VectorizedRNGStaticArraysExt
  27543.8 ms  ✓ LoopVectorization
   1184.0 ms  ✓ LoopVectorization → SpecialFunctionsExt
   1419.0 ms  ✓ LoopVectorization → ForwardDiffExt
   6463.5 ms  ✓ SimpleChains
  10 dependencies successfully precompiled in 45 seconds. 64 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
   2611.4 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
  1 dependency successfully precompiled in 3 seconds. 98 already precompiled.
Precompiling ReactantOffsetArraysExt...
  13542.2 ms  ✓ Reactant → ReactantOffsetArraysExt
  1 dependency successfully precompiled in 14 seconds. 79 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
   3978.0 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
  1 dependency successfully precompiled in 4 seconds. 106 already precompiled.
Precompiling LuxSimpleChainsExt...
   1953.5 ms  ✓ Lux → LuxSimpleChainsExt
  1 dependency successfully precompiled in 2 seconds. 127 already precompiled.
2025-03-28 04:31:42.942423: I external/xla/xla/service/service.cc:152] XLA service 0x7a3d030 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-28 04:31:42.942548: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1743136302.943404 3356716 se_gpu_pjrt_client.cc:1039] Using BFC allocator.
I0000 00:00:1743136302.943481 3356716 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1743136302.943529 3356716 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1743136302.955986 3356716 cuda_dnn.cc:529] Loaded cuDNN version 90400

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split=:train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true, partial=false),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false, partial=false),
    )
end
loadmnist (generic function with 1 method)

Define the Model

julia
lux_model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)
Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
    layer_4 = MaxPool((2, 2)),
    layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
        layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
        layer_3 = Dense(84 => 10),      # 850 parameters
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.

julia
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
    Chain(
        layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
        layer_2 = MaxPool((2, 2)),
        layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
        layer_4 = MaxPool((2, 2)),
        layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
        layer_6 = Chain(
            layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
            layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
            layer_3 = Dense(84 => 10),  # 850 parameters
        ),
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

Helper Functions

julia
const lossfn = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(Array(y))
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Define the Training Loop

julia
function train(model, dev=cpu_device(); rng=Random.default_rng(), kwargs...)
    train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))
    ps, st = dev(Lux.setup(rng, model))

    vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    if dev isa ReactantDevice
        x_ra = first(test_dataloader)[1]
        model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
    else
        model_compiled = model
    end

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, train_state = Training.single_train_step!(
                vjp, lossfn, (x, y), train_state
            )
        end
        ttime = time() - stime

        tr_acc =
            accuracy(
                model_compiled, train_state.parameters, train_state.states, train_dataloader
            ) * 100
        te_acc =
            accuracy(
                model_compiled, train_state.parameters, train_state.states, test_dataloader
            ) * 100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end
train (generic function with 2 methods)

Finally Training the Model

First we will train the Lux model

julia
tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] 	 Time 267.57s 	 Training Accuracy: 17.27% 	 Test Accuracy: 18.75%
[ 2/10] 	 Time 0.21s 	 Training Accuracy: 35.08% 	 Test Accuracy: 32.03%
[ 3/10] 	 Time 0.20s 	 Training Accuracy: 51.88% 	 Test Accuracy: 45.31%
[ 4/10] 	 Time 0.21s 	 Training Accuracy: 65.78% 	 Test Accuracy: 54.69%
[ 5/10] 	 Time 0.20s 	 Training Accuracy: 70.78% 	 Test Accuracy: 58.59%
[ 6/10] 	 Time 0.21s 	 Training Accuracy: 75.62% 	 Test Accuracy: 67.19%
[ 7/10] 	 Time 0.21s 	 Training Accuracy: 78.75% 	 Test Accuracy: 75.78%
[ 8/10] 	 Time 0.20s 	 Training Accuracy: 80.94% 	 Test Accuracy: 78.12%
[ 9/10] 	 Time 0.21s 	 Training Accuracy: 84.22% 	 Test Accuracy: 81.25%
[10/10] 	 Time 0.22s 	 Training Accuracy: 86.56% 	 Test Accuracy: 82.03%

Now we will train the SimpleChains model

julia
tr_acc, te_acc = train(simple_chains_model)
[ 1/10] 	 Time 870.34s 	 Training Accuracy: 12.42% 	 Test Accuracy: 12.50%
[ 2/10] 	 Time 12.24s 	 Training Accuracy: 36.72% 	 Test Accuracy: 32.81%
[ 3/10] 	 Time 12.36s 	 Training Accuracy: 55.62% 	 Test Accuracy: 50.00%
[ 4/10] 	 Time 12.23s 	 Training Accuracy: 69.06% 	 Test Accuracy: 60.16%
[ 5/10] 	 Time 12.21s 	 Training Accuracy: 76.64% 	 Test Accuracy: 71.09%
[ 6/10] 	 Time 12.20s 	 Training Accuracy: 78.75% 	 Test Accuracy: 73.44%
[ 7/10] 	 Time 12.22s 	 Training Accuracy: 81.02% 	 Test Accuracy: 78.12%
[ 8/10] 	 Time 12.26s 	 Training Accuracy: 81.25% 	 Test Accuracy: 78.91%
[ 9/10] 	 Time 12.22s 	 Training Accuracy: 84.06% 	 Test Accuracy: 81.25%
[10/10] 	 Time 12.27s 	 Training Accuracy: 85.78% 	 Test Accuracy: 83.59%

On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.