Skip to content

MNIST Classification with SimpleChains

SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.

Package Imports

julia
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains

Reactant.set_default_backend("cpu")
Precompiling Lux...
   9548.5 ms  ✓ Lux
  1 dependency successfully precompiled in 10 seconds. 104 already precompiled.
Precompiling LuxMLUtilsExt...
   2207.7 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 164 already precompiled.
Precompiling LuxZygoteExt...
   2605.0 ms  ✓ Lux → LuxZygoteExt
  1 dependency successfully precompiled in 3 seconds. 143 already precompiled.
Precompiling Reactant...
    384.2 ms  ✓ EnumX
    653.3 ms  ✓ URIs
    736.0 ms  ✓ ExpressionExplorer
    441.5 ms  ✓ SimpleBufferStream
    407.2 ms  ✓ BitFlags
    757.6 ms  ✓ ConcurrentUtilities
    555.8 ms  ✓ LoggingExtras
   1101.0 ms  ✓ MbedTLS
    532.1 ms  ✓ ExceptionUnwrapping
    639.9 ms  ✓ LLVMOpenMP_jll
    654.0 ms  ✓ ReactantCore
   1569.9 ms  ✓ CUDA_Driver_jll
   1954.0 ms  ✓ OpenSSL
   2772.8 ms  ✓ Reactant_jll
  19045.6 ms  ✓ HTTP
  96167.7 ms  ✓ Reactant
  16 dependencies successfully precompiled in 120 seconds. 64 already precompiled.
Precompiling LuxEnzymeExt...
   8608.1 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 9 seconds. 149 already precompiled.
Precompiling OptimisersReactantExt...
  21700.0 ms  ✓ Reactant → ReactantStatisticsExt
  24261.7 ms  ✓ Optimisers → OptimisersReactantExt
  2 dependencies successfully precompiled in 25 seconds. 88 already precompiled.
Precompiling LuxCoreReactantExt...
  21675.0 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 22 seconds. 85 already precompiled.
Precompiling MLDataDevicesReactantExt...
  22802.7 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 23 seconds. 82 already precompiled.
Precompiling WeightInitializersReactantExt...
  21893.0 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  22349.3 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  2 dependencies successfully precompiled in 23 seconds. 95 already precompiled.
Precompiling ReactantAbstractFFTsExt...
  22757.8 ms  ✓ Reactant → ReactantAbstractFFTsExt
  1 dependency successfully precompiled in 23 seconds. 82 already precompiled.
Precompiling ReactantKernelAbstractionsExt...
  22588.3 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  1 dependency successfully precompiled in 23 seconds. 92 already precompiled.
Precompiling ReactantArrayInterfaceExt...
  22545.4 ms  ✓ Reactant → ReactantArrayInterfaceExt
  1 dependency successfully precompiled in 23 seconds. 83 already precompiled.
Precompiling ReactantNNlibExt...
  22087.9 ms  ✓ Reactant → ReactantNNlibExt
  1 dependency successfully precompiled in 22 seconds. 103 already precompiled.
Precompiling LuxReactantExt...
  12590.2 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 13 seconds. 180 already precompiled.
