Skip to content

MNIST Classification with SimpleChains

SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.

Package Imports

julia
using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf
using MLDatasets: MNIST
using SimpleChains: SimpleChains
Precompiling LuxSimpleChainsExt...
   2369.5 ms  ✓ Lux → LuxSimpleChainsExt
  1 dependency successfully precompiled in 3 seconds. 126 already precompiled.

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split=:train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end
loadmnist (generic function with 1 method)

Define the Model

julia
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)))
Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
    layer_4 = MaxPool((2, 2)),
    layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
        layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
        layer_3 = Dense(84 => 10),      # 850 parameters
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.

julia
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
    Chain(
        layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
        layer_2 = MaxPool((2, 2)),
        layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
        layer_4 = MaxPool((2, 2)),
        layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
        layer_6 = Chain(
            layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
            layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
            layer_3 = Dense(84 => 10),  # 850 parameters
        ),
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

Helper Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Define the Training Loop

julia
function train(model; rng=Xoshiro(0), kwargs...)
    train_dataloader, test_dataloader = loadmnist(128, 0.9)
    ps, st = Lux.setup(rng, model)

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    ### Warmup the model
    x_proto = randn(rng, Float32, 28, 28, 1, 1)
    y_proto = onehotbatch([1], 0:9)
    Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            gs, _, _, train_state = Training.single_train_step!(
                AutoZygote(), loss, (x, y), train_state)
        end
        ttime = time() - stime

        tr_acc = accuracy(
            model, train_state.parameters, train_state.states, train_dataloader) * 100
        te_acc = accuracy(
            model, train_state.parameters, train_state.states, test_dataloader) * 100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end
train (generic function with 1 method)

Finally Training the Model

First we will train the Lux model

julia
tr_acc, te_acc = train(lux_model)
[ 1/10] 	 Time 67.11s 	 Training Accuracy: 17.70% 	 Test Accuracy: 17.33%
[ 2/10] 	 Time 76.19s 	 Training Accuracy: 37.41% 	 Test Accuracy: 30.00%
[ 3/10] 	 Time 60.52s 	 Training Accuracy: 51.48% 	 Test Accuracy: 46.67%
[ 4/10] 	 Time 64.75s 	 Training Accuracy: 60.96% 	 Test Accuracy: 54.00%
[ 5/10] 	 Time 65.85s 	 Training Accuracy: 68.52% 	 Test Accuracy: 60.00%
[ 6/10] 	 Time 55.95s 	 Training Accuracy: 73.70% 	 Test Accuracy: 66.00%
[ 7/10] 	 Time 59.55s 	 Training Accuracy: 77.26% 	 Test Accuracy: 67.33%
[ 8/10] 	 Time 57.79s 	 Training Accuracy: 79.33% 	 Test Accuracy: 72.67%
[ 9/10] 	 Time 58.76s 	 Training Accuracy: 81.85% 	 Test Accuracy: 75.33%
[10/10] 	 Time 67.79s 	 Training Accuracy: 83.33% 	 Test Accuracy: 76.67%

Now we will train the SimpleChains model

julia
train(simple_chains_model)
[ 1/10] 	 Time 14.22s 	 Training Accuracy: 33.19% 	 Test Accuracy: 26.67%
[ 2/10] 	 Time 12.90s 	 Training Accuracy: 41.63% 	 Test Accuracy: 33.33%
[ 3/10] 	 Time 12.89s 	 Training Accuracy: 54.37% 	 Test Accuracy: 47.33%
[ 4/10] 	 Time 12.94s 	 Training Accuracy: 59.33% 	 Test Accuracy: 54.67%
[ 5/10] 	 Time 12.90s 	 Training Accuracy: 68.07% 	 Test Accuracy: 60.67%
[ 6/10] 	 Time 12.89s 	 Training Accuracy: 73.19% 	 Test Accuracy: 68.67%
[ 7/10] 	 Time 12.90s 	 Training Accuracy: 75.41% 	 Test Accuracy: 73.33%
[ 8/10] 	 Time 12.90s 	 Training Accuracy: 77.85% 	 Test Accuracy: 72.67%
[ 9/10] 	 Time 12.89s 	 Training Accuracy: 82.59% 	 Test Accuracy: 80.00%
[10/10] 	 Time 12.89s 	 Training Accuracy: 85.63% 	 Test Accuracy: 80.67%

On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.