Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
      Printf, Random, Setfield, Statistics, Zygote

CUDA.allowscalar(false)

Loading Datasets

julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
    imgs, labels = dset(:train)[1:n_train]
    x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)

    imgs, labels = dset(:test)[1:n_eval]
    x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

    return (
        DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
        DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end

function load_datasets(n_train=1024, n_eval=32, batchsize=256)
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 4 methods)

Implement a HyperNet Layer

julia
function HyperNet(
        weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
    ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
              ComponentArray |>
              getaxes
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end

Create and Initialize the HyperNet

julia
function create_model()
    # Doesn't need to be a MLP can have any Lux Layer
    core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
    weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
        Dense(64, Lux.parameterlength(core_network)))

    model = HyperNet(weight_generator, core_network)
    return model
end
create_model (generic function with 1 method)

Define Utility Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model((data_idx, x), ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    model = create_model()
    dataloaders = load_datasets()

    dev = gpu_device()
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoZygote(), loss, ((data_idx, x), y), train_state)
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
[  1/ 50]	       MNIST	Time 78.46570s	Training Accuracy: 60.06%	Test Accuracy: 46.88%
[  1/ 50]	FashionMNIST	Time 0.03264s	Training Accuracy: 53.52%	Test Accuracy: 53.12%
[  2/ 50]	       MNIST	Time 0.02956s	Training Accuracy: 69.82%	Test Accuracy: 56.25%
[  2/ 50]	FashionMNIST	Time 0.03135s	Training Accuracy: 55.57%	Test Accuracy: 59.38%
[  3/ 50]	       MNIST	Time 0.03733s	Training Accuracy: 76.07%	Test Accuracy: 65.62%
[  3/ 50]	FashionMNIST	Time 0.03408s	Training Accuracy: 65.33%	Test Accuracy: 59.38%
[  4/ 50]	       MNIST	Time 0.02706s	Training Accuracy: 76.95%	Test Accuracy: 65.62%
[  4/ 50]	FashionMNIST	Time 0.02389s	Training Accuracy: 68.07%	Test Accuracy: 65.62%
[  5/ 50]	       MNIST	Time 0.02540s	Training Accuracy: 82.52%	Test Accuracy: 68.75%
[  5/ 50]	FashionMNIST	Time 0.02518s	Training Accuracy: 67.48%	Test Accuracy: 62.50%
[  6/ 50]	       MNIST	Time 0.02163s	Training Accuracy: 87.60%	Test Accuracy: 75.00%
[  6/ 50]	FashionMNIST	Time 0.02164s	Training Accuracy: 74.32%	Test Accuracy: 65.62%
[  7/ 50]	       MNIST	Time 0.02190s	Training Accuracy: 90.14%	Test Accuracy: 75.00%
[  7/ 50]	FashionMNIST	Time 0.02069s	Training Accuracy: 75.78%	Test Accuracy: 62.50%
[  8/ 50]	       MNIST	Time 0.02029s	Training Accuracy: 93.75%	Test Accuracy: 81.25%
[  8/ 50]	FashionMNIST	Time 0.02106s	Training Accuracy: 75.88%	Test Accuracy: 53.12%
[  9/ 50]	       MNIST	Time 0.02097s	Training Accuracy: 93.26%	Test Accuracy: 84.38%
[  9/ 50]	FashionMNIST	Time 0.04402s	Training Accuracy: 76.95%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.02088s	Training Accuracy: 95.61%	Test Accuracy: 81.25%
[ 10/ 50]	FashionMNIST	Time 0.02069s	Training Accuracy: 80.86%	Test Accuracy: 65.62%
[ 11/ 50]	       MNIST	Time 0.02230s	Training Accuracy: 97.36%	Test Accuracy: 81.25%
[ 11/ 50]	FashionMNIST	Time 0.02114s	Training Accuracy: 82.71%	Test Accuracy: 75.00%
[ 12/ 50]	       MNIST	Time 0.02115s	Training Accuracy: 98.54%	Test Accuracy: 84.38%
[ 12/ 50]	FashionMNIST	Time 0.02571s	Training Accuracy: 84.28%	Test Accuracy: 75.00%
[ 13/ 50]	       MNIST	Time 0.02090s	Training Accuracy: 98.73%	Test Accuracy: 84.38%
[ 13/ 50]	FashionMNIST	Time 0.02114s	Training Accuracy: 84.96%	Test Accuracy: 78.12%
[ 14/ 50]	       MNIST	Time 0.03184s	Training Accuracy: 99.22%	Test Accuracy: 84.38%
[ 14/ 50]	FashionMNIST	Time 0.02204s	Training Accuracy: 85.35%	Test Accuracy: 81.25%
[ 15/ 50]	       MNIST	Time 0.02328s	Training Accuracy: 99.51%	Test Accuracy: 84.38%
[ 15/ 50]	FashionMNIST	Time 0.02878s	Training Accuracy: 85.94%	Test Accuracy: 75.00%
[ 16/ 50]	       MNIST	Time 0.02129s	Training Accuracy: 99.71%	Test Accuracy: 84.38%
[ 16/ 50]	FashionMNIST	Time 0.02389s	Training Accuracy: 86.62%	Test Accuracy: 68.75%
[ 17/ 50]	       MNIST	Time 0.02113s	Training Accuracy: 99.80%	Test Accuracy: 84.38%
[ 17/ 50]	FashionMNIST	Time 0.02242s	Training Accuracy: 88.09%	Test Accuracy: 68.75%
[ 18/ 50]	       MNIST	Time 0.02085s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 18/ 50]	FashionMNIST	Time 0.02884s	Training Accuracy: 88.57%	Test Accuracy: 78.12%
[ 19/ 50]	       MNIST	Time 0.02230s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 19/ 50]	FashionMNIST	Time 0.02097s	Training Accuracy: 90.14%	Test Accuracy: 81.25%
[ 20/ 50]	       MNIST	Time 0.02478s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 20/ 50]	FashionMNIST	Time 0.02424s	Training Accuracy: 90.92%	Test Accuracy: 81.25%
[ 21/ 50]	       MNIST	Time 0.04484s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 21/ 50]	FashionMNIST	Time 0.01982s	Training Accuracy: 91.70%	Test Accuracy: 81.25%
[ 22/ 50]	       MNIST	Time 0.01972s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 22/ 50]	FashionMNIST	Time 0.01968s	Training Accuracy: 92.29%	Test Accuracy: 78.12%
[ 23/ 50]	       MNIST	Time 0.01994s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 23/ 50]	FashionMNIST	Time 0.02597s	Training Accuracy: 92.77%	Test Accuracy: 81.25%
[ 24/ 50]	       MNIST	Time 0.02466s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 24/ 50]	FashionMNIST	Time 0.03392s	Training Accuracy: 93.16%	Test Accuracy: 81.25%
[ 25/ 50]	       MNIST	Time 0.02114s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 25/ 50]	FashionMNIST	Time 0.02454s	Training Accuracy: 93.85%	Test Accuracy: 81.25%
[ 26/ 50]	       MNIST	Time 0.02240s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 26/ 50]	FashionMNIST	Time 0.03070s	Training Accuracy: 94.34%	Test Accuracy: 81.25%
[ 27/ 50]	       MNIST	Time 0.04950s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 27/ 50]	FashionMNIST	Time 0.03720s	Training Accuracy: 94.43%	Test Accuracy: 81.25%
[ 28/ 50]	       MNIST	Time 0.04596s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 28/ 50]	FashionMNIST	Time 0.02596s	Training Accuracy: 94.82%	Test Accuracy: 81.25%
[ 29/ 50]	       MNIST	Time 0.02988s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 29/ 50]	FashionMNIST	Time 0.02997s	Training Accuracy: 95.02%	Test Accuracy: 78.12%
[ 30/ 50]	       MNIST	Time 0.02919s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 30/ 50]	FashionMNIST	Time 0.03933s	Training Accuracy: 95.31%	Test Accuracy: 78.12%
[ 31/ 50]	       MNIST	Time 0.02382s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 31/ 50]	FashionMNIST	Time 0.02843s	Training Accuracy: 95.61%	Test Accuracy: 78.12%
[ 32/ 50]	       MNIST	Time 0.02070s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 32/ 50]	FashionMNIST	Time 0.02059s	Training Accuracy: 95.70%	Test Accuracy: 81.25%
[ 33/ 50]	       MNIST	Time 0.02155s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 33/ 50]	FashionMNIST	Time 0.02189s	Training Accuracy: 96.19%	Test Accuracy: 81.25%
[ 34/ 50]	       MNIST	Time 0.02188s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 34/ 50]	FashionMNIST	Time 0.02427s	Training Accuracy: 96.68%	Test Accuracy: 81.25%
[ 35/ 50]	       MNIST	Time 0.02018s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 35/ 50]	FashionMNIST	Time 0.02131s	Training Accuracy: 96.88%	Test Accuracy: 81.25%
[ 36/ 50]	       MNIST	Time 0.03078s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 36/ 50]	FashionMNIST	Time 0.02035s	Training Accuracy: 96.88%	Test Accuracy: 81.25%
[ 37/ 50]	       MNIST	Time 0.02239s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 37/ 50]	FashionMNIST	Time 0.02143s	Training Accuracy: 96.88%	Test Accuracy: 78.12%
[ 38/ 50]	       MNIST	Time 0.02121s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 38/ 50]	FashionMNIST	Time 0.02295s	Training Accuracy: 97.17%	Test Accuracy: 81.25%
[ 39/ 50]	       MNIST	Time 0.02132s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 39/ 50]	FashionMNIST	Time 0.02120s	Training Accuracy: 97.27%	Test Accuracy: 81.25%
[ 40/ 50]	       MNIST	Time 0.02339s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 40/ 50]	FashionMNIST	Time 0.02092s	Training Accuracy: 97.46%	Test Accuracy: 81.25%
[ 41/ 50]	       MNIST	Time 0.02237s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 41/ 50]	FashionMNIST	Time 0.02104s	Training Accuracy: 97.56%	Test Accuracy: 81.25%
[ 42/ 50]	       MNIST	Time 0.02215s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 42/ 50]	FashionMNIST	Time 0.02132s	Training Accuracy: 97.56%	Test Accuracy: 81.25%
[ 43/ 50]	       MNIST	Time 0.02194s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 43/ 50]	FashionMNIST	Time 0.02158s	Training Accuracy: 97.75%	Test Accuracy: 78.12%
[ 44/ 50]	       MNIST	Time 0.02189s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 44/ 50]	FashionMNIST	Time 0.02944s	Training Accuracy: 97.85%	Test Accuracy: 78.12%
[ 45/ 50]	       MNIST	Time 0.02178s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 45/ 50]	FashionMNIST	Time 0.02185s	Training Accuracy: 98.24%	Test Accuracy: 81.25%
[ 46/ 50]	       MNIST	Time 0.02149s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 46/ 50]	FashionMNIST	Time 0.02176s	Training Accuracy: 98.24%	Test Accuracy: 81.25%
[ 47/ 50]	       MNIST	Time 0.02047s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 47/ 50]	FashionMNIST	Time 0.02223s	Training Accuracy: 98.14%	Test Accuracy: 81.25%
[ 48/ 50]	       MNIST	Time 0.01997s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 48/ 50]	FashionMNIST	Time 0.02178s	Training Accuracy: 98.44%	Test Accuracy: 81.25%
[ 49/ 50]	       MNIST	Time 0.02526s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 49/ 50]	FashionMNIST	Time 0.02135s	Training Accuracy: 98.44%	Test Accuracy: 81.25%
[ 50/ 50]	       MNIST	Time 0.02100s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 50/ 50]	FashionMNIST	Time 0.02085s	Training Accuracy: 98.73%	Test Accuracy: 81.25%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[FINAL]	FashionMNIST	Training Accuracy: 98.73%	Test Accuracy: 81.25%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0

Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.170 GiB / 4.750 GiB available)

This page was generated using Literate.jl.