Skip to content

MNIST Classification with SimpleChains

SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.

Package Imports

julia
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains

Reactant.set_default_backend("cpu")
Precompiling Reactant...
   1050.9 ms  ✓ CUDA_Driver_jll
   2228.4 ms  ✓ Reactant_jll
  75211.7 ms  ✓ Reactant
  3 dependencies successfully precompiled in 79 seconds. 61 already precompiled.
Precompiling LuxCoreReactantExt...
  12014.1 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 12 seconds. 69 already precompiled.
Precompiling MLDataDevicesReactantExt...
  12352.1 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 13 seconds. 66 already precompiled.
Precompiling LuxLibReactantExt...
  12534.3 ms  ✓ Reactant → ReactantStatisticsExt
  13706.1 ms  ✓ Reactant → ReactantNNlibExt
  14093.4 ms  ✓ LuxLib → LuxLibReactantExt
  13797.1 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  13010.5 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  12864.1 ms  ✓ Reactant → ReactantArrayInterfaceExt
  6 dependencies successfully precompiled in 27 seconds. 143 already precompiled.
Precompiling WeightInitializersReactantExt...
  12293.7 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  1 dependency successfully precompiled in 13 seconds. 80 already precompiled.
Precompiling ReactantAbstractFFTsExt...
  12277.8 ms  ✓ Reactant → ReactantAbstractFFTsExt
  1 dependency successfully precompiled in 13 seconds. 65 already precompiled.
Precompiling LuxReactantExt...
  10189.8 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 11 seconds. 166 already precompiled.
Precompiling MLDatasets...
    624.8 ms  ✓ ZipFile
    808.3 ms  ✓ StructTypes
   1136.7 ms  ✓ MbedTLS
    530.4 ms  ✓ ExceptionUnwrapping
    634.8 ms  ✓ Chemfiles_jll
   2151.6 ms  ✓ ColorVectorSpace
    655.6 ms  ✓ libaec_jll
   1061.1 ms  ✓ FilePathsBase
   1904.3 ms  ✓ OpenSSL
   1193.9 ms  ✓ OpenMPI_jll
   4481.0 ms  ✓ FileIO
    563.3 ms  ✓ StringEncodings
    803.9 ms  ✓ WeakRefStrings
  10344.6 ms  ✓ JSON3
   3649.6 ms  ✓ ColorSchemes
  21360.4 ms  ✓ Unitful
    543.7 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1281.4 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   1886.1 ms  ✓ HDF5_jll
   1522.0 ms  ✓ NPZ
  19575.7 ms  ✓ ImageCore
   2313.2 ms  ✓ Pickle
    566.9 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    586.5 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   2919.0 ms  ✓ UnitfulAtomic
   2309.8 ms  ✓ PeriodicTable
    620.8 ms  ✓ Accessors → UnitfulExt
  19194.1 ms  ✓ HTTP
   7425.1 ms  ✓ HDF5
   2067.3 ms  ✓ ImageBase
   2213.0 ms  ✓ AtomsBase
  17147.2 ms  ✓ CSV
   3090.6 ms  ✓ DataDeps
   1935.0 ms  ✓ FileIO → HTTPExt
   1891.4 ms  ✓ ImageShow
   2422.4 ms  ✓ MAT
  33652.1 ms  ✓ JLD2
   2322.0 ms  ✓ Chemfiles
   9027.8 ms  ✓ MLDatasets
  39 dependencies successfully precompiled in 70 seconds. 161 already precompiled.
Precompiling ReactantOffsetArraysExt...
  12422.7 ms  ✓ Reactant → ReactantOffsetArraysExt
  1 dependency successfully precompiled in 13 seconds. 66 already precompiled.
2025-03-07 22:03:58.056525: I external/xla/xla/service/service.cc:152] XLA service 0xa586010 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-07 22:03:58.056741: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741385038.057451  558870 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741385038.057518  558870 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741385038.057547  558870 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741385038.081646  558870 cuda_dnn.cc:529] Loaded cuDNN version 90400

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split = :train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at = train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle = true, partial = false),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle = false, partial = false),
    )