Precompiling MLDatasets...
    401.3 ms  ✓ LaTeXStrings
    387.4 ms  ✓ Glob
    448.5 ms  ✓ TensorCore
    413.2 ms  ✓ WorkerUtilities
    458.8 ms  ✓ BufferedStreams
    464.9 ms  ✓ PaddedViews
    393.0 ms  ✓ LazyModules
    382.5 ms  ✓ InvertedIndices
    679.1 ms  ✓ InlineStrings
    344.3 ms  ✓ PackageExtensionCompat
    445.6 ms  ✓ MappedArrays
    702.5 ms  ✓ GZip
    435.1 ms  ✓ StackViews
   1134.7 ms  ✓ Crayons
    625.9 ms  ✓ ZipFile
    562.4 ms  ✓ BFloat16s
    493.3 ms  ✓ PooledArrays
    808.8 ms  ✓ StructTypes
   1344.3 ms  ✓ SentinelArrays
   2616.2 ms  ✓ FixedPointNumbers
    595.9 ms  ✓ MPIPreferences
    374.2 ms  ✓ InternedStrings
    604.0 ms  ✓ Chemfiles_jll
    670.8 ms  ✓ libaec_jll
    624.5 ms  ✓ Hwloc_jll
    496.0 ms  ✓ MicrosoftMPI_jll
   1068.9 ms  ✓ FilePathsBase
   1966.0 ms  ✓ StringManipulation
    583.4 ms  ✓ StringEncodings
   4316.7 ms  ✓ FileIO
   3285.5 ms  ✓ DataDeps
    556.0 ms  ✓ InlineStrings → ParsersExt
    465.4 ms  ✓ StridedViews
    541.1 ms  ✓ MosaicViews
   1488.6 ms  ✓ ColorTypes
   1181.3 ms  ✓ MPItrampoline_jll
   1155.5 ms  ✓ OpenMPI_jll
   1504.7 ms  ✓ MPICH_jll
    575.7 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1391.1 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   9559.7 ms  ✓ JSON3
   1963.1 ms  ✓ FileIO → HTTPExt
  20685.5 ms  ✓ Unitful
   1587.8 ms  ✓ NPZ
    836.8 ms  ✓ WeakRefStrings
   2430.6 ms  ✓ Pickle
   2234.2 ms  ✓ ColorVectorSpace
   4860.3 ms  ✓ Colors
   2011.5 ms  ✓ HDF5_jll
    600.5 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    593.2 ms  ✓ Unitful → InverseFunctionsUnitfulExt
  20841.2 ms  ✓ PrettyTables
   2623.5 ms  ✓ UnitfulAtomic
    621.6 ms  ✓ Accessors → UnitfulExt
   2481.5 ms  ✓ PeriodicTable
  34966.4 ms  ✓ JLD2
  18758.1 ms  ✓ CSV
   3516.2 ms  ✓ ColorSchemes
  19918.5 ms  ✓ ImageCore
   2369.6 ms  ✓ AtomsBase
   2194.7 ms  ✓ ImageBase
   7465.0 ms  ✓ HDF5
   2365.1 ms  ✓ Chemfiles
   1929.0 ms  ✓ ImageShow
   2401.3 ms  ✓ MAT
  48301.5 ms  ✓ DataFrames
   1390.1 ms  ✓ Transducers → TransducersDataFramesExt
   1839.0 ms  ✓ BangBang → BangBangDataFramesExt
   9716.9 ms  ✓ MLDatasets
  69 dependencies successfully precompiled in 124 seconds. 133 already precompiled.
Precompiling ReactantOffsetArraysExt...
  21544.3 ms  ✓ Reactant → ReactantOffsetArraysExt
  1 dependency successfully precompiled in 22 seconds. 82 already precompiled.
Precompiling LuxSimpleChainsExt...
   1887.2 ms  ✓ Lux → LuxSimpleChainsExt
  1 dependency successfully precompiled in 2 seconds. 122 already precompiled.
2025-05-23 22:23:27.617145: I external/xla/xla/service/service.cc:152] XLA service 0x3c432a50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-23 22:23:27.617625: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748039007.619100  647856 se_gpu_pjrt_client.cc:1026] Using BFC allocator.
I0000 00:00:1748039007.619491  647856 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1748039007.619854  647856 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1748039007.635854  647856 cuda_dnn.cc:529] Loaded cuDNN version 90400

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split=:train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true, partial=false),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false, partial=false),
    )
end
loadmnist (generic function with 1 method)

Define the Model

julia
lux_model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)
Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
    layer_4 = MaxPool((2, 2)),
    layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
        layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
        layer_3 = Dense(84 => 10),      # 850 parameters
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.

