MNIST Classification with SimpleChains
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
Package Imports
using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf
using MLDatasets: MNIST
using SimpleChains: SimpleChainsPrecompiling OneHotArrays...
1863.3 ms ✓ OneHotArrays
1 dependency successfully precompiled in 2 seconds. 50 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
1831.5 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 2 seconds. 57 already precompiled.
Precompiling MLDatasets...
350.5 ms ✓ Glob
376.5 ms ✓ WorkerUtilities
403.2 ms ✓ BufferedStreams
321.8 ms ✓ SimpleBufferStream
301.5 ms ✓ PackageExtensionCompat
567.5 ms ✓ URIs
338.2 ms ✓ BitFlags
572.1 ms ✓ ZipFile
709.7 ms ✓ ConcurrentUtilities
773.2 ms ✓ StructTypes
478.2 ms ✓ LoggingExtras
1002.2 ms ✓ MbedTLS
324.8 ms ✓ InternedStrings
595.5 ms ✓ MPIPreferences
466.2 ms ✓ ExceptionUnwrapping
580.6 ms ✓ OpenSSL_jll
538.9 ms ✓ Chemfiles_jll
580.2 ms ✓ libaec_jll
469.8 ms ✓ MicrosoftMPI_jll
571.5 ms ✓ Libiconv_jll
3479.6 ms ✓ ColorSchemes
1007.5 ms ✓ FilePathsBase
743.5 ms ✓ WeakRefStrings
440.0 ms ✓ StridedViews
1509.7 ms ✓ NPZ
10927.2 ms ✓ JSON3
21465.0 ms ✓ Unitful
18953.3 ms ✓ ImageCore
1140.5 ms ✓ MPItrampoline_jll
1400.6 ms ✓ MPICH_jll
1281.1 ms ✓ OpenMPI_jll
523.0 ms ✓ StringEncodings
516.6 ms ✓ FilePathsBase → FilePathsBaseMmapExt
639.2 ms ✓ Unitful → ConstructionBaseUnitfulExt
1213.8 ms ✓ FilePathsBase → FilePathsBaseTestExt
1925.4 ms ✓ OpenSSL
555.4 ms ✓ Unitful → InverseFunctionsUnitfulExt
644.7 ms ✓ Accessors → AccessorsUnitfulExt
2380.2 ms ✓ PeriodicTable
2746.1 ms ✓ UnitfulAtomic
2084.7 ms ✓ ImageBase
1819.9 ms ✓ HDF5_jll
2356.0 ms ✓ Pickle
2317.4 ms ✓ AtomsBase
1970.3 ms ✓ ImageShow
7872.8 ms ✓ HDF5
2576.1 ms ✓ Chemfiles
19449.9 ms ✓ CSV
2655.4 ms ✓ MAT
19902.6 ms ✓ HTTP
2187.1 ms ✓ FileIO → HTTPExt
3516.7 ms ✓ DataDeps
9521.6 ms ✓ MLDatasets
53 dependencies successfully precompiled in 63 seconds. 148 already precompiled.
Precompiling SimpleChains...
447.6 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
734.1 ms ✓ HostCPUFeatures
8203.6 ms ✓ VectorizationBase
1026.5 ms ✓ SLEEFPirates
1232.8 ms ✓ VectorizedRNG
740.8 ms ✓ VectorizedRNG → VectorizedRNGStaticArraysExt
27822.7 ms ✓ LoopVectorization
1145.7 ms ✓ LoopVectorization → SpecialFunctionsExt
1370.6 ms ✓ LoopVectorization → ForwardDiffExt
6112.8 ms ✓ SimpleChains
10 dependencies successfully precompiled in 45 seconds. 75 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
3359.0 ms ✓ LuxLib → LuxLibSLEEFPiratesExt
1 dependency successfully precompiled in 4 seconds. 112 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
5601.8 ms ✓ LuxLib → LuxLibLoopVectorizationExt
1 dependency successfully precompiled in 6 seconds. 120 already precompiled.
Precompiling LuxSimpleChainsExt...
2896.0 ms ✓ Lux → LuxSimpleChainsExt
1 dependency successfully precompiled in 3 seconds. 140 already precompiled.Loading MNIST
function loadmnist(batchsize, train_split)
# Load MNIST
N = 2000
dataset = MNIST(; split=:train)
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
# Process images into (H, W, C, BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
endloadmnist (generic function with 1 method)Define the Model
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)))Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
) # Total: 47_154 parameters,
# plus 0 states.We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor and providing the input dimensions.
adaptor = ToSimpleChainsAdaptor((28, 28, 1))
simple_chains_model = adaptor(lux_model)SimpleChainsLayer(
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = Lux.FlattenLayer{Static.StaticInt{3}}(static(3)),
layer_6 = Chain(
layer_1 = Dense(256 => 128, relu), # 32_896 parameters
layer_2 = Dense(128 => 84, relu), # 10_836 parameters
layer_3 = Dense(84 => 10), # 850 parameters
),
),
) # Total: 47_154 parameters,
# plus 0 states.Helper Functions
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(Array(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endaccuracy (generic function with 1 method)Define the Training Loop
function train(model; rng=Xoshiro(0), kwargs...)
train_dataloader, test_dataloader = loadmnist(128, 0.9)
ps, st = Lux.setup(rng, model)
train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))
### Warmup the model
x_proto = randn(rng, Float32, 28, 28, 1, 1)
y_proto = onehotbatch([1], 0:9)
Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state)
### Lets train the model
nepochs = 10
tr_acc, te_acc = 0.0, 0.0
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
gs, _, _, train_state = Training.single_train_step!(
AutoZygote(), loss, (x, y), train_state)
end
ttime = time() - stime
tr_acc = accuracy(
model, train_state.parameters, train_state.states, train_dataloader) * 100
te_acc = accuracy(
model, train_state.parameters, train_state.states, test_dataloader) * 100
@printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \
%.2f%%\n" epoch nepochs ttime tr_acc te_acc
end
return tr_acc, te_acc
endtrain (generic function with 1 method)Finally Training the Model
First we will train the Lux model
tr_acc, te_acc = train(lux_model)[ 1/10] Time 99.34s Training Accuracy: 23.11% Test Accuracy: 19.50%
[ 2/10] Time 110.42s Training Accuracy: 47.56% Test Accuracy: 43.50%
[ 3/10] Time 116.23s Training Accuracy: 61.11% Test Accuracy: 60.00%
[ 4/10] Time 113.21s Training Accuracy: 69.83% Test Accuracy: 66.00%
[ 5/10] Time 108.16s Training Accuracy: 75.33% Test Accuracy: 71.00%
[ 6/10] Time 97.78s Training Accuracy: 78.56% Test Accuracy: 77.50%
[ 7/10] Time 100.73s Training Accuracy: 81.28% Test Accuracy: 78.50%
[ 8/10] Time 101.86s Training Accuracy: 83.39% Test Accuracy: 79.50%
[ 9/10] Time 95.63s Training Accuracy: 84.67% Test Accuracy: 82.00%
[10/10] Time 104.95s Training Accuracy: 86.72% Test Accuracy: 83.50%Now we will train the SimpleChains model
train(simple_chains_model)[ 1/10] Time 18.80s Training Accuracy: 43.67% Test Accuracy: 39.00%
[ 2/10] Time 17.49s Training Accuracy: 52.39% Test Accuracy: 46.00%
[ 3/10] Time 17.49s Training Accuracy: 61.72% Test Accuracy: 57.00%
[ 4/10] Time 17.49s Training Accuracy: 68.56% Test Accuracy: 64.00%
[ 5/10] Time 17.49s Training Accuracy: 74.89% Test Accuracy: 72.00%
[ 6/10] Time 17.49s Training Accuracy: 78.67% Test Accuracy: 75.50%
[ 7/10] Time 17.50s Training Accuracy: 79.89% Test Accuracy: 76.50%
[ 8/10] Time 17.49s Training Accuracy: 83.61% Test Accuracy: 83.00%
[ 9/10] Time 17.49s Training Accuracy: 84.44% Test Accuracy: 83.50%
[10/10] Time 17.49s Training Accuracy: 87.50% Test Accuracy: 85.50%On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.