Skip to content

MNIST Classification with SimpleChains

SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.

Package Imports

julia
using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf
using MLDatasets: MNIST
using SimpleChains: SimpleChains
Precompiling Lux...
   1228.2 ms  ✓ StructArrays
    575.9 ms  ✓ StructArrays → StructArraysAdaptExt
    869.8 ms  ✓ StructArrays → StructArraysSparseArraysExt
    845.5 ms  ✓ StructArrays → StructArraysStaticArraysExt
    598.8 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
    666.8 ms  ✓ Accessors → AccessorsStructArraysExt
   1911.2 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
    698.5 ms  ✓ BangBang → BangBangStructArraysExt
   4979.2 ms  ✓ GPUArrays
   1931.9 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   2118.2 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
   6271.5 ms  ✓ ChainRules
   1004.1 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   1120.8 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  36797.7 ms  ✓ Zygote
   1865.1 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  10502.9 ms  ✓ Lux
   3539.6 ms  ✓ Lux → LuxMLUtilsExt
   3991.3 ms  ✓ Lux → LuxZygoteExt
  19 dependencies successfully precompiled in 64 seconds. 222 already precompiled.
Precompiling MLDatasets...
    954.6 ms  ✓ GZip
    863.2 ms  ✓ ZipFile
   1001.8 ms  ✓ ConcurrentUtilities
    756.5 ms  ✓ ExceptionUnwrapping
   1189.8 ms  ✓ WeakRefStrings
   1579.8 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   2522.2 ms  ✓ ColorVectorSpace
   1891.0 ms  ✓ HDF5_jll
   1937.5 ms  ✓ AtomsBase
   2774.6 ms  ✓ Pickle
    924.9 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   8616.6 ms  ✓ HDF5
   2674.5 ms  ✓ Chemfiles
  20736.1 ms  ✓ HTTP
  20961.1 ms  ✓ CSV
   3198.4 ms  ✓ MAT
   4535.2 ms  ✓ ColorSchemes
   2367.7 ms  ✓ FileIO → HTTPExt
   3773.5 ms  ✓ DataDeps
   1876.0 ms  ✓ NPZ
  20699.5 ms  ✓ ImageCore
   2385.8 ms  ✓ ImageBase
   2323.6 ms  ✓ ImageShow
  10469.4 ms  ✓ MLDatasets
  24 dependencies successfully precompiled in 55 seconds. 202 already precompiled.
Precompiling LuxSimpleChainsExt...
   3257.6 ms  ✓ Lux → LuxSimpleChainsExt
  1 dependency successfully precompiled in 4 seconds. 244 already precompiled.

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = 2000
    dataset = MNIST(; split=:train)
    imgs = dataset.features[:, :, 1:N]
    labels_raw = dataset.targets[1:N]

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end
loadmnist (generic function with 1 method)

Define the Model

julia
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)))
Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
    layer_4 = MaxPool((2, 2)),
    layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
        layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
        layer_3 = Dense(84 => 10),      # 850 parameters
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.

julia
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
    Chain(
        layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
        layer_2 = MaxPool((2, 2)),
        layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
        layer_4 = MaxPool((2, 2)),
        layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
        layer_6 = Chain(
            layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
            layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
            layer_3 = Dense(84 => 10),  # 850 parameters
        ),
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

Helper Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Define the Training Loop

julia
function train(model; rng=Xoshiro(0), kwargs...)
    train_dataloader, test_dataloader = loadmnist(128, 0.9)
    ps, st = Lux.setup(rng, model)

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    ### Warmup the model
    x_proto = randn(rng, Float32, 28, 28, 1, 1)
    y_proto = onehotbatch([1], 0:9)
    Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            gs, _, _, train_state = Training.single_train_step!(
                AutoZygote(), loss, (x, y), train_state)
        end
        ttime = time() - stime

        tr_acc = accuracy(
            model, train_state.parameters, train_state.states, train_dataloader) * 100
        te_acc = accuracy(
            model, train_state.parameters, train_state.states, test_dataloader) * 100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end
train (generic function with 1 method)

Finally Training the Model

First we will train the Lux model

julia
tr_acc, te_acc = train(lux_model)
[ 1/10] 	 Time 84.94s 	 Training Accuracy: 24.28% 	 Test Accuracy: 21.00%
[ 2/10] 	 Time 76.85s 	 Training Accuracy: 47.89% 	 Test Accuracy: 48.00%
[ 3/10] 	 Time 75.80s 	 Training Accuracy: 60.11% 	 Test Accuracy: 59.00%
[ 4/10] 	 Time 82.42s 	 Training Accuracy: 69.72% 	 Test Accuracy: 64.50%
[ 5/10] 	 Time 88.89s 	 Training Accuracy: 74.89% 	 Test Accuracy: 74.50%
[ 6/10] 	 Time 86.14s 	 Training Accuracy: 78.39% 	 Test Accuracy: 76.50%
[ 7/10] 	 Time 82.69s 	 Training Accuracy: 80.94% 	 Test Accuracy: 79.00%
[ 8/10] 	 Time 81.81s 	 Training Accuracy: 83.83% 	 Test Accuracy: 81.00%
[ 9/10] 	 Time 82.87s 	 Training Accuracy: 84.83% 	 Test Accuracy: 83.00%
[10/10] 	 Time 93.35s 	 Training Accuracy: 86.94% 	 Test Accuracy: 85.50%

Now we will train the SimpleChains model

julia
train(simple_chains_model)
[ 1/10] 	 Time 18.58s 	 Training Accuracy: 44.83% 	 Test Accuracy: 44.50%
[ 2/10] 	 Time 17.20s 	 Training Accuracy: 54.72% 	 Test Accuracy: 48.50%
[ 3/10] 	 Time 17.22s 	 Training Accuracy: 61.67% 	 Test Accuracy: 59.00%
[ 4/10] 	 Time 17.19s 	 Training Accuracy: 69.67% 	 Test Accuracy: 64.00%
[ 5/10] 	 Time 17.21s 	 Training Accuracy: 73.72% 	 Test Accuracy: 71.00%
[ 6/10] 	 Time 17.21s 	 Training Accuracy: 77.44% 	 Test Accuracy: 72.00%
[ 7/10] 	 Time 17.21s 	 Training Accuracy: 81.33% 	 Test Accuracy: 79.50%
[ 8/10] 	 Time 17.20s 	 Training Accuracy: 82.11% 	 Test Accuracy: 80.50%
[ 9/10] 	 Time 17.22s 	 Training Accuracy: 85.83% 	 Test Accuracy: 84.00%
[10/10] 	 Time 17.22s 	 Training Accuracy: 87.17% 	 Test Accuracy: 86.50%

On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.