Not Run on CI
This tutorial is not run on CI to reduce the computational burden. If you encounter any issues, please open an issue on the Lux.jl repository.
ImageNet Classification using Distributed Data Parallel Training
This implements training of popular model architectures, such as ResNet, AlexNet, and VGG on the ImageNet dataset.
For distributed data-parallel training we need to launch this script using mpiexecjl
Setup MPI.jl. If your system has functional NCCL we will use it for all CUDA communications. Otherwise, we will use MPI for all communications.
bash
mpiexecjl -np 4 julia --startup=no --project=examples/ImageNet -t auto\
examples/ImageNet/main.jl \
--model-name="ViT" \
--model-kind="tiny" \
--train-batchsize=256 \
--val-batchsize=256 \
--optimizer-kind="sgd" \
--learning-rate=0.01 \
--base-path="/home/avik-pal/data/ImageNet/"
For single-node training, we can simply launch the script using julia
bash
julia --startup=no --project=examples/ImageNet -t auto examples/ImageNet/main.jl \
--model-name="ViT" \
--model-kind="tiny" \
--train-batchsize=256 \
--val-batchsize=256 \
--optimizer-kind="sgd" \
--learning-rate=0.01 \
--base-path="/home/avik-pal/data/ImageNet/"
Package Imports
julia
using Boltz, Lux, MLDataDevices
# import Metalhead # Install and load this package to use the Metalhead models with Lux
using Dates, Random
using DataAugmentation,
FileIO, MLUtils, OneHotArrays, Optimisers, ParameterSchedulers, Setfield
using Comonicon, Format
using JLD2
using Zygote
using LuxCUDA
# using AMDGPU # Install and load AMDGPU to train models on AMD GPUs with ROCm
using MPI: MPI
# Enables distributed training in Lux. NCCL is needed for CUDA GPUs
using NCCL: NCCL
const gdev = gpu_device()
const cdev = cpu_device()
Setup Distributed Training
We will use NCCL for NVIDIA GPUs and MPI for anything else
julia
const distributed_backend = try
if gdev isa CUDADevice
DistributedUtils.initialize(NCCLBackend)
DistributedUtils.get_distributed_backend(NCCLBackend)
else
DistributedUtils.initialize(MPIBackend)
DistributedUtils.get_distributed_backend(MPIBackend)
end
catch err
@error "Could not initialize distributed training. Error: $err"
nothing
end
const local_rank =
distributed_backend === nothing ? 0 : DistributedUtils.local_rank(distributed_backend)
const total_workers = if distributed_backend === nothing
1
else
DistributedUtils.total_workers(distributed_backend)
end
const is_distributed = total_workers > 1
const should_log = !is_distributed || local_rank == 0
Data Loading for ImageNet
julia
# We need the data to be in a specific format. See the
# [README.md](<unknown>/examples/ImageNet/README.md) for more details.
const IMAGENET_CORRUPTED_FILES = [
"n01739381_1309.JPEG",
"n02077923_14822.JPEG",
"n02447366_23489.JPEG",
"n02492035_15739.JPEG",
"n02747177_10752.JPEG",
"n03018349_4028.JPEG",
"n03062245_4620.JPEG",
"n03347037_9675.JPEG",
"n03467068_12171.JPEG",
"n03529860_11437.JPEG",
"n03544143_17228.JPEG",
"n03633091_5218.JPEG",
"n03710637_5125.JPEG",
"n03961711_5286.JPEG",
"n04033995_2932.JPEG",
"n04258138_17003.JPEG",
"n04264628_27969.JPEG",
"n04336792_7448.JPEG",
"n04371774_5854.JPEG",
"n04596742_4225.JPEG",
"n07583066_647.JPEG",
"n13037406_4650.JPEG",
"n02105855_2933.JPEG",
"ILSVRC2012_val_00019877.JPEG",
]
function load_imagenet1k(base_path::String, split::Symbol)
@assert split in (:train, :val)
full_path = joinpath(base_path, string(split))
synsets = sort(readdir(full_path))
@assert length(synsets) == 1000 "There should be 1000 subdirectories in $(full_path)."
image_files = String[]
labels = Int[]
for (i, synset) in enumerate(synsets)
filenames = readdir(joinpath(full_path, synset))
filter!(x -> x ∉ IMAGENET_CORRUPTED_FILES, filenames)
paths = joinpath.((full_path,), (synset,), filenames)
append!(image_files, paths)
append!(labels, repeat([i - 1], length(paths)))
end
return image_files, labels
end
default_image_size(::Type{Vision.VisionTransformer}, ::Nothing) = 256
default_image_size(::Type{Vision.VisionTransformer}, size::Int) = size
default_image_size(_, ::Nothing) = 224
default_image_size(_, size::Int) = size
struct MakeColoredImage <: DataAugmentation.Transform end
function DataAugmentation.apply(
::MakeColoredImage, item::DataAugmentation.AbstractArrayItem; randstate=nothing
)
data = itemdata(item)
(ndims(data) == 2 || size(data, 3) == 1) && (data = cat(data, data, data; dims=Val(3)))
return DataAugmentation.setdata(item, data)
end
struct FileDataset
files
labels
augment
end
Base.length(dataset::FileDataset) = length(dataset.files)
function Base.getindex(dataset::FileDataset, i::Int)
img = Image(FileIO.load(dataset.files[i]))
aug_img = itemdata(DataAugmentation.apply(dataset.augment, img))
return aug_img, OneHotArrays.onehot(dataset.labels[i], 0:999)
end
function construct_dataloaders(;
base_path::String, train_batchsize, val_batchsize, image_size::Int
)
sensible_println("=> creating dataloaders.")
train_augment =
ScaleFixed((256, 256)) |>
Maybe(FlipX(), 0.5) |>
RandomResizeCrop((image_size, image_size)) |>
PinOrigin() |>
ImageToTensor() |>
MakeColoredImage() |>
ToEltype(Float32) |>
Normalize((0.485f0, 0.456f0, 0.406f0), (0.229f0, 0.224f0, 0.225f0))
train_files, train_labels = load_imagenet1k(base_path, :train)
train_dataset = FileDataset(train_files, train_labels, train_augment)
val_augment =
ScaleFixed((image_size, image_size)) |>
PinOrigin() |>
ImageToTensor() |>
MakeColoredImage() |>
ToEltype(Float32) |>
Normalize((0.485f0, 0.456f0, 0.406f0), (0.229f0, 0.224f0, 0.225f0))
val_files, val_labels = load_imagenet1k(base_path, :val)
val_dataset = FileDataset(val_files, val_labels, val_augment)
if is_distributed
train_dataset = DistributedUtils.DistributedDataContainer(
distributed_backend, train_dataset
)
val_dataset = DistributedUtils.DistributedDataContainer(
distributed_backend, val_dataset
)
end
train_dataloader = DataLoader(
train_dataset;
batchsize=train_batchsize ÷ total_workers,
partial=false,
collate=true,
shuffle=true,
parallel=true,
)
val_dataloader = DataLoader(
val_dataset;
batchsize=val_batchsize ÷ total_workers,
partial=true,
collate=true,
shuffle=false,
parallel=true,
)
return gdev(train_dataloader), gdev(val_dataloader)
end
Model Construction
julia
function construct_model(;
rng::AbstractRNG, model_name::String, model_args, pretrained::Bool=false
)
model = getproperty(Vision, Symbol(model_name))(model_args...; pretrained)
ps, st = gdev(Lux.setup(rng, model))
sensible_println("=> model `$(model_name)` created.")
pretrained && sensible_println("==> using pre-trained model`")
sensible_println("==> number of trainable parameters: $(Lux.parameterlength(ps))")
sensible_println("==> number of states: $(Lux.statelength(st))")
if is_distributed
ps = DistributedUtils.synchronize!!(distributed_backend, ps)
st = DistributedUtils.synchronize!!(distributed_backend, st)
sensible_println("==> synced model parameters and states across all ranks")
end
return model, ps, st
end
Optimizer Configuration
julia
function construct_optimizer_and_scheduler(;
kind::String,
learning_rate::AbstractFloat,
nesterov::Bool,
momentum::AbstractFloat,
weight_decay::AbstractFloat,
scheduler_kind::String,
cycle_length::Int,
damp_factor::AbstractFloat,
lr_step_decay::AbstractFloat,
lr_step::Vector{Int},
)
sensible_println("=> creating optimizer.")
kind = Symbol(kind)
optimizer = if kind == :adam
Adam(learning_rate)
elseif kind == :sgd
if nesterov
Nesterov(learning_rate, momentum)
elseif iszero(momentum)
Descent(learning_rate)
else
Momentum(learning_rate, momentum)
end
else
throw(ArgumentError("Unknown value for `optimizer` = $kind. Supported options are: \
`adam` and `sgd`."))
end
optimizer = if iszero(weight_decay)
optimizer
else
OptimiserChain(optimizer, WeightDecay(weight_decay))
end
sensible_println("=> creating scheduler.")
scheduler_kind = Symbol(scheduler_kind)
scheduler = if scheduler_kind == :cosine
l0 = learning_rate
l1 = learning_rate / 100
ComposedSchedule(
CosAnneal(l0, l1, cycle_length), Step(l0, damp_factor, cycle_length)
)
elseif scheduler_kind == :constant
Constant(learning_rate)
elseif scheduler_kind == :step
Step(learning_rate, lr_step_decay, lr_step)
else
throw(ArgumentError("Unknown value for `lr_scheduler` = $(scheduler_kind). \
Supported options are: `constant`, `step` and `cosine`."))
end
optimizer = if is_distributed
DistributedUtils.DistributedOptimizer(distributed_backend, optimizer)
else
optimizer
end
return optimizer, scheduler
end
Utility Functions
julia
const logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
function loss_function(model, ps, st, (img, y))
ŷ, stₙ = model(img, ps, st)
return logitcrossentropy(ŷ, y), stₙ, (; prediction=ŷ)
end
sensible_println(msg) = should_log && println("[$(now())] ", msg)
sensible_print(msg) = should_log && print("[$(now())] ", msg)
function accuracy(ŷ::AbstractMatrix, y::AbstractMatrix, topk=(1,))
pred_labels = partialsortperm.(eachcol(cdev(ŷ)), Ref(1:maximum(topk)); rev=true)
true_labels = onecold(cdev(y))
accuracies = Vector{Float64}(undef, length(topk))
for (i, k) in enumerate(topk)
accuracies[i] = sum(
map((a, b) -> sum(view(a, 1:k) .== b), pred_labels, true_labels)
)
end
accuracies .= accuracies .* 100 ./ size(y, 2)
return accuracies
end
function save_checkpoint(state::NamedTuple; is_best::Bool, filename::String)
should_log || return nothing
@assert last(splitext(filename)) == ".jld2" "Filename should have a .jld2 extension."
isdir(dirname(filename)) || mkpath(dirname(filename))
save(filename; state)
sensible_println("=> saved checkpoint `$(filename)`.")
if is_best
symlink_safe(filename, joinpath(dirname(filename), "model_best.jld2"))
sensible_println("=> best model updated to `$(filename)`!")
end
symlink_safe(filename, joinpath(dirname(filename), "model_current.jld2"))
return nothing
end
function symlink_safe(src, dest)
rm(dest; force=true)
symlink(src, dest)
return nothing
end
function load_checkpoint(filename::String)
try ## NOTE(@avik-pal): ispath is failing for symlinks?
return JLD2.load(filename)[:state]
catch
sensible_println("$(filename) could not be loaded. This might be because the file \
is absent or is corrupt. Proceeding by returning `nothing`.")
return nothing
end
end
function full_gc_and_reclaim()
GC.gc(true)
MLDataDevices.functional(CUDADevice) && CUDA.reclaim()
MLDataDevices.functional(AMDGPUDevice) && AMDGPU.reclaim()
return nothing
end
@kwdef mutable struct AverageMeter
fmtstr
val::Float64 = 0.0
sum::Float64 = 0.0
count::Int = 0
average::Float64 = 0
end
function AverageMeter(name::String, fmt::String)
return AverageMeter(; fmtstr=FormatExpr("$(name) {1:$(fmt)} ({2:$(fmt)})"))
end
function (meter::AverageMeter)(val, n::Int)
meter.val = val
s = val * n
if is_distributed
v = [s, typeof(val)(n)]
DistributedUtils.allreduce!(backend, v, +)
s, n = v[1], Int(v[2])
end
meter.sum += s
meter.count += n
meter.average = meter.sum / meter.count
return meter.average
end
function reset_meter!(meter::AverageMeter)
meter.val = 0.0
meter.sum = 0.0
meter.count = 0
meter.average = 0.0
return meter
end
function print_meter(meter::AverageMeter)
return should_log && printfmt(meter.fmtstr, meter.val, meter.average)
end
struct ProgressMeter
batch_fmtstr
meters
end
function ProgressMeter(num_batches::Int, meters, prefix::String="")
fmt = "%" * string(length(string(num_batches))) * "d"
fmt2 = "{1:" * string(length(string(num_batches))) * "d}"
prefix = prefix != "" ? endswith(prefix, " ") ? prefix : prefix * " " : ""
batch_fmtstr = FormatExpr("$prefix[$fmt2/" * cfmt(fmt, num_batches) * "]")
return ProgressMeter(batch_fmtstr, meters)
end
reset_meter!(meter::ProgressMeter) = foreach(reset_meter!, meter.meters)
function print_meter(meter::ProgressMeter, batch::Int)
should_log || return nothing
printfmt(meter.batch_fmtstr, batch)
foreach(meter.meters) do x
print("\t")
print_meter(x)
return nothing
end
println()
return nothing
end
get_loggable_values(meter::ProgressMeter) = getproperty.(meter.meters, :average)
Training and Validation Loops
julia
function validate(val_loader, model, ps, st, step, total_steps)
batch_time = AverageMeter("Batch Time", "6.5f")
data_time = AverageMeter("Data Time", "6.5f")
forward_time = AverageMeter("Forward Pass Time", "6.5f")
losses = AverageMeter("Loss", ".6f")
top1 = AverageMeter("Acc@1", "6.4f")
top5 = AverageMeter("Acc@5", "6.4f")
progress = ProgressMeter(
total_steps, (batch_time, data_time, forward_time, losses, top1, top5), "Val:"
)
st = Lux.testmode(st)
t = time()
for (img, y) in val_loader
t_data, t = time() - t, time()
bsize = size(img, ndims(img))
loss, st, stats = loss_function(model, ps, st, (img, y))
t_forward = time() - t
acc1, acc5 = accuracy(stats.prediction, y, (1, 5))
top1(acc1, bsize)
top5(acc5, bsize)
losses(loss, bsize)
data_time(t_data, bsize)
forward_time(t_forward, bsize)
batch_time(t_data + t_forward, bsize)
t = time()
end
print_meter(progress, step)
return top1.average
end
Entry Point
julia
Comonicon.@main function main(;
seed::Int=0,
model_name::String,
model_kind::String="nokind",
depth::Int=-1,
pretrained::Bool=false,
base_path::String="",
train_batchsize::Int=64,
val_batchsize::Int=64,
image_size::Int=-1,
optimizer_kind::String="sgd",
learning_rate::Float32=0.01f0,
nesterov::Bool=false,
momentum::Float32=0.0f0,
weight_decay::Float32=0.0f0,
scheduler_kind::String="step",
cycle_length::Int=50000,
damp_factor::Float32=1.2f0,
lr_step_decay::Float32=0.1f0,
lr_step::Vector{Int}=[100000, 250000, 500000],
expt_id::String="",
expt_subdir::String=@__DIR__,
resume::String="",
evaluate::Bool=false,
total_steps::Int=800000,
evaluate_every::Int=10000,
print_frequency::Int=100,
)
best_acc1 = 0
rng = Random.default_rng()
Random.seed!(rng, seed)
model_type = getproperty(Vision, Symbol(model_name))
image_size = default_image_size(model_type, image_size == -1 ? nothing : image_size)
depth = depth == -1 ? nothing : depth
model_kind = model_kind == "nokind" ? nothing : Symbol(model_kind)
model_args = if model_kind === nothing && depth === nothing
()
elseif model_kind !== nothing
(model_kind,)
else
(depth,)
end
model, ps, st = construct_model(; rng, model_name, model_args, pretrained)
ds_train, ds_val = construct_dataloaders(;
base_path, train_batchsize, val_batchsize, image_size
)
opt, scheduler = construct_optimizer_and_scheduler(;
kind=optimizer_kind,
learning_rate,
nesterov,
momentum,
weight_decay,
scheduler_kind,
cycle_length,
damp_factor,
lr_step_decay,
lr_step,
)
expt_name = "name-$(model_name)_seed-$(seed)_id-$(expt_id)"
ckpt_dir = joinpath(expt_subdir, "checkpoints", expt_name)
rpath = resume == "" ? joinpath(ckpt_dir, "model_current.jld2") : resume
ckpt = load_checkpoint(rpath)
if !isnothing(ckpt)
ps, st = gdev((ckpt.ps, ckpt.st))
initial_step = ckpt.step
sensible_println("=> training started from $(initial_step)")
else
initial_step = 1
end
validate(ds_val, model, ps, st, 0, total_steps)
evaluate && return nothing
full_gc_and_reclaim()
batch_time = AverageMeter("Batch Time", "6.5f")
data_time = AverageMeter("Data Time", "6.5f")
training_time = AverageMeter("Training Time", "6.5f")
losses = AverageMeter("Loss", ".6f")
top1 = AverageMeter("Acc@1", "6.4f")
top5 = AverageMeter("Acc@5", "6.4f")
progress = ProgressMeter(
total_steps, (batch_time, data_time, training_time, losses, top1, top5), "Train:"
)
st = Lux.trainmode(st)
train_state = Training.TrainState(model, ps, st, opt)
if is_distributed
@set! train_state.optimizer_state = DistributedUtils.synchronize!!(
distributed_backend, train_state.optimizer_state
)
end
train_loader = Iterators.cycle(ds_train)
_, train_loader_state = iterate(train_loader)
for step in initial_step:total_steps
t = time()
(img, y), train_loader_state = iterate(train_loader, train_loader_state)
t_data = time() - t
bsize = size(img, ndims(img))
t = time()
_, loss, stats, train_state = Training.single_train_step!(
AutoZygote(), loss_function, (img, y), train_state
)
t_training = time() - t
isnan(loss) && throw(ArgumentError("NaN loss encountered."))
acc1, acc5 = accuracy(stats.prediction, y, (1, 5))
top1(acc1, bsize)
top5(acc5, bsize)
losses(loss, bsize)
data_time(t_data, bsize)
training_time(t_training, bsize)
batch_time(t_data + t_training, bsize)
if step % print_frequency == 1 || step == total_steps
print_meter(progress, step)
reset_meter!(progress)
end
if step % evaluate_every == 0
acc1 = validate(ds_val, model, ps, st, step, total_steps)
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
save_state = (; ps=cdev(ps), st=cdev(st), step)
if should_log()
save_checkpoint(
save_state; is_best, filename=joinpath(ckpt_dir, "model_$(step).jld2")
)
end
end
Optimisers.adjust!(train_state.optimizer_state, scheduler(step + 1))
end
return nothing
end
This page was generated using Literate.jl.