MNIST Classification with SimpleChains
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
Package Imports
using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf
using MLDatasets: MNIST
using SimpleChains: SimpleChains
Precompiling OneHotArrays...
1014.6 ms ✓ OneHotArrays
1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
879.4 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling MLDatasets...
435.4 ms ✓ CodecZlib
697.0 ms ✓ GZip
729.3 ms ✓ ConcurrentUtilities
589.1 ms ✓ ZipFile
492.1 ms ✓ ExceptionUnwrapping
785.0 ms ✓ WeakRefStrings
2173.5 ms ✓ ColorVectorSpace
1229.7 ms ✓ FilePathsBase → FilePathsBaseTestExt
2284.6 ms ✓ AtomsBase
1905.4 ms ✓ HDF5_jll
1607.5 ms ✓ NPZ
2479.3 ms ✓ Pickle
19497.3 ms ✓ HTTP
19758.1 ms ✓ CSV
19208.3 ms ✓ ImageCore
2510.7 ms ✓ Chemfiles
3495.7 ms ✓ ColorSchemes
1913.3 ms ✓ FileIO → HTTPExt
3134.9 ms ✓ DataDeps
2121.4 ms ✓ ImageBase
7458.5 ms ✓ HDF5
1967.9 ms ✓ ImageShow
2430.8 ms ✓ MAT
9122.6 ms ✓ MLDatasets
24 dependencies successfully precompiled in 44 seconds. 174 already precompiled.
Precompiling SimpleChains...
1299.7 ms ✓ VectorizedRNG
1362.4 ms ✓ LoopVectorization → ForwardDiffExt
722.6 ms ✓ VectorizedRNG → VectorizedRNGStaticArraysExt
6177.7 ms ✓ SimpleChains
4 dependencies successfully precompiled in 8 seconds. 81 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
2773.1 ms ✓ LuxLib → LuxLibSLEEFPiratesExt
1 dependency successfully precompiled in 3 seconds. 108 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
4747.1 ms ✓ LuxLib → LuxLibLoopVectorizationExt
1 dependency successfully precompiled in 5 seconds. 116 already precompiled.
Precompiling LuxSimpleChainsExt...
2144.5 ms ✓ Lux → LuxSimpleChainsExt
1 dependency successfully precompiled in 2 seconds. 137 already precompiled.
Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
dataset = MNIST(; split=:train)
if N !== nothing
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
else
imgs = dataset.features
labels_raw = dataset.targets
end
# Process images into (H, W, C, BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end
loadmnist (generic function with 1 method)
Define the Model
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)))
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
) # Total: 47_154 parameters,
# plus 0 states.
We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor
and providing the input dimensions.
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer(
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
),
) # Total: 47_154 parameters,
# plus 0 states.
Helper Functions
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(Array(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Define the Training Loop
function train(model; rng=Xoshiro(0), kwargs...)
train_dataloader, test_dataloader = loadmnist(128, 0.9)
ps, st = Lux.setup(rng, model)
train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))
### Warmup the model
x_proto = randn(rng, Float32, 28, 28, 1, 1)
y_proto = onehotbatch([1], 0:9)
Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)
### Lets train the model
nepochs = 10
tr_acc, te_acc = 0.0, 0.0
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
gs, _, _, train_state = Training.single_train_step!(
AutoZygote(), loss, (x, y), train_state)
end
ttime = time() - stime
tr_acc = accuracy(
model, train_state.parameters, train_state.states, train_dataloader) * 100
te_acc = accuracy(
model, train_state.parameters, train_state.states, test_dataloader) * 100
@printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
%.2f%%\n" epoch nepochs ttime tr_acc te_acc
end
return tr_acc, te_acc
end
train (generic function with 1 method)
Finally Training the Model
First we will train the Lux model
tr_acc, te_acc = train(lux_model)
[ 1/10] Time 76.85s Training Accuracy: 17.78% Test Accuracy: 17.33%
[ 2/10] Time 73.15s Training Accuracy: 37.78% Test Accuracy: 30.67%
[ 3/10] Time 81.08s Training Accuracy: 51.63% Test Accuracy: 44.00%
[ 4/10] Time 83.41s Training Accuracy: 61.48% Test Accuracy: 54.00%
[ 5/10] Time 82.48s Training Accuracy: 67.93% Test Accuracy: 60.67%
[ 6/10] Time 87.14s Training Accuracy: 74.00% Test Accuracy: 64.00%
[ 7/10] Time 86.69s Training Accuracy: 77.19% Test Accuracy: 69.33%
[ 8/10] Time 85.98s Training Accuracy: 79.56% Test Accuracy: 73.33%
[ 9/10] Time 84.37s Training Accuracy: 82.00% Test Accuracy: 74.67%
[10/10] Time 85.09s Training Accuracy: 83.41% Test Accuracy: 80.00%
Now we will train the SimpleChains model
train(simple_chains_model)
[ 1/10] Time 14.56s Training Accuracy: 32.96% Test Accuracy: 22.00%
[ 2/10] Time 12.89s Training Accuracy: 41.26% Test Accuracy: 34.00%
[ 3/10] Time 12.85s Training Accuracy: 54.30% Test Accuracy: 49.33%
[ 4/10] Time 12.86s Training Accuracy: 60.00% Test Accuracy: 56.00%
[ 5/10] Time 12.86s Training Accuracy: 67.11% Test Accuracy: 62.67%
[ 6/10] Time 12.78s Training Accuracy: 72.07% Test Accuracy: 67.33%
[ 7/10] Time 12.77s Training Accuracy: 76.96% Test Accuracy: 70.67%
[ 8/10] Time 12.81s Training Accuracy: 80.52% Test Accuracy: 78.00%
[ 9/10] Time 12.77s Training Accuracy: 82.59% Test Accuracy: 76.00%
[10/10] Time 12.78s Training Accuracy: 85.85% Test Accuracy: 83.33%
On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.