SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics
import MLDatasets: MNIST
import SimpleChains: static
function loadmnist(batchsize, train_split)
# Load MNIST
N = 2000
dataset = MNIST(; split=:train)
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
# Process images into (H,W,C,BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end
loadmnist (generic function with 1 method)
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)))
Chain(
layer_1 = Conv((5, 5), 1 => 6, relu), # 156 parameters
layer_2 = MaxPool((2, 2)),
layer_3 = Conv((5, 5), 6 => 16, relu), # 2_416 parameters
layer_4 = MaxPool((2, 2)),
layer_5 = FlattenLayer(),
layer_6 = Dense(256 => 128, relu), # 32_896 parameters
layer_7 = Dense(128 => 84, relu), # 10_836 parameters
layer_8 = Dense(84 => 10), # 850 parameters
) # Total: 47_154 parameters,
# plus 0 states.
We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining the ToSimpleChainsAdaptor
and providing the input dimensions.
adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1)))
simple_chains_model = adaptor(lux_model)
SimpleChainsLayer() # 47_154 parameters
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))
function loss(x, y, model, ps, st)
y_pred, st = model(x, ps, st)
return logitcrossentropy(y_pred, y), st
end
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(Array(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
function train(model; rng=Xoshiro(0), kwargs...)
ps, st = Lux.setup(rng, model)
train_dataloader, test_dataloader = loadmnist(128, 0.9)
opt = Adam(3.0f-4)
st_opt = Optimisers.setup(opt, ps)
### Warmup the Model
img = train_dataloader.data[1][:, :, :, 1:1]
lab = train_dataloader.data[2][:, 1:1]
loss(img, lab, model, ps, st)
(l, _), back = pullback(p -> loss(img, lab, model, p, st), ps)
back((one(l), nothing))
### Lets train the model
nepochs = 9
for epoch in 1:nepochs
stime = time()
for (x, y) in train_dataloader
(l, st), back = pullback(loss, x, y, model, ps, st)
### We need to add `nothing`s equal to the number of returned values - 1
gs = back((one(l), nothing))[4]
st_opt, ps = Optimisers.update(st_opt, ps, gs)
end
ttime = time() - stime
println("[$epoch/$nepochs] \t Time $(round(ttime; digits=2))s \t Training Accuracy: " *
"$(round(accuracy(model, ps, st, train_dataloader) * 100; digits=2))% \t " *
"Test Accuracy: $(round(accuracy(model, ps, st, test_dataloader) * 100; digits=2))%")
end
end
train (generic function with 1 method)
First we will train the Lux model
train(lux_model)
[1/9] Time 58.99s Training Accuracy: 24.11% Test Accuracy: 24.0%
[2/9] Time 44.99s Training Accuracy: 46.89% Test Accuracy: 47.5%
[3/9] Time 46.05s Training Accuracy: 68.06% Test Accuracy: 67.5%
[4/9] Time 49.25s Training Accuracy: 74.33% Test Accuracy: 72.5%
[5/9] Time 49.25s Training Accuracy: 80.61% Test Accuracy: 79.0%
[6/9] Time 48.62s Training Accuracy: 82.83% Test Accuracy: 82.5%
[7/9] Time 54.08s Training Accuracy: 84.72% Test Accuracy: 83.0%
[8/9] Time 44.46s Training Accuracy: 85.61% Test Accuracy: 84.0%
[9/9] Time 44.06s Training Accuracy: 85.83% Test Accuracy: 84.5%
Now we will train the SimpleChains model
train(simple_chains_model)
[1/9] Time 16.52s Training Accuracy: 45.61% Test Accuracy: 41.0%
[2/9] Time 15.78s Training Accuracy: 62.28% Test Accuracy: 57.5%
[3/9] Time 15.79s Training Accuracy: 73.28% Test Accuracy: 73.5%
[4/9] Time 15.77s Training Accuracy: 79.83% Test Accuracy: 78.5%
[5/9] Time 15.79s Training Accuracy: 82.94% Test Accuracy: 82.5%
[6/9] Time 15.9s Training Accuracy: 83.61% Test Accuracy: 84.5%
[7/9] Time 15.86s Training Accuracy: 85.67% Test Accuracy: 85.5%
[8/9] Time 15.82s Training Accuracy: 86.44% Test Accuracy: 86.0%
[9/9] Time 15.91s Training Accuracy: 87.67% Test Accuracy: 87.5%
On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-4/julialang/lux-dot-jl/docs/Project.toml
JULIA_AMDGPU_LOGGING_ENABLED = true
JULIA_DEBUG = Literate
JULIA_CPU_THREADS = 2
JULIA_NUM_THREADS = 48
JULIA_LOAD_PATH = @:@v#.#:@stdlib
JULIA_CUDA_HARD_MEMORY_LIMIT = 25%
This page was generated using Literate.jl.