Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1760850060.273683  116446 service.cc:158] XLA service 0x424dbbd0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760850060.273753  116446 service.cc:166]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1760850060.274729  116446 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1760850060.274796  116446 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1760850060.274863  116446 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1760850060.285008  116446 cuda_dnn.cc:463] Loaded cuDNN version 91200
[  1/ 50]	       MNIST	Time 39.95069s	Training Accuracy: 35.06%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.04406s	Training Accuracy: 31.93%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.03920s	Training Accuracy: 37.50%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.03601s	Training Accuracy: 46.58%	Test Accuracy: 43.75%
[  3/ 50]	       MNIST	Time 0.03824s	Training Accuracy: 41.31%	Test Accuracy: 43.75%
[  3/ 50]	FashionMNIST	Time 0.03510s	Training Accuracy: 53.03%	Test Accuracy: 59.38%
[  4/ 50]	       MNIST	Time 0.03051s	Training Accuracy: 54.20%	Test Accuracy: 43.75%
[  4/ 50]	FashionMNIST	Time 0.06470s	Training Accuracy: 61.62%	Test Accuracy: 59.38%
[  5/ 50]	       MNIST	Time 0.02971s	Training Accuracy: 57.13%	Test Accuracy: 37.50%
[  5/ 50]	FashionMNIST	Time 0.03132s	Training Accuracy: 65.43%	Test Accuracy: 56.25%
[  6/ 50]	       MNIST	Time 0.03102s	Training Accuracy: 63.48%	Test Accuracy: 43.75%
[  6/ 50]	FashionMNIST	Time 0.04012s	Training Accuracy: 71.58%	Test Accuracy: 59.38%
[  7/ 50]	       MNIST	Time 0.03097s	Training Accuracy: 71.88%	Test Accuracy: 43.75%
[  7/ 50]	FashionMNIST	Time 0.04005s	Training Accuracy: 76.27%	Test Accuracy: 59.38%
[  8/ 50]	       MNIST	Time 0.03079s	Training Accuracy: 76.07%	Test Accuracy: 43.75%
[  8/ 50]	FashionMNIST	Time 0.03124s	Training Accuracy: 80.47%	Test Accuracy: 65.62%
[  9/ 50]	       MNIST	Time 0.03996s	Training Accuracy: 79.30%	Test Accuracy: 50.00%
[  9/ 50]	FashionMNIST	Time 0.03093s	Training Accuracy: 83.20%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.03953s	Training Accuracy: 83.59%	Test Accuracy: 46.88%
[ 10/ 50]	FashionMNIST	Time 0.03084s	Training Accuracy: 86.91%	Test Accuracy: 62.50%
[ 11/ 50]	       MNIST	Time 0.03095s	Training Accuracy: 86.23%	Test Accuracy: 43.75%
[ 11/ 50]	FashionMNIST	Time 0.04041s	Training Accuracy: 89.06%	Test Accuracy: 62.50%
[ 12/ 50]	       MNIST	Time 0.03080s	Training Accuracy: 89.36%	Test Accuracy: 50.00%
[ 12/ 50]	FashionMNIST	Time 0.03970s	Training Accuracy: 89.65%	Test Accuracy: 62.50%
[ 13/ 50]	       MNIST	Time 0.03094s	Training Accuracy: 92.09%	Test Accuracy: 53.12%
[ 13/ 50]	FashionMNIST	Time 0.03138s	Training Accuracy: 91.41%	Test Accuracy: 59.38%
[ 14/ 50]	       MNIST	Time 0.04120s	Training Accuracy: 95.12%	Test Accuracy: 56.25%
[ 14/ 50]	FashionMNIST	Time 0.03188s	Training Accuracy: 94.14%	Test Accuracy: 62.50%
[ 15/ 50]	       MNIST	Time 0.04004s	Training Accuracy: 96.39%	Test Accuracy: 59.38%
[ 15/ 50]	FashionMNIST	Time 0.03147s	Training Accuracy: 94.34%	Test Accuracy: 62.50%
[ 16/ 50]	       MNIST	Time 0.03112s	Training Accuracy: 97.46%	Test Accuracy: 62.50%
[ 16/ 50]	FashionMNIST	Time 0.03101s	Training Accuracy: 95.31%	Test Accuracy: 68.75%
[ 17/ 50]	       MNIST	Time 0.03126s	Training Accuracy: 98.93%	Test Accuracy: 62.50%
[ 17/ 50]	FashionMNIST	Time 0.04034s	Training Accuracy: 96.58%	Test Accuracy: 65.62%
[ 18/ 50]	       MNIST	Time 0.03110s	Training Accuracy: 99.61%	Test Accuracy: 62.50%
[ 18/ 50]	FashionMNIST	Time 0.03131s	Training Accuracy: 97.36%	Test Accuracy: 65.62%
[ 19/ 50]	       MNIST	Time 0.03121s	Training Accuracy: 99.41%	Test Accuracy: 62.50%
[ 19/ 50]	FashionMNIST	Time 0.03079s	Training Accuracy: 98.24%	Test Accuracy: 75.00%
[ 20/ 50]	       MNIST	Time 0.03977s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 20/ 50]	FashionMNIST	Time 0.03098s	Training Accuracy: 99.02%	Test Accuracy: 68.75%
[ 21/ 50]	       MNIST	Time 0.02982s	Training Accuracy: 99.80%	Test Accuracy: 62.50%
[ 21/ 50]	FashionMNIST	Time 0.03070s	Training Accuracy: 99.12%	Test Accuracy: 68.75%
[ 22/ 50]	       MNIST	Time 0.03060s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 22/ 50]	FashionMNIST	Time 0.03983s	Training Accuracy: 99.61%	Test Accuracy: 75.00%
[ 23/ 50]	       MNIST	Time 0.03058s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 23/ 50]	FashionMNIST	Time 0.03084s	Training Accuracy: 99.80%	Test Accuracy: 75.00%
[ 24/ 50]	       MNIST	Time 0.03065s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 24/ 50]	FashionMNIST	Time 0.03107s	Training Accuracy: 99.80%	Test Accuracy: 71.88%
[ 25/ 50]	       MNIST	Time 0.03961s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 25/ 50]	FashionMNIST	Time 0.03050s	Training Accuracy: 99.90%	Test Accuracy: 75.00%
[ 26/ 50]	       MNIST	Time 0.03066s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 26/ 50]	FashionMNIST	Time 0.03064s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 27/ 50]	       MNIST	Time 0.03653s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 27/ 50]	FashionMNIST	Time 0.03873s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 28/ 50]	       MNIST	Time 0.03125s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	FashionMNIST	Time 0.03958s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 29/ 50]	       MNIST	Time 0.03118s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 29/ 50]	FashionMNIST	Time 0.03122s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 30/ 50]	       MNIST	Time 0.04025s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 30/ 50]	FashionMNIST	Time 0.03098s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 31/ 50]	       MNIST	Time 0.03964s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 31/ 50]	FashionMNIST	Time 0.03121s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 32/ 50]	       MNIST	Time 0.03125s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 32/ 50]	FashionMNIST	Time 0.04117s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 33/ 50]	       MNIST	Time 0.03108s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 33/ 50]	FashionMNIST	Time 0.04014s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 34/ 50]	       MNIST	Time 0.03126s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 34/ 50]	FashionMNIST	Time 0.03101s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 35/ 50]	       MNIST	Time 0.04072s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 35/ 50]	FashionMNIST	Time 0.03172s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 36/ 50]	       MNIST	Time 0.04191s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 36/ 50]	FashionMNIST	Time 0.03122s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 37/ 50]	       MNIST	Time 0.03088s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 37/ 50]	FashionMNIST	Time 0.04083s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 38/ 50]	       MNIST	Time 0.03215s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	FashionMNIST	Time 0.03974s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 39/ 50]	       MNIST	Time 0.03159s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 39/ 50]	FashionMNIST	Time 0.03140s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 40/ 50]	       MNIST	Time 0.03117s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 40/ 50]	FashionMNIST	Time 0.03145s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 41/ 50]	       MNIST	Time 0.04137s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 41/ 50]	FashionMNIST	Time 0.03099s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 42/ 50]	       MNIST	Time 0.03069s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 42/ 50]	FashionMNIST	Time 0.03092s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 43/ 50]	       MNIST	Time 0.03042s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 43/ 50]	FashionMNIST	Time 0.04024s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 44/ 50]	       MNIST	Time 0.03076s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 44/ 50]	FashionMNIST	Time 0.04017s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 45/ 50]	       MNIST	Time 0.03076s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 45/ 50]	FashionMNIST	Time 0.03070s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 46/ 50]	       MNIST	Time 0.03103s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 46/ 50]	FashionMNIST	Time 0.03015s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 47/ 50]	       MNIST	Time 0.03944s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 47/ 50]	FashionMNIST	Time 0.03054s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 48/ 50]	       MNIST	Time 0.03893s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 48/ 50]	FashionMNIST	Time 0.03086s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 49/ 50]	       MNIST	Time 0.03066s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 49/ 50]	FashionMNIST	Time 0.03177s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 50/ 50]	       MNIST	Time 0.03105s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 50/ 50]	FashionMNIST	Time 0.04099s	Training Accuracy: 100.00%	Test Accuracy: 68.75%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 68.75%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.