Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        ŷ, _ = model((data_idx, x), ps, st)
        target_class = y |> cdev |> onecold
        predicted_class =|> cdev |> onecold
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = load_datasets() |> dev

    Random.seed!(1234)
    ps, st = Lux.setup(Random.default_rng(), model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
[  1/ 50]	       MNIST	Time 68.66474s	Training Accuracy: 34.57%	Test Accuracy: 37.50%
[  1/ 50]	FashionMNIST	Time 0.09102s	Training Accuracy: 32.62%	Test Accuracy: 43.75%
[  2/ 50]	       MNIST	Time 0.08948s	Training Accuracy: 36.72%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.09169s	Training Accuracy: 45.61%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.08911s	Training Accuracy: 41.02%	Test Accuracy: 28.12%
[  3/ 50]	FashionMNIST	Time 0.08720s	Training Accuracy: 57.03%	Test Accuracy: 59.38%
[  4/ 50]	       MNIST	Time 0.08925s	Training Accuracy: 52.25%	Test Accuracy: 40.62%
[  4/ 50]	FashionMNIST	Time 0.08731s	Training Accuracy: 63.96%	Test Accuracy: 59.38%
[  5/ 50]	       MNIST	Time 0.08615s	Training Accuracy: 57.32%	Test Accuracy: 37.50%
[  5/ 50]	FashionMNIST	Time 0.08757s	Training Accuracy: 70.51%	Test Accuracy: 53.12%
[  6/ 50]	       MNIST	Time 0.08490s	Training Accuracy: 63.67%	Test Accuracy: 37.50%
[  6/ 50]	FashionMNIST	Time 0.08741s	Training Accuracy: 75.49%	Test Accuracy: 56.25%
[  7/ 50]	       MNIST	Time 0.09414s	Training Accuracy: 71.00%	Test Accuracy: 34.38%
[  7/ 50]	FashionMNIST	Time 0.08686s	Training Accuracy: 76.66%	Test Accuracy: 53.12%
[  8/ 50]	       MNIST	Time 0.08670s	Training Accuracy: 75.68%	Test Accuracy: 43.75%
[  8/ 50]	FashionMNIST	Time 0.08567s	Training Accuracy: 81.84%	Test Accuracy: 65.62%
[  9/ 50]	       MNIST	Time 0.08507s	Training Accuracy: 80.08%	Test Accuracy: 43.75%
[  9/ 50]	FashionMNIST	Time 0.08905s	Training Accuracy: 83.89%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.08665s	Training Accuracy: 84.96%	Test Accuracy: 50.00%
[ 10/ 50]	FashionMNIST	Time 0.09309s	Training Accuracy: 88.57%	Test Accuracy: 59.38%
[ 11/ 50]	       MNIST	Time 0.08814s	Training Accuracy: 89.45%	Test Accuracy: 53.12%
[ 11/ 50]	FashionMNIST	Time 0.08566s	Training Accuracy: 90.04%	Test Accuracy: 68.75%
[ 12/ 50]	       MNIST	Time 0.08749s	Training Accuracy: 91.11%	Test Accuracy: 50.00%
[ 12/ 50]	FashionMNIST	Time 0.09397s	Training Accuracy: 91.11%	Test Accuracy: 68.75%
[ 13/ 50]	       MNIST	Time 0.09211s	Training Accuracy: 93.46%	Test Accuracy: 56.25%
[ 13/ 50]	FashionMNIST	Time 0.08720s	Training Accuracy: 91.89%	Test Accuracy: 68.75%
[ 14/ 50]	       MNIST	Time 0.08715s	Training Accuracy: 95.51%	Test Accuracy: 53.12%
[ 14/ 50]	FashionMNIST	Time 0.08619s	Training Accuracy: 94.53%	Test Accuracy: 65.62%
[ 15/ 50]	       MNIST	Time 0.08572s	Training Accuracy: 96.58%	Test Accuracy: 56.25%
[ 15/ 50]	FashionMNIST	Time 0.09332s	Training Accuracy: 93.85%	Test Accuracy: 71.88%
[ 16/ 50]	       MNIST	Time 0.08934s	Training Accuracy: 98.73%	Test Accuracy: 62.50%
[ 16/ 50]	FashionMNIST	Time 0.08669s	Training Accuracy: 96.29%	Test Accuracy: 68.75%
[ 17/ 50]	       MNIST	Time 0.08797s	Training Accuracy: 99.12%	Test Accuracy: 59.38%
[ 17/ 50]	FashionMNIST	Time 0.08551s	Training Accuracy: 96.97%	Test Accuracy: 68.75%
[ 18/ 50]	       MNIST	Time 0.09255s	Training Accuracy: 99.51%	Test Accuracy: 56.25%
[ 18/ 50]	FashionMNIST	Time 0.08512s	Training Accuracy: 96.97%	Test Accuracy: 68.75%
[ 19/ 50]	       MNIST	Time 0.08724s	Training Accuracy: 99.80%	Test Accuracy: 59.38%
[ 19/ 50]	FashionMNIST	Time 0.08602s	Training Accuracy: 99.32%	Test Accuracy: 68.75%
[ 20/ 50]	       MNIST	Time 0.08387s	Training Accuracy: 99.80%	Test Accuracy: 59.38%
[ 20/ 50]	FashionMNIST	Time 0.09208s	Training Accuracy: 98.73%	Test Accuracy: 68.75%
[ 21/ 50]	       MNIST	Time 0.08656s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 21/ 50]	FashionMNIST	Time 0.08497s	Training Accuracy: 98.73%	Test Accuracy: 65.62%
[ 22/ 50]	       MNIST	Time 0.08437s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 22/ 50]	FashionMNIST	Time 0.08568s	Training Accuracy: 99.61%	Test Accuracy: 71.88%
[ 23/ 50]	       MNIST	Time 0.09146s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 23/ 50]	FashionMNIST	Time 0.08779s	Training Accuracy: 99.71%	Test Accuracy: 65.62%
[ 24/ 50]	       MNIST	Time 0.08777s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 24/ 50]	FashionMNIST	Time 0.08658s	Training Accuracy: 99.51%	Test Accuracy: 65.62%
[ 25/ 50]	       MNIST	Time 0.08698s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 25/ 50]	FashionMNIST	Time 0.09203s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 26/ 50]	       MNIST	Time 0.08523s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 26/ 50]	FashionMNIST	Time 0.08612s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 27/ 50]	       MNIST	Time 0.08553s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 27/ 50]	FashionMNIST	Time 0.08586s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	       MNIST	Time 0.09116s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	FashionMNIST	Time 0.08752s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	       MNIST	Time 0.08628s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	FashionMNIST	Time 0.08644s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 30/ 50]	       MNIST	Time 0.08574s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 30/ 50]	FashionMNIST	Time 0.09200s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 31/ 50]	       MNIST	Time 0.08859s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 31/ 50]	FashionMNIST	Time 0.09048s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 32/ 50]	       MNIST	Time 0.08794s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 32/ 50]	FashionMNIST	Time 0.08589s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 33/ 50]	       MNIST	Time 0.09320s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 33/ 50]	FashionMNIST	Time 0.08834s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 34/ 50]	       MNIST	Time 0.09556s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 34/ 50]	FashionMNIST	Time 0.08478s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 35/ 50]	       MNIST	Time 0.08691s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 35/ 50]	FashionMNIST	Time 0.09204s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 36/ 50]	       MNIST	Time 0.08601s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 36/ 50]	FashionMNIST	Time 0.08650s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 37/ 50]	       MNIST	Time 0.08538s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 37/ 50]	FashionMNIST	Time 0.08416s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	       MNIST	Time 0.09231s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	FashionMNIST	Time 0.08680s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 39/ 50]	       MNIST	Time 0.08670s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 39/ 50]	FashionMNIST	Time 0.08650s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 40/ 50]	       MNIST	Time 0.08637s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 40/ 50]	FashionMNIST	Time 0.09415s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 41/ 50]	       MNIST	Time 0.08683s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 41/ 50]	FashionMNIST	Time 0.08654s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 42/ 50]	       MNIST	Time 0.08687s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 42/ 50]	FashionMNIST	Time 0.08670s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 43/ 50]	       MNIST	Time 0.09671s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 43/ 50]	FashionMNIST	Time 0.08837s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 44/ 50]	       MNIST	Time 0.08730s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 44/ 50]	FashionMNIST	Time 0.09048s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 45/ 50]	       MNIST	Time 0.08823s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 45/ 50]	FashionMNIST	Time 0.09215s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 46/ 50]	       MNIST	Time 0.08551s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 46/ 50]	FashionMNIST	Time 0.08634s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 47/ 50]	       MNIST	Time 0.08809s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 47/ 50]	FashionMNIST	Time 0.08601s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 48/ 50]	       MNIST	Time 0.09253s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 48/ 50]	FashionMNIST	Time 0.08888s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 49/ 50]	       MNIST	Time 0.08632s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 49/ 50]	FashionMNIST	Time 0.08487s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 50/ 50]	       MNIST	Time 0.08543s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 50/ 50]	FashionMNIST	Time 0.09527s	Training Accuracy: 100.00%	Test Accuracy: 68.75%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 68.75%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.