Skip to content

MNIST Classification with SimpleChains

SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.

Package Imports

julia
using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf
using MLDatasets: MNIST
using SimpleChains: SimpleChains
Precompiling OneHotArrays...
   1014.6 ms  ✓ OneHotArrays
  1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
    879.4 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
  1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling MLDatasets...
    435.4 ms  ✓ CodecZlib
    697.0 ms  ✓ GZip
    729.3 ms  ✓ ConcurrentUtilities
    589.1 ms  ✓ ZipFile
    492.1 ms  ✓ ExceptionUnwrapping
    785.0 ms  ✓ WeakRefStrings
   2173.5 ms  ✓ ColorVectorSpace
   1229.7 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   2284.6 ms  ✓ AtomsBase
   1905.4 ms  ✓ HDF5_jll
   1607.5 ms  ✓ NPZ
   2479.3 ms  ✓ Pickle
  19497.3 ms  ✓ HTTP
  19758.1 ms  ✓ CSV
  19208.3 ms  ✓ ImageCore
   2510.7 ms  ✓ Chemfiles
   3495.7 ms  ✓ ColorSchemes
   1913.3 ms  ✓ FileIO → HTTPExt
   3134.9 ms  ✓ DataDeps
   2121.4 ms  ✓ ImageBase
   7458.5 ms  ✓ HDF5
   1967.9 ms  ✓ ImageShow
   2430.8 ms  ✓ MAT
   9122.6 ms  ✓ MLDatasets
  24 dependencies successfully precompiled in 44 seconds. 174 already precompiled.
Precompiling SimpleChains...
   1299.7 ms  ✓ VectorizedRNG
   1362.4 ms  ✓ LoopVectorization → ForwardDiffExt
    722.6 ms  ✓ VectorizedRNG → VectorizedRNGStaticArraysExt
   6177.7 ms  ✓ SimpleChains
  4 dependencies successfully precompiled in 8 seconds. 81 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
   2773.1 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
  1 dependency successfully precompiled in 3 seconds. 108 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
   4747.1 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
  1 dependency successfully precompiled in 5 seconds. 116 already precompiled.
Precompiling LuxSimpleChainsExt...
   2144.5 ms  ✓ Lux → LuxSimpleChainsExt
  1 dependency successfully precompiled in 2 seconds. 137 already precompiled.

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
    dataset = MNIST(; split=:train)
    if N !== nothing
        imgs = dataset.features[:, :, 1:N]
        labels_raw = dataset.targets[1:N]
    else
        imgs = dataset.features
        labels_raw = dataset.targets
    end

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end
loadmnist (generic function with 1 method)

Define the Model

julia
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)))
Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
    layer_4 = MaxPool((2, 2)),
    layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
        layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
        layer_3 = Dense(84 => 10),      # 850 parameters
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.

julia
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
    Chain(
        layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
        layer_2 = MaxPool((2, 2)),
        layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
        layer_4 = MaxPool((2, 2)),
        layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
        layer_6 = Chain(
            layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
            layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
            layer_3 = Dense(84 => 10),  # 850 parameters
        ),
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

Helper Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Define the Training Loop

julia
function train(model; rng=Xoshiro(0), kwargs...)
    train_dataloader, test_dataloader = loadmnist(128, 0.9)
    ps, st = Lux.setup(rng, model)

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    ### Warmup the model
    x_proto = randn(rng, Float32, 28, 28, 1, 1)
    y_proto = onehotbatch([1], 0:9)
    Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            gs, _, _, train_state = Training.single_train_step!(
                AutoZygote(), loss, (x, y), train_state)
        end
        ttime = time() - stime

        tr_acc = accuracy(
            model, train_state.parameters, train_state.states, train_dataloader) * 100
        te_acc = accuracy(
            model, train_state.parameters, train_state.states, test_dataloader) * 100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end
train (generic function with 1 method)

Finally Training the Model

First we will train the Lux model

julia
tr_acc, te_acc = train(lux_model)
[ 1/10] 	 Time 76.85s 	 Training Accuracy: 17.78% 	 Test Accuracy: 17.33%
[ 2/10] 	 Time 73.15s 	 Training Accuracy: 37.78% 	 Test Accuracy: 30.67%
[ 3/10] 	 Time 81.08s 	 Training Accuracy: 51.63% 	 Test Accuracy: 44.00%
[ 4/10] 	 Time 83.41s 	 Training Accuracy: 61.48% 	 Test Accuracy: 54.00%
[ 5/10] 	 Time 82.48s 	 Training Accuracy: 67.93% 	 Test Accuracy: 60.67%
[ 6/10] 	 Time 87.14s 	 Training Accuracy: 74.00% 	 Test Accuracy: 64.00%
[ 7/10] 	 Time 86.69s 	 Training Accuracy: 77.19% 	 Test Accuracy: 69.33%
[ 8/10] 	 Time 85.98s 	 Training Accuracy: 79.56% 	 Test Accuracy: 73.33%
[ 9/10] 	 Time 84.37s 	 Training Accuracy: 82.00% 	 Test Accuracy: 74.67%
[10/10] 	 Time 85.09s 	 Training Accuracy: 83.41% 	 Test Accuracy: 80.00%

Now we will train the SimpleChains model

julia
train(simple_chains_model)
[ 1/10] 	 Time 14.56s 	 Training Accuracy: 32.96% 	 Test Accuracy: 22.00%
[ 2/10] 	 Time 12.89s 	 Training Accuracy: 41.26% 	 Test Accuracy: 34.00%
[ 3/10] 	 Time 12.85s 	 Training Accuracy: 54.30% 	 Test Accuracy: 49.33%
[ 4/10] 	 Time 12.86s 	 Training Accuracy: 60.00% 	 Test Accuracy: 56.00%
[ 5/10] 	 Time 12.86s 	 Training Accuracy: 67.11% 	 Test Accuracy: 62.67%
[ 6/10] 	 Time 12.78s 	 Training Accuracy: 72.07% 	 Test Accuracy: 67.33%
[ 7/10] 	 Time 12.77s 	 Training Accuracy: 76.96% 	 Test Accuracy: 70.67%
[ 8/10] 	 Time 12.81s 	 Training Accuracy: 80.52% 	 Test Accuracy: 78.00%
[ 9/10] 	 Time 12.77s 	 Training Accuracy: 82.59% 	 Test Accuracy: 76.00%
[10/10] 	 Time 12.78s 	 Training Accuracy: 85.85% 	 Test Accuracy: 83.33%

On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.