Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1760327431.499524 1442466 service.cc:158] XLA service 0x36311e50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760327431.499603 1442466 service.cc:166]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1760327431.500605 1442466 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1760327431.500644 1442466 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1760327431.500684 1442466 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1760327431.511429 1442466 cuda_dnn.cc:463] Loaded cuDNN version 91200
[  1/ 50]	       MNIST	Time 44.16144s	Training Accuracy: 35.64%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.04697s	Training Accuracy: 31.74%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.04045s	Training Accuracy: 35.64%	Test Accuracy: 37.50%
[  2/ 50]	FashionMNIST	Time 0.04474s	Training Accuracy: 45.21%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.04061s	Training Accuracy: 40.14%	Test Accuracy: 43.75%
[  3/ 50]	FashionMNIST	Time 0.03794s	Training Accuracy: 53.91%	Test Accuracy: 59.38%
[  4/ 50]	       MNIST	Time 0.03827s	Training Accuracy: 47.95%	Test Accuracy: 37.50%
[  4/ 50]	FashionMNIST	Time 0.03606s	Training Accuracy: 60.64%	Test Accuracy: 59.38%
[  5/ 50]	       MNIST	Time 0.04552s	Training Accuracy: 55.57%	Test Accuracy: 46.88%
[  5/ 50]	FashionMNIST	Time 0.04194s	Training Accuracy: 65.92%	Test Accuracy: 59.38%
[  6/ 50]	       MNIST	Time 0.03920s	Training Accuracy: 60.74%	Test Accuracy: 50.00%
[  6/ 50]	FashionMNIST	Time 0.03929s	Training Accuracy: 72.27%	Test Accuracy: 56.25%
[  7/ 50]	       MNIST	Time 0.03858s	Training Accuracy: 67.19%	Test Accuracy: 46.88%
[  7/ 50]	FashionMNIST	Time 0.05045s	Training Accuracy: 77.34%	Test Accuracy: 56.25%
[  8/ 50]	       MNIST	Time 0.03932s	Training Accuracy: 72.46%	Test Accuracy: 50.00%
[  8/ 50]	FashionMNIST	Time 0.04704s	Training Accuracy: 81.25%	Test Accuracy: 56.25%
[  9/ 50]	       MNIST	Time 0.03749s	Training Accuracy: 77.73%	Test Accuracy: 46.88%
[  9/ 50]	FashionMNIST	Time 0.03637s	Training Accuracy: 83.69%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.03652s	Training Accuracy: 83.59%	Test Accuracy: 46.88%
[ 10/ 50]	FashionMNIST	Time 0.03730s	Training Accuracy: 87.50%	Test Accuracy: 62.50%
[ 11/ 50]	       MNIST	Time 0.04709s	Training Accuracy: 88.38%	Test Accuracy: 50.00%
[ 11/ 50]	FashionMNIST	Time 0.03982s	Training Accuracy: 89.16%	Test Accuracy: 62.50%
[ 12/ 50]	       MNIST	Time 0.04864s	Training Accuracy: 90.04%	Test Accuracy: 53.12%
[ 12/ 50]	FashionMNIST	Time 0.03661s	Training Accuracy: 91.02%	Test Accuracy: 65.62%
[ 13/ 50]	       MNIST	Time 0.03445s	Training Accuracy: 91.60%	Test Accuracy: 46.88%
[ 13/ 50]	FashionMNIST	Time 0.03442s	Training Accuracy: 93.75%	Test Accuracy: 65.62%
[ 14/ 50]	       MNIST	Time 0.03414s	Training Accuracy: 94.63%	Test Accuracy: 53.12%
[ 14/ 50]	FashionMNIST	Time 0.04496s	Training Accuracy: 94.53%	Test Accuracy: 62.50%
[ 15/ 50]	       MNIST	Time 0.03390s	Training Accuracy: 95.31%	Test Accuracy: 56.25%
[ 15/ 50]	FashionMNIST	Time 0.04374s	Training Accuracy: 95.80%	Test Accuracy: 68.75%
[ 16/ 50]	       MNIST	Time 0.03302s	Training Accuracy: 97.56%	Test Accuracy: 53.12%
[ 16/ 50]	FashionMNIST	Time 0.03282s	Training Accuracy: 95.80%	Test Accuracy: 68.75%
[ 17/ 50]	       MNIST	Time 0.03355s	Training Accuracy: 98.93%	Test Accuracy: 50.00%
[ 17/ 50]	FashionMNIST	Time 0.03405s	Training Accuracy: 96.68%	Test Accuracy: 71.88%
[ 18/ 50]	       MNIST	Time 0.04450s	Training Accuracy: 99.41%	Test Accuracy: 53.12%
[ 18/ 50]	FashionMNIST	Time 0.03373s	Training Accuracy: 97.36%	Test Accuracy: 71.88%
[ 19/ 50]	       MNIST	Time 0.04407s	Training Accuracy: 99.51%	Test Accuracy: 53.12%
[ 19/ 50]	FashionMNIST	Time 0.03423s	Training Accuracy: 98.54%	Test Accuracy: 71.88%
[ 20/ 50]	       MNIST	Time 0.03257s	Training Accuracy: 99.80%	Test Accuracy: 50.00%
[ 20/ 50]	FashionMNIST	Time 0.03313s	Training Accuracy: 99.22%	Test Accuracy: 71.88%
[ 21/ 50]	       MNIST	Time 0.03418s	Training Accuracy: 99.90%	Test Accuracy: 53.12%
[ 21/ 50]	FashionMNIST	Time 0.04374s	Training Accuracy: 99.61%	Test Accuracy: 75.00%
[ 22/ 50]	       MNIST	Time 0.03329s	Training Accuracy: 99.90%	Test Accuracy: 50.00%
[ 22/ 50]	FashionMNIST	Time 0.04691s	Training Accuracy: 99.22%	Test Accuracy: 75.00%
[ 23/ 50]	       MNIST	Time 0.03269s	Training Accuracy: 99.90%	Test Accuracy: 53.12%
[ 23/ 50]	FashionMNIST	Time 0.03413s	Training Accuracy: 99.41%	Test Accuracy: 71.88%
[ 24/ 50]	       MNIST	Time 0.03324s	Training Accuracy: 99.90%	Test Accuracy: 50.00%
[ 24/ 50]	FashionMNIST	Time 0.03345s	Training Accuracy: 99.80%	Test Accuracy: 75.00%
[ 25/ 50]	       MNIST	Time 0.04341s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 25/ 50]	FashionMNIST	Time 0.03367s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 26/ 50]	       MNIST	Time 0.04409s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 26/ 50]	FashionMNIST	Time 0.03411s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 27/ 50]	       MNIST	Time 0.03305s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 27/ 50]	FashionMNIST	Time 0.03389s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 28/ 50]	       MNIST	Time 0.03307s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 28/ 50]	FashionMNIST	Time 0.04396s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 29/ 50]	       MNIST	Time 0.03289s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 29/ 50]	FashionMNIST	Time 0.04458s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 30/ 50]	       MNIST	Time 0.03305s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 30/ 50]	FashionMNIST	Time 0.03334s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 31/ 50]	       MNIST	Time 0.03341s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 31/ 50]	FashionMNIST	Time 0.03396s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 32/ 50]	       MNIST	Time 0.04374s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 32/ 50]	FashionMNIST	Time 0.03353s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 33/ 50]	       MNIST	Time 0.04334s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 33/ 50]	FashionMNIST	Time 0.03371s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 34/ 50]	       MNIST	Time 0.03367s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 34/ 50]	FashionMNIST	Time 0.03382s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 35/ 50]	       MNIST	Time 0.03343s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 35/ 50]	FashionMNIST	Time 0.04723s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 36/ 50]	       MNIST	Time 0.03216s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 36/ 50]	FashionMNIST	Time 0.04361s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 37/ 50]	       MNIST	Time 0.03154s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 37/ 50]	FashionMNIST	Time 0.03306s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 38/ 50]	       MNIST	Time 0.03213s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 38/ 50]	FashionMNIST	Time 0.03335s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 39/ 50]	       MNIST	Time 0.04400s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 39/ 50]	FashionMNIST	Time 0.03626s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 40/ 50]	       MNIST	Time 0.04176s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 40/ 50]	FashionMNIST	Time 0.03400s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 41/ 50]	       MNIST	Time 0.03407s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 41/ 50]	FashionMNIST	Time 0.03420s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 42/ 50]	       MNIST	Time 0.03370s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 42/ 50]	FashionMNIST	Time 0.04494s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 43/ 50]	       MNIST	Time 0.03360s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 43/ 50]	FashionMNIST	Time 0.04574s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 44/ 50]	       MNIST	Time 0.03407s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 44/ 50]	FashionMNIST	Time 0.03244s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 45/ 50]	       MNIST	Time 0.03424s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 45/ 50]	FashionMNIST	Time 0.03535s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 46/ 50]	       MNIST	Time 0.04443s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 46/ 50]	FashionMNIST	Time 0.03421s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 47/ 50]	       MNIST	Time 0.04485s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 47/ 50]	FashionMNIST	Time 0.03423s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 48/ 50]	       MNIST	Time 0.03363s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 48/ 50]	FashionMNIST	Time 0.03394s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 49/ 50]	       MNIST	Time 0.03419s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 49/ 50]	FashionMNIST	Time 0.03465s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 50/ 50]	       MNIST	Time 0.03257s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 50/ 50]	FashionMNIST	Time 0.04516s	Training Accuracy: 100.00%	Test Accuracy: 71.88%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 71.88%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.