Skip to content

Not Run on CI

This tutorial is not run on CI to reduce the computational burden. If you encounter any issues, please open an issue on the Lux.jl repository.

ConvMixer on CIFAR-10

Package Imports

julia
using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics

Set some global flags that will improve performance

julia
XLA_FLAGS = get(ENV, "XLA_FLAGS", "")
ENV["XLA_FLAGS"] = "$(XLA_FLAGS) --xla_gpu_enable_cublaslt=true"

Load Common Packages

julia
using ConcreteStructs,
    DataAugmentation,
    ImageShow,
    Lux,
    MLDatasets,
    MLUtils,
    OneHotArrays,
    Printf,
    ProgressTables,
    Random,
    BFloat16s
using Reactant

Data Loading Functionality

julia
@concrete struct TensorDataset
    dataset
    transform
end

Base.length(ds::TensorDataset) = length(ds.dataset)

function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer},AbstractRange})
    img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3))
    y = onehotbatch(ds.dataset.targets[idxs], 0:9)
    return stack(parent  itemdata  Base.Fix1(apply, ds.transform), img), y
end

function get_cifar10_dataloaders(::Type{T}, batchsize; kwargs...) where {T}
    cifar10_mean = T.((0.4914, 0.4822, 0.4465))
    cifar10_std = T.((0.2471, 0.2435, 0.2616))

    train_transform =
        RandomResizeCrop((32, 32)) |>
        Maybe(FlipX{2}()) |>
        ImageToTensor() |>
        Normalize(cifar10_mean, cifar10_std) |>
        ToEltype(T)

    test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) |> ToEltype(T)

    trainset = TensorDataset(CIFAR10(; Tx=T, split=:train), train_transform)
    trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...)

    testset = TensorDataset(CIFAR10(; Tx=T, split=:test), test_transform)
    testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...)

    return trainloader, testloader
end

Utility Functions

julia
function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    cdev = cpu_device()
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model(x, ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training Loop

julia
function train_model(
    model,
    opt,
    scheduler=nothing;
    batchsize::Int=512,
    seed::Int=1234,
    epochs::Int=25,
    bfloat16::Bool=false,
)
    rng = Random.default_rng()
    Random.seed!(rng, seed)

    prec = bfloat16 ? bf16 : f32
    prec_jl = bfloat16 ? BFloat16 : Float32
    prec_str = bfloat16 ? "BFloat16" : "Float32"
    @printf "[Info] Using %s precision\n" prec_str

    dev = reactant_device(; force=true)

    trainloader, testloader =
        get_cifar10_dataloaders(prec_jl, batchsize; partial=false) |> dev

    ps, st = prec(Lux.setup(rng, model)) |> dev

    train_state = Training.TrainState(model, ps, st, opt)

    x_ra = rand(rng, prec_jl, size(first(trainloader)[1])) |> dev
    @printf "[Info] Compiling model with Reactant.jl\n"
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model(x_ra, ps, Lux.testmode(st))
    end
    @printf "[Info] Model compiled!\n"

    loss_fn = CrossEntropyLoss(; logits=Val(true))

    pt = ProgressTable(;
        header=[
            "Epoch", "Learning Rate", "Train Accuracy (%)", "Test Accuracy (%)", "Time (s)"
        ],
        widths=[24, 24, 24, 24, 24],
        format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"],
        color=[:normal, :normal, :blue, :blue, :normal],
        border=true,
        alignment=[:center, :center, :center, :center, :center],
    )

    @printf "[Info] Training model\n"
    initialize(pt)

    for epoch in 1:epochs
        stime = time()
        lr = 0
        for (i, (x, y)) in enumerate(trainloader)
            if scheduler !== nothing
                lr = scheduler((epoch - 1) + (i + 1) / length(trainloader))
                train_state = Optimisers.adjust!(train_state, lr)
            end
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoEnzyme(), loss_fn, (x, y), train_state; return_gradients=Val(false)
            )
            isnan(loss) && error("NaN loss encountered!")
        end
        ttime = time() - stime

        train_acc =
            accuracy(
                model_compiled,
                train_state.parameters,
                Lux.testmode(train_state.states),
                trainloader,
            ) * 100
        test_acc =
            accuracy(
                model_compiled,
                train_state.parameters,
                Lux.testmode(train_state.states),
                testloader,
            ) * 100

        scheduler === nothing && (lr = NaN32)
        next(pt, [epoch, lr, train_acc, test_acc, ttime])
    end

    finalize(pt)
    return @printf "[Info] Finished training\n"
end

Model Definition

julia
function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
    return Chain(
        Conv((patch_size, patch_size), 3 => dim, relu; stride=patch_size),
        BatchNorm(dim),
        [
            Chain(
                SkipConnection(
                    Chain(
                        Conv(
                            (kernel_size, kernel_size),
                            dim => dim,
                            relu;
                            groups=dim,
                            pad=SamePad(),
                        ),
                        BatchNorm(dim),
                    ),
                    +,
                ),
                Conv((1, 1), dim => dim, relu),
                BatchNorm(dim),
            ) for _ in 1:depth
        ]...,
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(dim => 10),
    )
end

Entry Point

julia
Comonicon.@main function main(;
    batchsize::Int=512,
    hidden_dim::Int=256,
    depth::Int=8,
    patch_size::Int=2,
    kernel_size::Int=5,
    weight_decay::Float64=0.0001,
    clip_norm::Bool=false,
    seed::Int=1234,
    epochs::Int=25,
    lr_max::Float64=0.05,
    bfloat16::Bool=false,
)
    model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size)

    opt = AdamW(; eta=lr_max, lambda=weight_decay)
    clip_norm && (opt = OptimiserChain(ClipNorm(), opt))

    lr_schedule = linear_interpolation(
        [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]
    )

    return train_model(model, opt, lr_schedule; batchsize, seed, epochs, bfloat16)
end

This page was generated using Literate.jl.