Skip to content

MNIST Classification with SimpleChains

SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.

Package Imports

julia
using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf
using MLDatasets: MNIST
using SimpleChains: SimpleChains
Precompiling OneHotArrays...
   1863.3 ms  ✓ OneHotArrays
  1 dependency successfully precompiled in 2 seconds. 50 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
   1831.5 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
  1 dependency successfully precompiled in 2 seconds. 57 already precompiled.
Precompiling MLDatasets...
    350.5 ms  ✓ Glob
    376.5 ms  ✓ WorkerUtilities
    403.2 ms  ✓ BufferedStreams
    321.8 ms  ✓ SimpleBufferStream
    301.5 ms  ✓ PackageExtensionCompat
    567.5 ms  ✓ URIs
    338.2 ms  ✓ BitFlags
    572.1 ms  ✓ ZipFile
    709.7 ms  ✓ ConcurrentUtilities
    773.2 ms  ✓ StructTypes
    478.2 ms  ✓ LoggingExtras
   1002.2 ms  ✓ MbedTLS
    324.8 ms  ✓ InternedStrings
    595.5 ms  ✓ MPIPreferences
    466.2 ms  ✓ ExceptionUnwrapping
    580.6 ms  ✓ OpenSSL_jll
    538.9 ms  ✓ Chemfiles_jll
    580.2 ms  ✓ libaec_jll
    469.8 ms  ✓ MicrosoftMPI_jll
    571.5 ms  ✓ Libiconv_jll
   3479.6 ms  ✓ ColorSchemes
   1007.5 ms  ✓ FilePathsBase
    743.5 ms  ✓ WeakRefStrings
    440.0 ms  ✓ StridedViews
   1509.7 ms  ✓ NPZ
  10927.2 ms  ✓ JSON3
  21465.0 ms  ✓ Unitful
  18953.3 ms  ✓ ImageCore
   1140.5 ms  ✓ MPItrampoline_jll
   1400.6 ms  ✓ MPICH_jll
   1281.1 ms  ✓ OpenMPI_jll
    523.0 ms  ✓ StringEncodings
    516.6 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
    639.2 ms  ✓ Unitful → ConstructionBaseUnitfulExt
   1213.8 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   1925.4 ms  ✓ OpenSSL
    555.4 ms  ✓ Unitful → InverseFunctionsUnitfulExt
    644.7 ms  ✓ Accessors → AccessorsUnitfulExt
   2380.2 ms  ✓ PeriodicTable
   2746.1 ms  ✓ UnitfulAtomic
   2084.7 ms  ✓ ImageBase
   1819.9 ms  ✓ HDF5_jll
   2356.0 ms  ✓ Pickle
   2317.4 ms  ✓ AtomsBase
   1970.3 ms  ✓ ImageShow
   7872.8 ms  ✓ HDF5
   2576.1 ms  ✓ Chemfiles
  19449.9 ms  ✓ CSV
   2655.4 ms  ✓ MAT
  19902.6 ms  ✓ HTTP
   2187.1 ms  ✓ FileIO → HTTPExt
   3516.7 ms  ✓ DataDeps
   9521.6 ms  ✓ MLDatasets
  53 dependencies successfully precompiled in 63 seconds. 148 already precompiled.
Precompiling SimpleChains...
    447.6 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
    734.1 ms  ✓ HostCPUFeatures
   8203.6 ms  ✓ VectorizationBase
   1026.5 ms  ✓ SLEEFPirates
   1232.8 ms  ✓ VectorizedRNG
    740.8 ms  ✓ VectorizedRNG → VectorizedRNGStaticArraysExt
  27822.7 ms  ✓ LoopVectorization
   1145.7 ms  ✓ LoopVectorization → SpecialFunctionsExt
   1370.6 ms  ✓ LoopVectorization → ForwardDiffExt
   6112.8 ms  ✓ SimpleChains
  10 dependencies successfully precompiled in 45 seconds. 75 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
   3359.0 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
  1 dependency successfully precompiled in 4 seconds. 112 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
   5601.8 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
  1 dependency successfully precompiled in 6 seconds. 120 already precompiled.
Precompiling LuxSimpleChainsExt...
   2896.0 ms  ✓ Lux → LuxSimpleChainsExt
  1 dependency successfully precompiled in 3 seconds. 140 already precompiled.

Loading MNIST

julia
function loadmnist(batchsize, train_split)
    # Load MNIST
    N = 2000
    dataset = MNIST(; split=:train)
    imgs = dataset.features[:, :, 1:N]
    labels_raw = dataset.targets[1:N]

    # Process images into (H, W, C, BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end
loadmnist (generic function with 1 method)

Define the Model

