Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
      Printf, Random, Zygote

CUDA.allowscalar(false)
Precompiling LuxComponentArraysExt...
   1622.1 ms  ✓ Lux → LuxComponentArraysExt
  1 dependency successfully precompiled in 2 seconds. 113 already precompiled.
Precompiling LuxLibCUDAExt...
   5438.1 ms  ✓ LuxLib → LuxLibCUDAExt
  1 dependency successfully precompiled in 6 seconds. 171 already precompiled.
Precompiling LuxLibcuDNNExt...
   5541.7 ms  ✓ LuxLib → LuxLibcuDNNExt
  1 dependency successfully precompiled in 6 seconds. 176 already precompiled.
Precompiling LuxMLUtilsExt...
   2282.3 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 167 already precompiled.
Precompiling LuxZygoteExt...
   2900.0 ms  ✓ Lux → LuxZygoteExt
  1 dependency successfully precompiled in 3 seconds. 166 already precompiled.

Loading Datasets

julia
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
        n_eval::Union{Nothing, Int}, batchsize::Int) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train); batchsize=min(batchsize, size(x_train, 4)), shuffle=true),
        DataLoader(
            (x_test, y_test); batchsize=min(batchsize, size(x_test, 4)), shuffle=false)
    )
end

function load_datasets(batchsize=256)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(
        weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
    ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
              ComponentArray |>
              getaxes
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end

Create and Initialize the HyperNet

julia
function create_model()
    # Doesn't need to be a MLP can have any Lux Layer
    core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
    weight_generator = Chain(
        Embedding(2 => 32),
        Dense(32, 64, relu),
        Dense(64, Lux.parameterlength(core_network))
    )

    model = HyperNet(weight_generator, core_network)
    return model
end
create_model (generic function with 1 method)

Define Utility Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model((data_idx, x), ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    model = create_model()
    dataloaders = load_datasets()

    dev = gpu_device()
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoZygote(), loss, ((data_idx, x), y), train_state)
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
[  1/ 50]	       MNIST	Time 94.45098s	Training Accuracy: 59.18%	Test Accuracy: 46.88%
[  1/ 50]	FashionMNIST	Time 0.03458s	Training Accuracy: 50.98%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.03528s	Training Accuracy: 67.38%	Test Accuracy: 65.62%
[  2/ 50]	FashionMNIST	Time 0.03218s	Training Accuracy: 55.08%	Test Accuracy: 37.50%
[  3/ 50]	       MNIST	Time 0.03225s	Training Accuracy: 73.24%	Test Accuracy: 62.50%
[  3/ 50]	FashionMNIST	Time 0.03322s	Training Accuracy: 67.77%	Test Accuracy: 53.12%
[  4/ 50]	       MNIST	Time 0.03267s	Training Accuracy: 79.20%	Test Accuracy: 65.62%
[  4/ 50]	FashionMNIST	Time 0.03358s	Training Accuracy: 63.77%	Test Accuracy: 50.00%
[  5/ 50]	       MNIST	Time 0.06424s	Training Accuracy: 84.47%	Test Accuracy: 62.50%
[  5/ 50]	FashionMNIST	Time 0.03170s	Training Accuracy: 71.78%	Test Accuracy: 50.00%
[  6/ 50]	       MNIST	Time 0.02446s	Training Accuracy: 88.09%	Test Accuracy: 65.62%
[  6/ 50]	FashionMNIST	Time 0.02443s	Training Accuracy: 76.76%	Test Accuracy: 56.25%
[  7/ 50]	       MNIST	Time 0.02429s	Training Accuracy: 89.16%	Test Accuracy: 75.00%
[  7/ 50]	FashionMNIST	Time 0.02518s	Training Accuracy: 78.42%	Test Accuracy: 65.62%
[  8/ 50]	       MNIST	Time 0.02348s	Training Accuracy: 91.99%	Test Accuracy: 81.25%
[  8/ 50]	FashionMNIST	Time 0.03206s	Training Accuracy: 79.79%	Test Accuracy: 65.62%
[  9/ 50]	       MNIST	Time 0.02287s	Training Accuracy: 95.02%	Test Accuracy: 81.25%
[  9/ 50]	FashionMNIST	Time 0.02440s	Training Accuracy: 80.18%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.02371s	Training Accuracy: 96.48%	Test Accuracy: 81.25%
[ 10/ 50]	FashionMNIST	Time 0.02476s	Training Accuracy: 81.05%	Test Accuracy: 68.75%
[ 11/ 50]	       MNIST	Time 0.02461s	Training Accuracy: 97.85%	Test Accuracy: 81.25%
[ 11/ 50]	FashionMNIST	Time 0.02367s	Training Accuracy: 83.20%	Test Accuracy: 62.50%
[ 12/ 50]	       MNIST	Time 0.02466s	Training Accuracy: 98.73%	Test Accuracy: 81.25%
[ 12/ 50]	FashionMNIST	Time 0.02490s	Training Accuracy: 84.18%	Test Accuracy: 62.50%
[ 13/ 50]	       MNIST	Time 0.02354s	Training Accuracy: 98.54%	Test Accuracy: 84.38%
[ 13/ 50]	FashionMNIST	Time 0.02442s	Training Accuracy: 83.69%	Test Accuracy: 65.62%
[ 14/ 50]	       MNIST	Time 0.02491s	Training Accuracy: 98.54%	Test Accuracy: 81.25%
[ 14/ 50]	FashionMNIST	Time 0.02635s	Training Accuracy: 85.25%	Test Accuracy: 59.38%
[ 15/ 50]	       MNIST	Time 0.02493s	Training Accuracy: 99.41%	Test Accuracy: 84.38%
[ 15/ 50]	FashionMNIST	Time 0.02355s	Training Accuracy: 81.05%	Test Accuracy: 65.62%
[ 16/ 50]	       MNIST	Time 0.02565s	Training Accuracy: 99.71%	Test Accuracy: 84.38%
[ 16/ 50]	FashionMNIST	Time 0.02708s	Training Accuracy: 82.71%	Test Accuracy: 65.62%
[ 17/ 50]	       MNIST	Time 0.02731s	Training Accuracy: 99.80%	Test Accuracy: 84.38%
[ 17/ 50]	FashionMNIST	Time 0.02515s	Training Accuracy: 83.89%	Test Accuracy: 71.88%
[ 18/ 50]	       MNIST	Time 0.02647s	Training Accuracy: 99.71%	Test Accuracy: 84.38%
[ 18/ 50]	FashionMNIST	Time 0.02631s	Training Accuracy: 83.11%	Test Accuracy: 62.50%
[ 19/ 50]	       MNIST	Time 0.02501s	Training Accuracy: 99.80%	Test Accuracy: 84.38%
[ 19/ 50]	FashionMNIST	Time 0.02496s	Training Accuracy: 86.13%	Test Accuracy: 62.50%
[ 20/ 50]	       MNIST	Time 0.02613s	Training Accuracy: 99.90%	Test Accuracy: 84.38%
[ 20/ 50]	FashionMNIST	Time 0.02391s	Training Accuracy: 84.57%	Test Accuracy: 59.38%
[ 21/ 50]	       MNIST	Time 0.02554s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 21/ 50]	FashionMNIST	Time 0.02692s	Training Accuracy: 85.74%	Test Accuracy: 62.50%
[ 22/ 50]	       MNIST	Time 0.02382s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 22/ 50]	FashionMNIST	Time 0.02456s	Training Accuracy: 89.26%	Test Accuracy: 59.38%
[ 23/ 50]	       MNIST	Time 0.02372s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 23/ 50]	FashionMNIST	Time 0.02369s	Training Accuracy: 89.26%	Test Accuracy: 59.38%
[ 24/ 50]	       MNIST	Time 0.02994s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 24/ 50]	FashionMNIST	Time 0.02286s	Training Accuracy: 88.96%	Test Accuracy: 62.50%
[ 25/ 50]	       MNIST	Time 0.02266s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 25/ 50]	FashionMNIST	Time 0.02504s	Training Accuracy: 91.60%	Test Accuracy: 62.50%
[ 26/ 50]	       MNIST	Time 0.02349s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 26/ 50]	FashionMNIST	Time 0.02268s	Training Accuracy: 91.99%	Test Accuracy: 65.62%
[ 27/ 50]	       MNIST	Time 0.02337s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 27/ 50]	FashionMNIST	Time 0.02308s	Training Accuracy: 92.97%	Test Accuracy: 62.50%
[ 28/ 50]	       MNIST	Time 0.03522s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 28/ 50]	FashionMNIST	Time 0.02333s	Training Accuracy: 93.07%	Test Accuracy: 65.62%
[ 29/ 50]	       MNIST	Time 0.02246s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 29/ 50]	FashionMNIST	Time 0.02220s	Training Accuracy: 93.26%	Test Accuracy: 62.50%
[ 30/ 50]	       MNIST	Time 0.02752s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 30/ 50]	FashionMNIST	Time 0.02433s	Training Accuracy: 93.65%	Test Accuracy: 62.50%
[ 31/ 50]	       MNIST	Time 0.02347s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 31/ 50]	FashionMNIST	Time 0.02389s	Training Accuracy: 93.65%	Test Accuracy: 62.50%
[ 32/ 50]	       MNIST	Time 0.02469s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 32/ 50]	FashionMNIST	Time 0.03397s	Training Accuracy: 93.95%	Test Accuracy: 62.50%
[ 33/ 50]	       MNIST	Time 0.02350s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 33/ 50]	FashionMNIST	Time 0.02493s	Training Accuracy: 94.63%	Test Accuracy: 62.50%
[ 34/ 50]	       MNIST	Time 0.02372s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 34/ 50]	FashionMNIST	Time 0.02424s	Training Accuracy: 94.73%	Test Accuracy: 62.50%
[ 35/ 50]	       MNIST	Time 0.02444s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 35/ 50]	FashionMNIST	Time 0.02449s	Training Accuracy: 94.73%	Test Accuracy: 59.38%
[ 36/ 50]	       MNIST	Time 0.02493s	Training Accuracy: 99.80%	Test Accuracy: 84.38%
[ 36/ 50]	FashionMNIST	Time 0.02314s	Training Accuracy: 95.02%	Test Accuracy: 62.50%
[ 37/ 50]	       MNIST	Time 0.02339s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 37/ 50]	FashionMNIST	Time 0.02375s	Training Accuracy: 95.12%	Test Accuracy: 65.62%
[ 38/ 50]	       MNIST	Time 0.02321s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 38/ 50]	FashionMNIST	Time 0.02251s	Training Accuracy: 95.90%	Test Accuracy: 62.50%
[ 39/ 50]	       MNIST	Time 0.02317s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 39/ 50]	FashionMNIST	Time 0.02318s	Training Accuracy: 95.80%	Test Accuracy: 65.62%
[ 40/ 50]	       MNIST	Time 0.02267s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 40/ 50]	FashionMNIST	Time 0.02352s	Training Accuracy: 96.39%	Test Accuracy: 65.62%
[ 41/ 50]	       MNIST	Time 0.02336s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 41/ 50]	FashionMNIST	Time 0.02320s	Training Accuracy: 95.80%	Test Accuracy: 65.62%
[ 42/ 50]	       MNIST	Time 0.02326s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 42/ 50]	FashionMNIST	Time 0.02928s	Training Accuracy: 95.80%	Test Accuracy: 65.62%
[ 43/ 50]	       MNIST	Time 0.02673s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 43/ 50]	FashionMNIST	Time 0.02489s	Training Accuracy: 95.61%	Test Accuracy: 65.62%
[ 44/ 50]	       MNIST	Time 0.02546s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 44/ 50]	FashionMNIST	Time 0.02335s	Training Accuracy: 96.58%	Test Accuracy: 65.62%
[ 45/ 50]	       MNIST	Time 0.03225s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 45/ 50]	FashionMNIST	Time 0.02560s	Training Accuracy: 96.29%	Test Accuracy: 65.62%
[ 46/ 50]	       MNIST	Time 0.02537s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 46/ 50]	FashionMNIST	Time 0.02629s	Training Accuracy: 96.39%	Test Accuracy: 65.62%
[ 47/ 50]	       MNIST	Time 0.02619s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 47/ 50]	FashionMNIST	Time 0.02527s	Training Accuracy: 96.78%	Test Accuracy: 65.62%
[ 48/ 50]	       MNIST	Time 0.03111s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 48/ 50]	FashionMNIST	Time 0.02471s	Training Accuracy: 96.88%	Test Accuracy: 68.75%
[ 49/ 50]	       MNIST	Time 0.02520s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 49/ 50]	FashionMNIST	Time 0.02442s	Training Accuracy: 97.07%	Test Accuracy: 65.62%
[ 50/ 50]	       MNIST	Time 0.02592s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 50/ 50]	FashionMNIST	Time 0.02452s	Training Accuracy: 97.46%	Test Accuracy: 65.62%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[FINAL]	FashionMNIST	Training Accuracy: 97.46%	Test Accuracy: 65.62%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.6.1
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.3
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.982 GiB / 4.750 GiB available)

This page was generated using Literate.jl.