Not Run on CI
This tutorial is not run on CI to reduce the computational burden. If you encounter any issues, please open an issue on the Lux.jl repository.
ConvMixer on CIFAR-10
Package Imports
julia
using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics
Set some global flags that will improve performance
julia
XLA_FLAGS = get(ENV, "XLA_FLAGS", "")
ENV["XLA_FLAGS"] = "$(XLA_FLAGS) --xla_gpu_enable_cublaslt=true"
Load Common Packages
julia
using ConcreteStructs,
DataAugmentation,
ImageShow,
Lux,
MLDatasets,
MLUtils,
OneHotArrays,
Printf,
ProgressTables,
Random,
BFloat16s
using Reactant
Data Loading Functionality
julia
@concrete struct TensorDataset
dataset
transform
end
Base.length(ds::TensorDataset) = length(ds.dataset)
function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer},AbstractRange})
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3))
y = onehotbatch(ds.dataset.targets[idxs], 0:9)
return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y
end
function get_cifar10_dataloaders(::Type{T}, batchsize; kwargs...) where {T}
cifar10_mean = T.((0.4914, 0.4822, 0.4465))
cifar10_std = T.((0.2471, 0.2435, 0.2616))
train_transform =
RandomResizeCrop((32, 32)) |>
Maybe(FlipX{2}()) |>
ImageToTensor() |>
Normalize(cifar10_mean, cifar10_std) |>
ToEltype(T)
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) |> ToEltype(T)
trainset = TensorDataset(CIFAR10(; Tx=T, split=:train), train_transform)
trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...)
testset = TensorDataset(CIFAR10(; Tx=T, split=:test), test_transform)
testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...)
return trainloader, testloader
end
Utility Functions
julia
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
cdev = cpu_device()
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
Training Loop
julia
function train_model(
model,
opt,
scheduler=nothing;
batchsize::Int=512,
seed::Int=1234,
epochs::Int=25,
bfloat16::Bool=false,
)
rng = Random.default_rng()
Random.seed!(rng, seed)
prec = bfloat16 ? bf16 : f32
prec_jl = bfloat16 ? BFloat16 : Float32
prec_str = bfloat16 ? "BFloat16" : "Float32"
@printf "[Info] Using %s precision\n" prec_str
dev = reactant_device(; force=true)
trainloader, testloader =
get_cifar10_dataloaders(prec_jl, batchsize; partial=false) |> dev
ps, st = prec(Lux.setup(rng, model)) |> dev
train_state = Training.TrainState(model, ps, st, opt)
x_ra = rand(rng, prec_jl, size(first(trainloader)[1])) |> dev
@printf "[Info] Compiling model with Reactant.jl\n"
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model(x_ra, ps, Lux.testmode(st))
end
@printf "[Info] Model compiled!\n"
loss_fn = CrossEntropyLoss(; logits=Val(true))
pt = ProgressTable(;
header=[
"Epoch", "Learning Rate", "Train Accuracy (%)", "Test Accuracy (%)", "Time (s)"
],
widths=[24, 24, 24, 24, 24],
format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"],
color=[:normal, :normal, :blue, :blue, :normal],
border=true,
alignment=[:center, :center, :center, :center, :center],
)
@printf "[Info] Training model\n"
initialize(pt)
for epoch in 1:epochs
stime = time()
lr = 0
for (i, (x, y)) in enumerate(trainloader)
if scheduler !== nothing
lr = scheduler((epoch - 1) + (i + 1) / length(trainloader))
train_state = Optimisers.adjust!(train_state, lr)
end
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_fn, (x, y), train_state; return_gradients=Val(false)
)
isnan(loss) && error("NaN loss encountered!")
end
ttime = time() - stime
train_acc =
accuracy(
model_compiled,
train_state.parameters,
Lux.testmode(train_state.states),
trainloader,
) * 100
test_acc =
accuracy(
model_compiled,
train_state.parameters,
Lux.testmode(train_state.states),
testloader,
) * 100
scheduler === nothing && (lr = NaN32)
next(pt, [epoch, lr, train_acc, test_acc, ttime])
end
finalize(pt)
return @printf "[Info] Finished training\n"
end
Model Definition
julia
function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
return Chain(
Conv((patch_size, patch_size), 3 => dim, relu; stride=patch_size),
BatchNorm(dim),
[
Chain(
SkipConnection(
Chain(
Conv(
(kernel_size, kernel_size),
dim => dim,
relu;
groups=dim,
pad=SamePad(),
),
BatchNorm(dim),
),
+,
),
Conv((1, 1), dim => dim, relu),
BatchNorm(dim),
) for _ in 1:depth
]...,
GlobalMeanPool(),
FlattenLayer(),
Dense(dim => 10),
)
end
Entry Point
julia
Comonicon.@main function main(;
batchsize::Int=512,
hidden_dim::Int=256,
depth::Int=8,
patch_size::Int=2,
kernel_size::Int=5,
weight_decay::Float64=0.0001,
clip_norm::Bool=false,
seed::Int=1234,
epochs::Int=25,
lr_max::Float64=0.05,
bfloat16::Bool=false,
)
model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size)
opt = AdamW(; eta=lr_max, lambda=weight_decay)
clip_norm && (opt = OptimiserChain(ClipNorm(), opt))
lr_schedule = linear_interpolation(
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]
)
return train_model(model, opt, lr_schedule; batchsize, seed, epochs, bfloat16)
end
This page was generated using Literate.jl.