end
loadmnist (generic function with 1 method)

Define the Model

julia
lux_model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 10)
    )
)
Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
    layer_4 = MaxPool((2, 2)),
    layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
        layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
        layer_3 = Dense(84 => 10),      # 850 parameters
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.

julia
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
    Chain(
        layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
        layer_2 = MaxPool((2, 2)),
        layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
        layer_4 = MaxPool((2, 2)),
        layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
        layer_6 = Chain(
            layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
            layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
            layer_3 = Dense(84 => 10),  # 850 parameters
        ),
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

Helper Functions

julia
const lossfn = CrossEntropyLoss(; logits = Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Define the Training Loop

julia
function train(model, dev = cpu_device(); rng = Random.default_rng(), kwargs...)
    train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev
    ps, st = Lux.setup(rng, model) |> dev

    vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    if dev isa ReactantDevice
        x_ra = first(test_dataloader)[1]
        model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
    else
        model_compiled = model
    end

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, train_state = Training.single_train_step!(
                vjp, lossfn, (x, y), train_state
            )
        end
        ttime = time() - stime

        tr_acc = accuracy(
            model_compiled, train_state.parameters, train_state.states, train_dataloader
        ) *
            100
        te_acc = accuracy(
            model_compiled, train_state.parameters, train_state.states, test_dataloader
        ) *
            100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end
train (generic function with 2 methods)

Finally Training the Model

First we will train the Lux model

julia
tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] 	 Time 458.64s 	 Training Accuracy: 16.88% 	 Test Accuracy: 13.28%
[ 2/10] 	 Time 0.39s 	 Training Accuracy: 30.47% 	 Test Accuracy: 25.00%
[ 3/10] 	 Time 0.37s 	 Training Accuracy: 41.33% 	 Test Accuracy: 36.72%
[ 4/10] 	 Time 0.38s 	 Training Accuracy: 55.47% 	 Test Accuracy: 43.75%
[ 5/10] 	 Time 0.38s 	 Training Accuracy: 63.67% 	 Test Accuracy: 56.25%
[ 6/10] 	 Time 0.39s 	 Training Accuracy: 69.92% 	 Test Accuracy: 64.84%
[ 7/10] 	 Time 0.39s 	 Training Accuracy: 74.30% 	 Test Accuracy: 70.31%
[ 8/10] 	 Time 0.39s 	 Training Accuracy: 78.05% 	 Test Accuracy: 75.78%
[ 9/10] 	 Time 0.40s 	 Training Accuracy: 80.16% 	 Test Accuracy: 74.22%
[10/10] 	 Time 0.39s 	 Training Accuracy: 82.11% 	 Test Accuracy: 76.56%

Now we will train the SimpleChains model

julia
tr_acc, te_acc = train(simple_chains_model)
[ 1/10] 	 Time 919.45s 	 Training Accuracy: 26.95% 	 Test Accuracy: 19.53%
[ 2/10] 	 Time 12.46s 	 Training Accuracy: 40.62% 	 Test Accuracy: 35.16%
[ 3/10] 	 Time 12.47s 	 Training Accuracy: 52.34% 	 Test Accuracy: 50.78%
[ 4/10] 	 Time 12.44s 	 Training Accuracy: 63.28% 	 Test Accuracy: 61.72%
[ 5/10] 	 Time 12.41s 	 Training Accuracy: 74.06% 	 Test Accuracy: 70.31%
[ 6/10] 	 Time 12.39s 	 Training Accuracy: 79.77% 	 Test Accuracy: 73.44%
[ 7/10] 	 Time 12.40s 	 Training Accuracy: 84.06% 	 Test Accuracy: 78.12%
[ 8/10] 	 Time 12.51s 	 Training Accuracy: 85.86% 	 Test Accuracy: 82.03%
[ 9/10] 	 Time 12.46s 	 Training Accuracy: 86.48% 	 Test Accuracy: 82.81%
[10/10] 	 Time 12.43s 	 Training Accuracy: 88.98% 	 Test Accuracy: 82.81%

On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.