Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        ŷ, _ = model((data_idx, x), ps, st)
        target_class = y |> cdev |> onecold
        predicted_class =|> cdev |> onecold
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = load_datasets() |> dev

    Random.seed!(1234)
    ps, st = Lux.setup(Random.default_rng(), model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
[  1/ 50]	       MNIST	Time 47.92338s	Training Accuracy: 34.57%	Test Accuracy: 37.50%
[  1/ 50]	FashionMNIST	Time 0.09140s	Training Accuracy: 32.52%	Test Accuracy: 43.75%
[  2/ 50]	       MNIST	Time 0.08455s	Training Accuracy: 36.33%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.10560s	Training Accuracy: 46.19%	Test Accuracy: 46.88%
[  3/ 50]	       MNIST	Time 0.08393s	Training Accuracy: 42.68%	Test Accuracy: 28.12%
[  3/ 50]	FashionMNIST	Time 0.08217s	Training Accuracy: 56.64%	Test Accuracy: 56.25%
[  4/ 50]	       MNIST	Time 0.08624s	Training Accuracy: 51.37%	Test Accuracy: 37.50%
[  4/ 50]	FashionMNIST	Time 0.08552s	Training Accuracy: 65.23%	Test Accuracy: 62.50%
[  5/ 50]	       MNIST	Time 0.08488s	Training Accuracy: 56.64%	Test Accuracy: 40.62%
[  5/ 50]	FashionMNIST	Time 0.08496s	Training Accuracy: 70.51%	Test Accuracy: 59.38%
[  6/ 50]	       MNIST	Time 0.30614s	Training Accuracy: 61.62%	Test Accuracy: 37.50%
[  6/ 50]	FashionMNIST	Time 0.08376s	Training Accuracy: 75.78%	Test Accuracy: 56.25%
[  7/ 50]	       MNIST	Time 0.08380s	Training Accuracy: 67.77%	Test Accuracy: 43.75%
[  7/ 50]	FashionMNIST	Time 0.08220s	Training Accuracy: 75.59%	Test Accuracy: 62.50%
[  8/ 50]	       MNIST	Time 0.08520s	Training Accuracy: 74.90%	Test Accuracy: 46.88%
[  8/ 50]	FashionMNIST	Time 0.08879s	Training Accuracy: 80.96%	Test Accuracy: 62.50%
[  9/ 50]	       MNIST	Time 0.08283s	Training Accuracy: 81.05%	Test Accuracy: 53.12%
[  9/ 50]	FashionMNIST	Time 0.08549s	Training Accuracy: 84.77%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.08386s	Training Accuracy: 82.52%	Test Accuracy: 53.12%
[ 10/ 50]	FashionMNIST	Time 0.08421s	Training Accuracy: 88.18%	Test Accuracy: 59.38%
[ 11/ 50]	       MNIST	Time 0.09461s	Training Accuracy: 86.43%	Test Accuracy: 53.12%
[ 11/ 50]	FashionMNIST	Time 0.08155s	Training Accuracy: 89.84%	Test Accuracy: 62.50%
[ 12/ 50]	       MNIST	Time 0.08165s	Training Accuracy: 90.04%	Test Accuracy: 50.00%
[ 12/ 50]	FashionMNIST	Time 0.08282s	Training Accuracy: 92.09%	Test Accuracy: 68.75%
[ 13/ 50]	       MNIST	Time 0.08531s	Training Accuracy: 94.43%	Test Accuracy: 59.38%
[ 13/ 50]	FashionMNIST	Time 0.08557s	Training Accuracy: 93.46%	Test Accuracy: 68.75%
[ 14/ 50]	       MNIST	Time 0.08347s	Training Accuracy: 94.53%	Test Accuracy: 56.25%
[ 14/ 50]	FashionMNIST	Time 0.08493s	Training Accuracy: 95.51%	Test Accuracy: 68.75%
[ 15/ 50]	       MNIST	Time 0.08418s	Training Accuracy: 96.29%	Test Accuracy: 71.88%
[ 15/ 50]	FashionMNIST	Time 0.08574s	Training Accuracy: 94.24%	Test Accuracy: 68.75%
[ 16/ 50]	       MNIST	Time 0.08408s	Training Accuracy: 98.63%	Test Accuracy: 59.38%
[ 16/ 50]	FashionMNIST	Time 0.08371s	Training Accuracy: 95.51%	Test Accuracy: 71.88%
[ 17/ 50]	       MNIST	Time 0.08331s	Training Accuracy: 99.32%	Test Accuracy: 59.38%
[ 17/ 50]	FashionMNIST	Time 0.09246s	Training Accuracy: 97.36%	Test Accuracy: 71.88%
[ 18/ 50]	       MNIST	Time 0.08352s	Training Accuracy: 99.61%	Test Accuracy: 71.88%
[ 18/ 50]	FashionMNIST	Time 0.08406s	Training Accuracy: 97.46%	Test Accuracy: 71.88%
[ 19/ 50]	       MNIST	Time 0.09680s	Training Accuracy: 99.61%	Test Accuracy: 65.62%
[ 19/ 50]	FashionMNIST	Time 0.08479s	Training Accuracy: 98.63%	Test Accuracy: 65.62%
[ 20/ 50]	       MNIST	Time 0.08239s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 20/ 50]	FashionMNIST	Time 0.08401s	Training Accuracy: 98.63%	Test Accuracy: 71.88%
[ 21/ 50]	       MNIST	Time 0.08486s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 21/ 50]	FashionMNIST	Time 0.08380s	Training Accuracy: 98.93%	Test Accuracy: 71.88%
[ 22/ 50]	       MNIST	Time 0.08486s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 22/ 50]	FashionMNIST	Time 0.08348s	Training Accuracy: 99.51%	Test Accuracy: 71.88%
[ 23/ 50]	       MNIST	Time 0.08450s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 23/ 50]	FashionMNIST	Time 0.08156s	Training Accuracy: 99.71%	Test Accuracy: 68.75%
[ 24/ 50]	       MNIST	Time 0.08334s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 24/ 50]	FashionMNIST	Time 0.08429s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 25/ 50]	       MNIST	Time 0.08632s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 25/ 50]	FashionMNIST	Time 0.09528s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 26/ 50]	       MNIST	Time 0.08381s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 26/ 50]	FashionMNIST	Time 0.08569s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 27/ 50]	       MNIST	Time 0.09184s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 27/ 50]	FashionMNIST	Time 0.08370s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 28/ 50]	       MNIST	Time 0.08263s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 28/ 50]	FashionMNIST	Time 0.09316s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 29/ 50]	       MNIST	Time 0.08406s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	FashionMNIST	Time 0.08369s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 30/ 50]	       MNIST	Time 0.08162s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 30/ 50]	FashionMNIST	Time 0.08411s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 31/ 50]	       MNIST	Time 0.08809s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 31/ 50]	FashionMNIST	Time 0.08735s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 32/ 50]	       MNIST	Time 0.08484s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 32/ 50]	FashionMNIST	Time 0.08535s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 33/ 50]	       MNIST	Time 0.08372s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 33/ 50]	FashionMNIST	Time 0.08255s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 34/ 50]	       MNIST	Time 0.08368s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 34/ 50]	FashionMNIST	Time 0.08276s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 35/ 50]	       MNIST	Time 0.09185s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 35/ 50]	FashionMNIST	Time 0.08488s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 36/ 50]	       MNIST	Time 0.08576s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 36/ 50]	FashionMNIST	Time 0.09237s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 37/ 50]	       MNIST	Time 0.08322s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 37/ 50]	FashionMNIST	Time 0.08299s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 38/ 50]	       MNIST	Time 0.09251s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	FashionMNIST	Time 0.08243s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 39/ 50]	       MNIST	Time 0.08324s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 39/ 50]	FashionMNIST	Time 0.08391s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 40/ 50]	       MNIST	Time 0.08307s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 40/ 50]	FashionMNIST	Time 0.08314s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 41/ 50]	       MNIST	Time 0.08922s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 41/ 50]	FashionMNIST	Time 0.08223s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 42/ 50]	       MNIST	Time 0.08369s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 42/ 50]	FashionMNIST	Time 0.08459s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 43/ 50]	       MNIST	Time 0.08440s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 43/ 50]	FashionMNIST	Time 0.08607s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 44/ 50]	       MNIST	Time 0.09017s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 44/ 50]	FashionMNIST	Time 0.09493s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 45/ 50]	       MNIST	Time 0.08458s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 45/ 50]	FashionMNIST	Time 0.08421s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 46/ 50]	       MNIST	Time 0.09339s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 46/ 50]	FashionMNIST	Time 0.08915s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 47/ 50]	       MNIST	Time 0.08599s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 47/ 50]	FashionMNIST	Time 0.09286s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 48/ 50]	       MNIST	Time 0.08382s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 48/ 50]	FashionMNIST	Time 0.08387s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 49/ 50]	       MNIST	Time 0.08379s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 49/ 50]	FashionMNIST	Time 0.08348s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 50/ 50]	       MNIST	Time 0.08243s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 50/ 50]	FashionMNIST	Time 0.08299s	Training Accuracy: 100.00%	Test Accuracy: 68.75%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 68.75%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.12.4
Commit 01a2eadb047 (2026-01-06 16:56 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.