MNIST Classification with SimpleChains

SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.

Package Imports

using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant
using MLDatasets: MNIST
using SimpleChains: SimpleChains

Loading MNIST

function loadmnist(batchsize, train_split)
    # Load MNIST
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split = :train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
        imgs = dataset.features
        labels_raw = dataset.targets

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at = train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle = true, partial = false),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle = false, partial = false),
loadmnist (generic function with 1 method)

Define the Model

lux_model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 10)
    layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
    layer_4 = MaxPool((2, 2)),
    layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
        layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
        layer_3 = Dense(84 => 10),      # 850 parameters
)         # Total: 47_154 parameters,
          #        plus 0 states.

We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.

adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
        layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
        layer_2 = MaxPool((2, 2)),
        layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
        layer_4 = MaxPool((2, 2)),
        layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
        layer_6 = Chain(
            layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
            layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
            layer_3 = Dense(84 => 10),  # 850 parameters
)         # Total: 47_154 parameters,
          #        plus 0 states.

Helper Functions

const lossfn = CrossEntropyLoss(; logits = Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    return total_correct / total
accuracy (generic function with 1 method)

Define the Training Loop

function train(model, dev = cpu_device(); rng = Random.default_rng(), kwargs...)
    train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev
    ps, st = Lux.setup(rng, model) |> dev

    vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    if dev isa ReactantDevice
        x_ra = first(test_dataloader)[1]
        model_compiled = @compile model(x_ra, ps, Lux.testmode(st))
        model_compiled = model

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            _, _, _, train_state = Training.single_train_step!(
                vjp, lossfn, (x, y), train_state
        ttime = time() - stime

        tr_acc = accuracy(
            model_compiled, train_state.parameters, train_state.states, train_dataloader
        ) *
        te_acc = accuracy(
            model_compiled, train_state.parameters, train_state.states, test_dataloader
        ) *

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc

    return tr_acc, te_acc
train (generic function with 2 methods)

Finally Training the Model

First we will train the Lux model

tr_acc, te_acc = train(lux_model, reactant_device())
[ 1/10] 	 Time 458.64s 	 Training Accuracy: 16.88% 	 Test Accuracy: 13.28%
[ 2/10] 	 Time 0.39s 	 Training Accuracy: 30.47% 	 Test Accuracy: 25.00%
[ 3/10] 	 Time 0.37s 	 Training Accuracy: 41.33% 	 Test Accuracy: 36.72%
[ 4/10] 	 Time 0.38s 	 Training Accuracy: 55.47% 	 Test Accuracy: 43.75%
[ 5/10] 	 Time 0.38s 	 Training Accuracy: 63.67% 	 Test Accuracy: 56.25%
[ 6/10] 	 Time 0.39s 	 Training Accuracy: 69.92% 	 Test Accuracy: 64.84%
[ 7/10] 	 Time 0.39s 	 Training Accuracy: 74.30% 	 Test Accuracy: 70.31%
[ 8/10] 	 Time 0.39s 	 Training Accuracy: 78.05% 	 Test Accuracy: 75.78%
[ 9/10] 	 Time 0.40s 	 Training Accuracy: 80.16% 	 Test Accuracy: 74.22%
[10/10] 	 Time 0.39s 	 Training Accuracy: 82.11% 	 Test Accuracy: 76.56%

Now we will train the SimpleChains model

tr_acc, te_acc = train(simple_chains_model)
[ 1/10] 	 Time 919.45s 	 Training Accuracy: 26.95% 	 Test Accuracy: 19.53%
[ 2/10] 	 Time 12.46s 	 Training Accuracy: 40.62% 	 Test Accuracy: 35.16%
[ 3/10] 	 Time 12.47s 	 Training Accuracy: 52.34% 	 Test Accuracy: 50.78%
[ 4/10] 	 Time 12.44s 	 Training Accuracy: 63.28% 	 Test Accuracy: 61.72%
[ 5/10] 	 Time 12.41s 	 Training Accuracy: 74.06% 	 Test Accuracy: 70.31%
[ 6/10] 	 Time 12.39s 	 Training Accuracy: 79.77% 	 Test Accuracy: 73.44%
[ 7/10] 	 Time 12.40s 	 Training Accuracy: 84.06% 	 Test Accuracy: 78.12%
[ 8/10] 	 Time 12.51s 	 Training Accuracy: 85.86% 	 Test Accuracy: 82.03%
[ 9/10] 	 Time 12.46s 	 Training Accuracy: 86.48% 	 Test Accuracy: 82.81%
[10/10] 	 Time 12.43s 	 Training Accuracy: 88.98% 	 Test Accuracy: 82.81%

On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.


using InteractiveUtils

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEBUG = Literate