julia
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)))
Chain(
    layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
    layer_2 = MaxPool((2, 2)),
    layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
    layer_4 = MaxPool((2, 2)),
    layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
    layer_6 = Chain(
        layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
        layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
        layer_3 = Dense(84 => 10),      # 850 parameters
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.

julia
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
    Chain(
        layer_1 = Conv((5, 5), 1 => 6, relu),  # 156 parameters
        layer_2 = MaxPool((2, 2)),
        layer_3 = Conv((5, 5), 6 => 16, relu),  # 2_416 parameters
        layer_4 = MaxPool((2, 2)),
        layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
        layer_6 = Chain(
            layer_1 = Dense(256 => 128, relu),  # 32_896 parameters
            layer_2 = Dense(128 => 84, relu),  # 10_836 parameters
            layer_3 = Dense(84 => 10),  # 850 parameters
        ),
    ),
)         # Total: 47_154 parameters,
          #        plus 0 states.

Helper Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(Array(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Define the Training Loop

julia
function train(model; rng=Xoshiro(0), kwargs...)
    train_dataloader, test_dataloader = loadmnist(128, 0.9)
    ps, st = Lux.setup(rng, model)

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    ### Warmup the model
    x_proto = randn(rng, Float32, 28, 28, 1, 1)
    y_proto = onehotbatch([1], 0:9)
    Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)

    ### Lets train the model
    nepochs = 10
    tr_acc, te_acc = 0.0, 0.0
    for epoch in 1:nepochs
        stime = time()
        for (x, y) in train_dataloader
            gs, _, _, train_state = Training.single_train_step!(
                AutoZygote(), loss, (x, y), train_state)
        end
        ttime = time() - stime

        tr_acc = accuracy(
            model, train_state.parameters, train_state.states, train_dataloader) * 100
        te_acc = accuracy(
            model, train_state.parameters, train_state.states, test_dataloader) * 100

        @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" epoch nepochs ttime tr_acc te_acc
    end

    return tr_acc, te_acc
end
train (generic function with 1 method)

Finally Training the Model

First we will train the Lux model

julia
tr_acc, te_acc = train(lux_model)
[ 1/10] 	 Time 99.34s 	 Training Accuracy: 23.11% 	 Test Accuracy: 19.50%
[ 2/10] 	 Time 110.42s 	 Training Accuracy: 47.56% 	 Test Accuracy: 43.50%
[ 3/10] 	 Time 116.23s 	 Training Accuracy: 61.11% 	 Test Accuracy: 60.00%
[ 4/10] 	 Time 113.21s 	 Training Accuracy: 69.83% 	 Test Accuracy: 66.00%
[ 5/10] 	 Time 108.16s 	 Training Accuracy: 75.33% 	 Test Accuracy: 71.00%
[ 6/10] 	 Time 97.78s 	 Training Accuracy: 78.56% 	 Test Accuracy: 77.50%
[ 7/10] 	 Time 100.73s 	 Training Accuracy: 81.28% 	 Test Accuracy: 78.50%
[ 8/10] 	 Time 101.86s 	 Training Accuracy: 83.39% 	 Test Accuracy: 79.50%
[ 9/10] 	 Time 95.63s 	 Training Accuracy: 84.67% 	 Test Accuracy: 82.00%
[10/10] 	 Time 104.95s 	 Training Accuracy: 86.72% 	 Test Accuracy: 83.50%

Now we will train the SimpleChains model

julia
train(simple_chains_model)
[ 1/10] 	 Time 18.80s 	 Training Accuracy: 43.67% 	 Test Accuracy: 39.00%
[ 2/10] 	 Time 17.49s 	 Training Accuracy: 52.39% 	 Test Accuracy: 46.00%
[ 3/10] 	 Time 17.49s 	 Training Accuracy: 61.72% 	 Test Accuracy: 57.00%
[ 4/10] 	 Time 17.49s 	 Training Accuracy: 68.56% 	 Test Accuracy: 64.00%
[ 5/10] 	 Time 17.49s 	 Training Accuracy: 74.89% 	 Test Accuracy: 72.00%
[ 6/10] 	 Time 17.49s 	 Training Accuracy: 78.67% 	 Test Accuracy: 75.50%
[ 7/10] 	 Time 17.50s 	 Training Accuracy: 79.89% 	 Test Accuracy: 76.50%
[ 8/10] 	 Time 17.49s 	 Training Accuracy: 83.61% 	 Test Accuracy: 83.00%
[ 9/10] 	 Time 17.49s 	 Training Accuracy: 84.44% 	 Test Accuracy: 83.50%
[10/10] 	 Time 17.49s 	 Training Accuracy: 87.50% 	 Test Accuracy: 85.50%

On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.