Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        ŷ, _ = model((data_idx, x), ps, st)
        target_class = y |> cdev |> onecold
        predicted_class =|> cdev |> onecold
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = load_datasets() |> dev

    Random.seed!(1234)
    ps, st = Lux.setup(Random.default_rng(), model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
[  1/ 50]	       MNIST	Time 67.84203s	Training Accuracy: 34.57%	Test Accuracy: 37.50%
[  1/ 50]	FashionMNIST	Time 0.14110s	Training Accuracy: 32.62%	Test Accuracy: 43.75%
[  2/ 50]	       MNIST	Time 0.09295s	Training Accuracy: 36.72%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.09253s	Training Accuracy: 45.61%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.09559s	Training Accuracy: 41.02%	Test Accuracy: 28.12%
[  3/ 50]	FashionMNIST	Time 0.08765s	Training Accuracy: 57.03%	Test Accuracy: 59.38%
[  4/ 50]	       MNIST	Time 0.08834s	Training Accuracy: 52.25%	Test Accuracy: 40.62%
[  4/ 50]	FashionMNIST	Time 0.08415s	Training Accuracy: 63.96%	Test Accuracy: 59.38%
[  5/ 50]	       MNIST	Time 0.09107s	Training Accuracy: 57.32%	Test Accuracy: 37.50%
[  5/ 50]	FashionMNIST	Time 0.08962s	Training Accuracy: 70.51%	Test Accuracy: 53.12%
[  6/ 50]	       MNIST	Time 0.09344s	Training Accuracy: 63.67%	Test Accuracy: 37.50%
[  6/ 50]	FashionMNIST	Time 0.09053s	Training Accuracy: 75.49%	Test Accuracy: 56.25%
[  7/ 50]	       MNIST	Time 0.08993s	Training Accuracy: 71.00%	Test Accuracy: 34.38%
[  7/ 50]	FashionMNIST	Time 0.08954s	Training Accuracy: 76.66%	Test Accuracy: 53.12%
[  8/ 50]	       MNIST	Time 0.08844s	Training Accuracy: 75.68%	Test Accuracy: 43.75%
[  8/ 50]	FashionMNIST	Time 0.09535s	Training Accuracy: 81.84%	Test Accuracy: 65.62%
[  9/ 50]	       MNIST	Time 0.09048s	Training Accuracy: 80.08%	Test Accuracy: 43.75%
[  9/ 50]	FashionMNIST	Time 0.08935s	Training Accuracy: 83.89%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.09194s	Training Accuracy: 84.96%	Test Accuracy: 50.00%
[ 10/ 50]	FashionMNIST	Time 0.08935s	Training Accuracy: 88.57%	Test Accuracy: 59.38%
[ 11/ 50]	       MNIST	Time 0.10289s	Training Accuracy: 89.45%	Test Accuracy: 53.12%
[ 11/ 50]	FashionMNIST	Time 0.08998s	Training Accuracy: 90.04%	Test Accuracy: 68.75%
[ 12/ 50]	       MNIST	Time 0.09116s	Training Accuracy: 91.11%	Test Accuracy: 50.00%
[ 12/ 50]	FashionMNIST	Time 0.08926s	Training Accuracy: 91.11%	Test Accuracy: 68.75%
[ 13/ 50]	       MNIST	Time 0.08872s	Training Accuracy: 93.46%	Test Accuracy: 56.25%
[ 13/ 50]	FashionMNIST	Time 0.09405s	Training Accuracy: 91.89%	Test Accuracy: 68.75%
[ 14/ 50]	       MNIST	Time 0.09314s	Training Accuracy: 95.51%	Test Accuracy: 53.12%
[ 14/ 50]	FashionMNIST	Time 0.08822s	Training Accuracy: 94.53%	Test Accuracy: 65.62%
[ 15/ 50]	       MNIST	Time 0.08968s	Training Accuracy: 96.58%	Test Accuracy: 56.25%
[ 15/ 50]	FashionMNIST	Time 0.09238s	Training Accuracy: 93.85%	Test Accuracy: 71.88%
[ 16/ 50]	       MNIST	Time 0.08626s	Training Accuracy: 98.73%	Test Accuracy: 62.50%
[ 16/ 50]	FashionMNIST	Time 0.08900s	Training Accuracy: 96.29%	Test Accuracy: 68.75%
[ 17/ 50]	       MNIST	Time 0.09404s	Training Accuracy: 99.12%	Test Accuracy: 59.38%
[ 17/ 50]	FashionMNIST	Time 0.09011s	Training Accuracy: 96.97%	Test Accuracy: 68.75%
[ 18/ 50]	       MNIST	Time 0.08848s	Training Accuracy: 99.51%	Test Accuracy: 56.25%
[ 18/ 50]	FashionMNIST	Time 0.09391s	Training Accuracy: 96.97%	Test Accuracy: 68.75%
[ 19/ 50]	       MNIST	Time 0.08784s	Training Accuracy: 99.80%	Test Accuracy: 59.38%
[ 19/ 50]	FashionMNIST	Time 0.09838s	Training Accuracy: 99.32%	Test Accuracy: 68.75%
[ 20/ 50]	       MNIST	Time 0.08595s	Training Accuracy: 99.80%	Test Accuracy: 59.38%
[ 20/ 50]	FashionMNIST	Time 0.08543s	Training Accuracy: 98.73%	Test Accuracy: 68.75%
[ 21/ 50]	       MNIST	Time 0.08655s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 21/ 50]	FashionMNIST	Time 0.08995s	Training Accuracy: 98.73%	Test Accuracy: 65.62%
[ 22/ 50]	       MNIST	Time 0.09554s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 22/ 50]	FashionMNIST	Time 0.08640s	Training Accuracy: 99.61%	Test Accuracy: 71.88%
[ 23/ 50]	       MNIST	Time 0.08696s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 23/ 50]	FashionMNIST	Time 0.08693s	Training Accuracy: 99.71%	Test Accuracy: 65.62%
[ 24/ 50]	       MNIST	Time 0.08712s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 24/ 50]	FashionMNIST	Time 0.09485s	Training Accuracy: 99.51%	Test Accuracy: 65.62%
[ 25/ 50]	       MNIST	Time 0.08627s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 25/ 50]	FashionMNIST	Time 0.08631s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 26/ 50]	       MNIST	Time 0.08828s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 26/ 50]	FashionMNIST	Time 0.08661s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 27/ 50]	       MNIST	Time 0.09187s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 27/ 50]	FashionMNIST	Time 0.08911s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	       MNIST	Time 0.08842s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	FashionMNIST	Time 0.08539s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	       MNIST	Time 0.08768s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	FashionMNIST	Time 0.09321s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 30/ 50]	       MNIST	Time 0.08648s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 30/ 50]	FashionMNIST	Time 0.09249s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 31/ 50]	       MNIST	Time 0.08474s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 31/ 50]	FashionMNIST	Time 0.08740s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 32/ 50]	       MNIST	Time 0.08773s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 32/ 50]	FashionMNIST	Time 0.09306s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 33/ 50]	       MNIST	Time 0.09040s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 33/ 50]	FashionMNIST	Time 0.08583s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 34/ 50]	       MNIST	Time 0.08764s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 34/ 50]	FashionMNIST	Time 0.08921s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 35/ 50]	       MNIST	Time 0.08914s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 35/ 50]	FashionMNIST	Time 0.09129s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 36/ 50]	       MNIST	Time 0.08716s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 36/ 50]	FashionMNIST	Time 0.08515s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 37/ 50]	       MNIST	Time 0.08539s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 37/ 50]	FashionMNIST	Time 0.08605s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	       MNIST	Time 0.09263s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	FashionMNIST	Time 0.08562s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 39/ 50]	       MNIST	Time 0.10237s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 39/ 50]	FashionMNIST	Time 0.08805s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 40/ 50]	       MNIST	Time 0.08875s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 40/ 50]	FashionMNIST	Time 0.09226s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 41/ 50]	       MNIST	Time 0.08483s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 41/ 50]	FashionMNIST	Time 0.08819s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 42/ 50]	       MNIST	Time 0.09029s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 42/ 50]	FashionMNIST	Time 0.08805s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 43/ 50]	       MNIST	Time 0.09410s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 43/ 50]	FashionMNIST	Time 0.08416s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 44/ 50]	       MNIST	Time 0.08948s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 44/ 50]	FashionMNIST	Time 0.08772s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 45/ 50]	       MNIST	Time 0.08788s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 45/ 50]	FashionMNIST	Time 0.09373s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 46/ 50]	       MNIST	Time 0.08712s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 46/ 50]	FashionMNIST	Time 0.09230s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 47/ 50]	       MNIST	Time 0.08816s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 47/ 50]	FashionMNIST	Time 0.08582s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 48/ 50]	       MNIST	Time 0.08714s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 48/ 50]	FashionMNIST	Time 0.08770s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 49/ 50]	       MNIST	Time 0.09564s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 49/ 50]	FashionMNIST	Time 0.08578s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 50/ 50]	       MNIST	Time 0.08683s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 50/ 50]	FashionMNIST	Time 0.08947s	Training Accuracy: 100.00%	Test Accuracy: 68.75%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 68.75%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.