Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1758241995.694755 3706750 service.cc:158] XLA service 0x22560740 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1758241995.694828 3706750 service.cc:166]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1758241995.695621 3706750 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1758241995.695680 3706750 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1758241995.695728 3706750 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1758241995.708141 3706750 cuda_dnn.cc:463] Loaded cuDNN version 91200
[  1/ 50]	       MNIST	Time 44.64796s	Training Accuracy: 35.45%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.03934s	Training Accuracy: 31.74%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.04115s	Training Accuracy: 35.74%	Test Accuracy: 31.25%
[  2/ 50]	FashionMNIST	Time 0.07768s	Training Accuracy: 47.07%	Test Accuracy: 53.12%
[  3/ 50]	       MNIST	Time 0.03701s	Training Accuracy: 40.72%	Test Accuracy: 37.50%
[  3/ 50]	FashionMNIST	Time 0.03356s	Training Accuracy: 55.57%	Test Accuracy: 53.12%
[  4/ 50]	       MNIST	Time 0.03282s	Training Accuracy: 48.73%	Test Accuracy: 34.38%
[  4/ 50]	FashionMNIST	Time 0.03413s	Training Accuracy: 62.21%	Test Accuracy: 56.25%
[  5/ 50]	       MNIST	Time 0.04897s	Training Accuracy: 56.15%	Test Accuracy: 46.88%
[  5/ 50]	FashionMNIST	Time 0.03753s	Training Accuracy: 68.65%	Test Accuracy: 59.38%
[  6/ 50]	       MNIST	Time 0.04907s	Training Accuracy: 61.62%	Test Accuracy: 40.62%
[  6/ 50]	FashionMNIST	Time 0.03733s	Training Accuracy: 74.61%	Test Accuracy: 59.38%
[  7/ 50]	       MNIST	Time 0.04755s	Training Accuracy: 67.87%	Test Accuracy: 46.88%
[  7/ 50]	FashionMNIST	Time 0.03761s	Training Accuracy: 76.76%	Test Accuracy: 62.50%
[  8/ 50]	       MNIST	Time 0.04804s	Training Accuracy: 72.17%	Test Accuracy: 43.75%
[  8/ 50]	FashionMNIST	Time 0.03595s	Training Accuracy: 81.54%	Test Accuracy: 59.38%
[  9/ 50]	       MNIST	Time 0.03919s	Training Accuracy: 77.83%	Test Accuracy: 50.00%
[  9/ 50]	FashionMNIST	Time 0.05106s	Training Accuracy: 84.28%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.03953s	Training Accuracy: 82.62%	Test Accuracy: 50.00%
[ 10/ 50]	FashionMNIST	Time 0.04714s	Training Accuracy: 88.38%	Test Accuracy: 65.62%
[ 11/ 50]	       MNIST	Time 0.04005s	Training Accuracy: 87.89%	Test Accuracy: 56.25%
[ 11/ 50]	FashionMNIST	Time 0.03406s	Training Accuracy: 89.84%	Test Accuracy: 59.38%
[ 12/ 50]	       MNIST	Time 0.03417s	Training Accuracy: 85.74%	Test Accuracy: 62.50%
[ 12/ 50]	FashionMNIST	Time 0.03705s	Training Accuracy: 91.70%	Test Accuracy: 71.88%
[ 13/ 50]	       MNIST	Time 0.04598s	Training Accuracy: 88.09%	Test Accuracy: 50.00%
[ 13/ 50]	FashionMNIST	Time 0.03665s	Training Accuracy: 93.26%	Test Accuracy: 65.62%
[ 14/ 50]	       MNIST	Time 0.03967s	Training Accuracy: 92.48%	Test Accuracy: 59.38%
[ 14/ 50]	FashionMNIST	Time 0.03070s	Training Accuracy: 94.04%	Test Accuracy: 68.75%
[ 15/ 50]	       MNIST	Time 0.03510s	Training Accuracy: 95.51%	Test Accuracy: 62.50%
[ 15/ 50]	FashionMNIST	Time 0.03412s	Training Accuracy: 94.53%	Test Accuracy: 71.88%
[ 16/ 50]	       MNIST	Time 0.03197s	Training Accuracy: 96.39%	Test Accuracy: 53.12%
[ 16/ 50]	FashionMNIST	Time 0.04597s	Training Accuracy: 95.90%	Test Accuracy: 71.88%
[ 17/ 50]	       MNIST	Time 0.03163s	Training Accuracy: 98.63%	Test Accuracy: 62.50%
[ 17/ 50]	FashionMNIST	Time 0.03896s	Training Accuracy: 97.46%	Test Accuracy: 68.75%
[ 18/ 50]	       MNIST	Time 0.03537s	Training Accuracy: 99.22%	Test Accuracy: 56.25%
[ 18/ 50]	FashionMNIST	Time 0.03598s	Training Accuracy: 96.29%	Test Accuracy: 68.75%
[ 19/ 50]	       MNIST	Time 0.03230s	Training Accuracy: 99.61%	Test Accuracy: 62.50%
[ 19/ 50]	FashionMNIST	Time 0.03326s	Training Accuracy: 98.14%	Test Accuracy: 68.75%
[ 20/ 50]	       MNIST	Time 0.04105s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 20/ 50]	FashionMNIST	Time 0.03296s	Training Accuracy: 98.93%	Test Accuracy: 68.75%
[ 21/ 50]	       MNIST	Time 0.04249s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 21/ 50]	FashionMNIST	Time 0.03280s	Training Accuracy: 98.93%	Test Accuracy: 65.62%
[ 22/ 50]	       MNIST	Time 0.03191s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 22/ 50]	FashionMNIST	Time 0.03313s	Training Accuracy: 99.32%	Test Accuracy: 71.88%
[ 23/ 50]	       MNIST	Time 0.03378s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 23/ 50]	FashionMNIST	Time 0.04375s	Training Accuracy: 99.80%	Test Accuracy: 71.88%
[ 24/ 50]	       MNIST	Time 0.03504s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 24/ 50]	FashionMNIST	Time 0.04022s	Training Accuracy: 99.80%	Test Accuracy: 71.88%
[ 25/ 50]	       MNIST	Time 0.03795s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 25/ 50]	FashionMNIST	Time 0.03355s	Training Accuracy: 99.90%	Test Accuracy: 71.88%
[ 26/ 50]	       MNIST	Time 0.03699s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 26/ 50]	FashionMNIST	Time 0.03499s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 27/ 50]	       MNIST	Time 0.04573s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 27/ 50]	FashionMNIST	Time 0.03637s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 28/ 50]	       MNIST	Time 0.04304s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	FashionMNIST	Time 0.03420s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 29/ 50]	       MNIST	Time 0.03358s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 29/ 50]	FashionMNIST	Time 0.03345s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 30/ 50]	       MNIST	Time 0.03359s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 30/ 50]	FashionMNIST	Time 0.03352s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 31/ 50]	       MNIST	Time 0.03435s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 31/ 50]	FashionMNIST	Time 0.04218s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 32/ 50]	       MNIST	Time 0.03553s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 32/ 50]	FashionMNIST	Time 0.04625s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 33/ 50]	       MNIST	Time 0.03553s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 33/ 50]	FashionMNIST	Time 0.03530s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 34/ 50]	       MNIST	Time 0.03570s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 34/ 50]	FashionMNIST	Time 0.04480s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 35/ 50]	       MNIST	Time 0.04923s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 35/ 50]	FashionMNIST	Time 0.03280s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 36/ 50]	       MNIST	Time 0.04464s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 36/ 50]	FashionMNIST	Time 0.03503s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 37/ 50]	       MNIST	Time 0.03507s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 37/ 50]	FashionMNIST	Time 0.03757s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 38/ 50]	       MNIST	Time 0.03757s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 38/ 50]	FashionMNIST	Time 0.04606s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 39/ 50]	       MNIST	Time 0.03421s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 39/ 50]	FashionMNIST	Time 0.07067s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 40/ 50]	       MNIST	Time 0.04400s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 40/ 50]	FashionMNIST	Time 0.04168s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 41/ 50]	       MNIST	Time 0.03754s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 41/ 50]	FashionMNIST	Time 0.03731s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 42/ 50]	       MNIST	Time 0.04663s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 42/ 50]	FashionMNIST	Time 0.03683s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 43/ 50]	       MNIST	Time 0.04755s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 43/ 50]	FashionMNIST	Time 0.03857s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 44/ 50]	       MNIST	Time 0.03612s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 44/ 50]	FashionMNIST	Time 0.03841s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 45/ 50]	       MNIST	Time 0.03639s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 45/ 50]	FashionMNIST	Time 0.04628s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 46/ 50]	       MNIST	Time 0.03519s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 46/ 50]	FashionMNIST	Time 0.04898s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 47/ 50]	       MNIST	Time 0.04647s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 47/ 50]	FashionMNIST	Time 0.03154s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 48/ 50]	       MNIST	Time 0.03435s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 48/ 50]	FashionMNIST	Time 0.03427s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 49/ 50]	       MNIST	Time 0.04323s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 49/ 50]	FashionMNIST	Time 0.03301s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 50/ 50]	       MNIST	Time 0.04271s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 50/ 50]	FashionMNIST	Time 0.03275s	Training Accuracy: 100.00%	Test Accuracy: 68.75%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 68.75%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.