Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
      Printf, Random, Setfield, Statistics, Zygote

CUDA.allowscalar(false)

Loading Datasets

julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
    imgs, labels = dset(:train)[1:n_train]
    x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)

    imgs, labels = dset(:test)[1:n_eval]
    x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

    return (
        DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
        DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end

function load_datasets(n_train=1024, n_eval=32, batchsize=256)
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 4 methods)

Implement a HyperNet Layer

julia
function HyperNet(
        weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
    ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
              ComponentArray |>
              getaxes
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end

Create and Initialize the HyperNet

julia
function create_model()
    # Doesn't need to be a MLP can have any Lux Layer
    core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
    weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
        Dense(64, Lux.parameterlength(core_network)))

    model = HyperNet(weight_generator, core_network)
    return model
end
create_model (generic function with 1 method)

Define Utility Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model((data_idx, x), ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    model = create_model()
    dataloaders = load_datasets()

    dev = gpu_device()
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoZygote(), loss, ((data_idx, x), y), train_state)
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
[  1/ 50]	       MNIST	Time 71.03980s	Training Accuracy: 59.77%	Test Accuracy: 53.12%
[  1/ 50]	FashionMNIST	Time 0.02848s	Training Accuracy: 49.22%	Test Accuracy: 43.75%
[  2/ 50]	       MNIST	Time 0.02817s	Training Accuracy: 70.12%	Test Accuracy: 68.75%
[  2/ 50]	FashionMNIST	Time 0.02862s	Training Accuracy: 60.45%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.02804s	Training Accuracy: 76.95%	Test Accuracy: 65.62%
[  3/ 50]	FashionMNIST	Time 0.02236s	Training Accuracy: 63.38%	Test Accuracy: 59.38%
[  4/ 50]	       MNIST	Time 0.02292s	Training Accuracy: 79.20%	Test Accuracy: 65.62%
[  4/ 50]	FashionMNIST	Time 0.02286s	Training Accuracy: 59.18%	Test Accuracy: 56.25%
[  5/ 50]	       MNIST	Time 0.02302s	Training Accuracy: 82.42%	Test Accuracy: 68.75%
[  5/ 50]	FashionMNIST	Time 0.04113s	Training Accuracy: 73.14%	Test Accuracy: 62.50%
[  6/ 50]	       MNIST	Time 0.02088s	Training Accuracy: 88.09%	Test Accuracy: 68.75%
[  6/ 50]	FashionMNIST	Time 0.02097s	Training Accuracy: 72.17%	Test Accuracy: 56.25%
[  7/ 50]	       MNIST	Time 0.02118s	Training Accuracy: 91.41%	Test Accuracy: 75.00%
[  7/ 50]	FashionMNIST	Time 0.02044s	Training Accuracy: 75.59%	Test Accuracy: 62.50%
[  8/ 50]	       MNIST	Time 0.02076s	Training Accuracy: 92.19%	Test Accuracy: 81.25%
[  8/ 50]	FashionMNIST	Time 0.02037s	Training Accuracy: 78.32%	Test Accuracy: 62.50%
[  9/ 50]	       MNIST	Time 0.02041s	Training Accuracy: 93.75%	Test Accuracy: 81.25%
[  9/ 50]	FashionMNIST	Time 0.02058s	Training Accuracy: 77.64%	Test Accuracy: 68.75%
[ 10/ 50]	       MNIST	Time 0.03276s	Training Accuracy: 96.29%	Test Accuracy: 84.38%
[ 10/ 50]	FashionMNIST	Time 0.02028s	Training Accuracy: 77.34%	Test Accuracy: 68.75%
[ 11/ 50]	       MNIST	Time 0.02036s	Training Accuracy: 97.27%	Test Accuracy: 84.38%
[ 11/ 50]	FashionMNIST	Time 0.02073s	Training Accuracy: 79.39%	Test Accuracy: 68.75%
[ 12/ 50]	       MNIST	Time 0.02104s	Training Accuracy: 97.56%	Test Accuracy: 84.38%
[ 12/ 50]	FashionMNIST	Time 0.02038s	Training Accuracy: 79.69%	Test Accuracy: 71.88%
[ 13/ 50]	       MNIST	Time 0.02077s	Training Accuracy: 98.73%	Test Accuracy: 84.38%
[ 13/ 50]	FashionMNIST	Time 0.02030s	Training Accuracy: 82.71%	Test Accuracy: 75.00%
[ 14/ 50]	       MNIST	Time 0.02066s	Training Accuracy: 99.41%	Test Accuracy: 84.38%
[ 14/ 50]	FashionMNIST	Time 0.02658s	Training Accuracy: 82.91%	Test Accuracy: 71.88%
[ 15/ 50]	       MNIST	Time 0.02095s	Training Accuracy: 99.51%	Test Accuracy: 84.38%
[ 15/ 50]	FashionMNIST	Time 0.02054s	Training Accuracy: 82.91%	Test Accuracy: 75.00%
[ 16/ 50]	       MNIST	Time 0.02063s	Training Accuracy: 99.80%	Test Accuracy: 84.38%
[ 16/ 50]	FashionMNIST	Time 0.02108s	Training Accuracy: 83.98%	Test Accuracy: 62.50%
[ 17/ 50]	       MNIST	Time 0.02091s	Training Accuracy: 99.71%	Test Accuracy: 84.38%
[ 17/ 50]	FashionMNIST	Time 0.02103s	Training Accuracy: 86.43%	Test Accuracy: 68.75%
[ 18/ 50]	       MNIST	Time 0.02066s	Training Accuracy: 99.90%	Test Accuracy: 84.38%
[ 18/ 50]	FashionMNIST	Time 0.02035s	Training Accuracy: 86.62%	Test Accuracy: 65.62%
[ 19/ 50]	       MNIST	Time 0.02030s	Training Accuracy: 99.90%	Test Accuracy: 84.38%
[ 19/ 50]	FashionMNIST	Time 0.02081s	Training Accuracy: 88.67%	Test Accuracy: 68.75%
[ 20/ 50]	       MNIST	Time 0.02087s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 20/ 50]	FashionMNIST	Time 0.02054s	Training Accuracy: 89.36%	Test Accuracy: 71.88%
[ 21/ 50]	       MNIST	Time 0.02052s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 21/ 50]	FashionMNIST	Time 0.02101s	Training Accuracy: 89.65%	Test Accuracy: 84.38%
[ 22/ 50]	       MNIST	Time 0.02074s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 22/ 50]	FashionMNIST	Time 0.02047s	Training Accuracy: 90.43%	Test Accuracy: 78.12%
[ 23/ 50]	       MNIST	Time 0.02776s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 23/ 50]	FashionMNIST	Time 0.02046s	Training Accuracy: 90.72%	Test Accuracy: 78.12%
[ 24/ 50]	       MNIST	Time 0.02053s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 24/ 50]	FashionMNIST	Time 0.02026s	Training Accuracy: 91.31%	Test Accuracy: 78.12%
[ 25/ 50]	       MNIST	Time 0.02207s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 25/ 50]	FashionMNIST	Time 0.02159s	Training Accuracy: 91.89%	Test Accuracy: 84.38%
[ 26/ 50]	       MNIST	Time 0.02081s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 26/ 50]	FashionMNIST	Time 0.02090s	Training Accuracy: 91.80%	Test Accuracy: 81.25%
[ 27/ 50]	       MNIST	Time 0.02054s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 27/ 50]	FashionMNIST	Time 0.02756s	Training Accuracy: 92.58%	Test Accuracy: 81.25%
[ 28/ 50]	       MNIST	Time 0.02042s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 28/ 50]	FashionMNIST	Time 0.02061s	Training Accuracy: 92.77%	Test Accuracy: 75.00%
[ 29/ 50]	       MNIST	Time 0.02075s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 29/ 50]	FashionMNIST	Time 0.02078s	Training Accuracy: 92.38%	Test Accuracy: 75.00%
[ 30/ 50]	       MNIST	Time 0.02057s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 30/ 50]	FashionMNIST	Time 0.02072s	Training Accuracy: 92.77%	Test Accuracy: 81.25%
[ 31/ 50]	       MNIST	Time 0.02089s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 31/ 50]	FashionMNIST	Time 0.02029s	Training Accuracy: 93.26%	Test Accuracy: 84.38%
[ 32/ 50]	       MNIST	Time 0.02640s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 32/ 50]	FashionMNIST	Time 0.02032s	Training Accuracy: 94.04%	Test Accuracy: 81.25%
[ 33/ 50]	       MNIST	Time 0.02031s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 33/ 50]	FashionMNIST	Time 0.02044s	Training Accuracy: 94.53%	Test Accuracy: 81.25%
[ 34/ 50]	       MNIST	Time 0.02108s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 34/ 50]	FashionMNIST	Time 0.02044s	Training Accuracy: 94.82%	Test Accuracy: 81.25%
[ 35/ 50]	       MNIST	Time 0.02080s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 35/ 50]	FashionMNIST	Time 0.02075s	Training Accuracy: 95.31%	Test Accuracy: 78.12%
[ 36/ 50]	       MNIST	Time 0.02122s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 36/ 50]	FashionMNIST	Time 0.02006s	Training Accuracy: 95.12%	Test Accuracy: 81.25%
[ 37/ 50]	       MNIST	Time 0.02076s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 37/ 50]	FashionMNIST	Time 0.02073s	Training Accuracy: 95.80%	Test Accuracy: 84.38%
[ 38/ 50]	       MNIST	Time 0.02059s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 38/ 50]	FashionMNIST	Time 0.02081s	Training Accuracy: 95.12%	Test Accuracy: 81.25%
[ 39/ 50]	       MNIST	Time 0.02079s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 39/ 50]	FashionMNIST	Time 0.02320s	Training Accuracy: 95.70%	Test Accuracy: 84.38%
[ 40/ 50]	       MNIST	Time 0.02078s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 40/ 50]	FashionMNIST	Time 0.02696s	Training Accuracy: 96.09%	Test Accuracy: 84.38%
[ 41/ 50]	       MNIST	Time 0.02039s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 41/ 50]	FashionMNIST	Time 0.02076s	Training Accuracy: 96.00%	Test Accuracy: 84.38%
[ 42/ 50]	       MNIST	Time 0.02078s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 42/ 50]	FashionMNIST	Time 0.02069s	Training Accuracy: 96.19%	Test Accuracy: 81.25%
[ 43/ 50]	       MNIST	Time 0.02081s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 43/ 50]	FashionMNIST	Time 0.02025s	Training Accuracy: 96.09%	Test Accuracy: 84.38%
[ 44/ 50]	       MNIST	Time 0.02177s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 44/ 50]	FashionMNIST	Time 0.03603s	Training Accuracy: 96.68%	Test Accuracy: 84.38%
[ 45/ 50]	       MNIST	Time 0.02670s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 45/ 50]	FashionMNIST	Time 0.01978s	Training Accuracy: 97.27%	Test Accuracy: 84.38%
[ 46/ 50]	       MNIST	Time 0.01968s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 46/ 50]	FashionMNIST	Time 0.02102s	Training Accuracy: 97.46%	Test Accuracy: 84.38%
[ 47/ 50]	       MNIST	Time 0.01969s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 47/ 50]	FashionMNIST	Time 0.01991s	Training Accuracy: 97.27%	Test Accuracy: 84.38%
[ 48/ 50]	       MNIST	Time 0.01988s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 48/ 50]	FashionMNIST	Time 0.01967s	Training Accuracy: 97.27%	Test Accuracy: 84.38%
[ 49/ 50]	       MNIST	Time 0.01967s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 49/ 50]	FashionMNIST	Time 0.02651s	Training Accuracy: 97.56%	Test Accuracy: 84.38%
[ 50/ 50]	       MNIST	Time 0.01965s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 50/ 50]	FashionMNIST	Time 0.01965s	Training Accuracy: 97.46%	Test Accuracy: 84.38%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[FINAL]	FashionMNIST	Training Accuracy: 97.46%	Test Accuracy: 84.38%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6

CUDA libraries: 
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+555.42.6

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0

Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.170 GiB / 4.750 GiB available)

This page was generated using Literate.jl.