MNIST Classification with SimpleChains
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
Package Imports
using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf
using MLDatasets: MNIST
using SimpleChains: SimpleChainsPrecompiling Lux...
1228.2 ms ✓ StructArrays
575.9 ms ✓ StructArrays → StructArraysAdaptExt
869.8 ms ✓ StructArrays → StructArraysSparseArraysExt
845.5 ms ✓ StructArrays → StructArraysStaticArraysExt
598.8 ms ✓ StructArrays → StructArraysLinearAlgebraExt
666.8 ms ✓ Accessors → AccessorsStructArraysExt
1911.2 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
698.5 ms ✓ BangBang → BangBangStructArraysExt
4979.2 ms ✓ GPUArrays
1931.9 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
2118.2 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
6271.5 ms ✓ ChainRules
1004.1 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1120.8 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
36797.7 ms ✓ Zygote
1865.1 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
10502.9 ms ✓ Lux
3539.6 ms ✓ Lux → LuxMLUtilsExt
3991.3 ms ✓ Lux → LuxZygoteExt
19 dependencies successfully precompiled in 64 seconds. 222 already precompiled.
Precompiling MLDatasets...
954.6 ms ✓ GZip
863.2 ms ✓ ZipFile
1001.8 ms ✓ ConcurrentUtilities
756.5 ms ✓ ExceptionUnwrapping
1189.8 ms ✓ WeakRefStrings
1579.8 ms ✓ FilePathsBase → FilePathsBaseTestExt
2522.2 ms ✓ ColorVectorSpace
1891.0 ms ✓ HDF5_jll
1937.5 ms ✓ AtomsBase
2774.6 ms ✓ Pickle
924.9 ms ✓ ColorVectorSpace → SpecialFunctionsExt
8616.6 ms ✓ HDF5
2674.5 ms ✓ Chemfiles
20736.1 ms ✓ HTTP
20961.1 ms ✓ CSV
3198.4 ms ✓ MAT
4535.2 ms ✓ ColorSchemes
2367.7 ms ✓ FileIO → HTTPExt
3773.5 ms ✓ DataDeps
1876.0 ms ✓ NPZ
20699.5 ms ✓ ImageCore
2385.8 ms ✓ ImageBase
2323.6 ms ✓ ImageShow
10469.4 ms ✓ MLDatasets
24 dependencies successfully precompiled in 55 seconds. 202 already precompiled.
Precompiling LuxSimpleChainsExt...
3257.6 ms ✓ Lux → LuxSimpleChainsExt
1 dependency successfully precompiled in 4 seconds. 244 already precompiled.Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST
N = 2000
dataset = MNIST(; split=:train)
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
# Process images into (H, W, C, BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
endloadmnist (generic function with 1 method)Define the Model
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)))Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
) # Total: 47_154 parameters,
# plus 0 states.We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)SimpleChainsLayer(
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
),
) # Total: 47_154 parameters,
# plus 0 states.Helper Functions
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(Array(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endaccuracy (generic function with 1 method)Define the Training Loop
function train(model; rng=Xoshiro(0), kwargs...)
train_dataloader, test_dataloader = loadmnist(128, 0.9)
ps, st = Lux.setup(rng, model)
train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))
### Warmup the model
x_proto = randn(rng, Float32, 28, 28, 1, 1)
y_proto = onehotbatch([1], 0:9)
Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)
### Lets train the model
nepochs = 10
tr_acc, te_acc = 0.0, 0.0
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
gs, _, _, train_state = Training.single_train_step!(
AutoZygote(), loss, (x, y), train_state)
end
ttime = time() - stime
tr_acc = accuracy(
model, train_state.parameters, train_state.states, train_dataloader) * 100
te_acc = accuracy(
model, train_state.parameters, train_state.states, test_dataloader) * 100
@printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
%.2f%%\n" epoch nepochs ttime tr_acc te_acc
end
return tr_acc, te_acc
endtrain (generic function with 1 method)Finally Training the Model
First we will train the Lux model
tr_acc, te_acc = train(lux_model)[ 1/10] Time 84.94s Training Accuracy: 24.28% Test Accuracy: 21.00%
[ 2/10] Time 76.85s Training Accuracy: 47.89% Test Accuracy: 48.00%
[ 3/10] Time 75.80s Training Accuracy: 60.11% Test Accuracy: 59.00%
[ 4/10] Time 82.42s Training Accuracy: 69.72% Test Accuracy: 64.50%
[ 5/10] Time 88.89s Training Accuracy: 74.89% Test Accuracy: 74.50%
[ 6/10] Time 86.14s Training Accuracy: 78.39% Test Accuracy: 76.50%
[ 7/10] Time 82.69s Training Accuracy: 80.94% Test Accuracy: 79.00%
[ 8/10] Time 81.81s Training Accuracy: 83.83% Test Accuracy: 81.00%
[ 9/10] Time 82.87s Training Accuracy: 84.83% Test Accuracy: 83.00%
[10/10] Time 93.35s Training Accuracy: 86.94% Test Accuracy: 85.50%Now we will train the SimpleChains model
train(simple_chains_model)[ 1/10] Time 18.58s Training Accuracy: 44.83% Test Accuracy: 44.50%
[ 2/10] Time 17.20s Training Accuracy: 54.72% Test Accuracy: 48.50%
[ 3/10] Time 17.22s Training Accuracy: 61.67% Test Accuracy: 59.00%
[ 4/10] Time 17.19s Training Accuracy: 69.67% Test Accuracy: 64.00%
[ 5/10] Time 17.21s Training Accuracy: 73.72% Test Accuracy: 71.00%
[ 6/10] Time 17.21s Training Accuracy: 77.44% Test Accuracy: 72.00%
[ 7/10] Time 17.21s Training Accuracy: 81.33% Test Accuracy: 79.50%
[ 8/10] Time 17.20s Training Accuracy: 82.11% Test Accuracy: 80.50%
[ 9/10] Time 17.22s Training Accuracy: 85.83% Test Accuracy: 84.00%
[10/10] Time 17.22s Training Accuracy: 87.17% Test Accuracy: 86.50%On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.