julia
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
    Chain(
        layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
        layer_2 = MaxPool((2, 2)),
        layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
        layer_4 = MaxPool((2, 2)),
        layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
        layer_6 = Chain(
            layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
            layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
            layer_3 = Dense(84 => 10),  # 850 parameters
        ),
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

Helper Functions

julia
const lossfn = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(Array(y))
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Define the Training Loop

julia
function train(model, dev=cpu_device(); rng=Random.default_rng(), kwargs...)
    train_dataloader, test_dataloader = dev(loadmnist(128, 0.9))
    ps, st = dev(Lux.setup(rng, model))

    vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    if dev isa ReactantDevice
        x_ra = first(test_dataloader)[1]
        model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
    else
        model_compiled = model
    end

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, train_state = Training.single_train_step!(
                vjp, lossfn, (x, y), train_state
            )
        end
        ttime = time() - stime

        tr_acc =
            accuracy(
                model_compiled, train_state.parameters, train_state.states, train_dataloader
            ) * 100
        te_acc =
            accuracy(
                model_compiled, train_state.parameters, train_state.states, test_dataloader
            ) * 100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end
train (generic function with 2 methods)

Finally Training the Model

First we will train the Lux model

julia
tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] 	 Time 312.09s 	 Training Accuracy: 16.09% 	 Test Accuracy: 19.53%
[ 2/10] 	 Time 0.23s 	 Training Accuracy: 28.12% 	 Test Accuracy: 23.44%
[ 3/10] 	 Time 0.24s 	 Training Accuracy: 46.72% 	 Test Accuracy: 42.97%
[ 4/10] 	 Time 0.23s 	 Training Accuracy: 57.42% 	 Test Accuracy: 56.25%
[ 5/10] 	 Time 0.23s 	 Training Accuracy: 66.48% 	 Test Accuracy: 63.28%
[ 6/10] 	 Time 0.21s 	 Training Accuracy: 73.05% 	 Test Accuracy: 68.75%
[ 7/10] 	 Time 0.23s 	 Training Accuracy: 76.25% 	 Test Accuracy: 70.31%
[ 8/10] 	 Time 0.22s 	 Training Accuracy: 79.45% 	 Test Accuracy: 74.22%
[ 9/10] 	 Time 0.23s 	 Training Accuracy: 81.25% 	 Test Accuracy: 80.47%
[10/10] 	 Time 0.22s 	 Training Accuracy: 82.89% 	 Test Accuracy: 81.25%

Now we will train the SimpleChains model

julia
tr_acc, te_acc = train(simple_chains_model)
[ 1/10] 	 Time 1081.64s 	 Training Accuracy: 29.77% 	 Test Accuracy: 27.34%
[ 2/10] 	 Time 12.20s 	 Training Accuracy: 50.23% 	 Test Accuracy: 51.56%
[ 3/10] 	 Time 12.19s 	 Training Accuracy: 57.81% 	 Test Accuracy: 54.69%
[ 4/10] 	 Time 12.21s 	 Training Accuracy: 62.97% 	 Test Accuracy: 57.81%
[ 5/10] 	 Time 12.21s 	 Training Accuracy: 70.70% 	 Test Accuracy: 62.50%
[ 6/10] 	 Time 12.17s 	 Training Accuracy: 74.84% 	 Test Accuracy: 67.19%
[ 7/10] 	 Time 12.16s 	 Training Accuracy: 77.89% 	 Test Accuracy: 74.22%
[ 8/10] 	 Time 12.25s 	 Training Accuracy: 80.78% 	 Test Accuracy: 71.88%
[ 9/10] 	 Time 12.16s 	 Training Accuracy: 82.27% 	 Test Accuracy: 79.69%
[10/10] 	 Time 12.18s 	 Training Accuracy: 85.00% 	 Test Accuracy: 78.12%

On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.