Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
2025-08-02 22:52:03.754867: I external/xla/xla/service/service.cc:163] XLA service 0x29c5f0a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-08-02 22:52:03.754917: I external/xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1754175123.755912  541672 se_gpu_pjrt_client.cc:1380] Using BFC allocator.
I0000 00:00:1754175123.755999  541672 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1754175123.756036  541672 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
2025-08-02 22:52:03.784723: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800
[  1/ 50]	       MNIST	Time 48.35461s	Training Accuracy: 35.16%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.05158s	Training Accuracy: 32.91%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.04933s	Training Accuracy: 35.25%	Test Accuracy: 37.50%
[  2/ 50]	FashionMNIST	Time 0.30378s	Training Accuracy: 46.48%	Test Accuracy: 53.12%
[  3/ 50]	       MNIST	Time 0.04965s	Training Accuracy: 42.48%	Test Accuracy: 31.25%
[  3/ 50]	FashionMNIST	Time 0.04336s	Training Accuracy: 53.81%	Test Accuracy: 59.38%
[  4/ 50]	       MNIST	Time 0.18753s	Training Accuracy: 52.15%	Test Accuracy: 43.75%
[  4/ 50]	FashionMNIST	Time 0.03714s	Training Accuracy: 61.82%	Test Accuracy: 59.38%
[  5/ 50]	       MNIST	Time 0.03955s	Training Accuracy: 57.71%	Test Accuracy: 43.75%
[  5/ 50]	FashionMNIST	Time 0.04418s	Training Accuracy: 67.68%	Test Accuracy: 56.25%
[  6/ 50]	       MNIST	Time 0.04251s	Training Accuracy: 63.28%	Test Accuracy: 50.00%
[  6/ 50]	FashionMNIST	Time 0.04186s	Training Accuracy: 74.71%	Test Accuracy: 56.25%
[  7/ 50]	       MNIST	Time 0.03328s	Training Accuracy: 72.36%	Test Accuracy: 43.75%
[  7/ 50]	FashionMNIST	Time 0.03312s	Training Accuracy: 74.80%	Test Accuracy: 62.50%
[  8/ 50]	       MNIST	Time 0.03298s	Training Accuracy: 76.17%	Test Accuracy: 50.00%
[  8/ 50]	FashionMNIST	Time 0.05678s	Training Accuracy: 80.66%	Test Accuracy: 62.50%
[  9/ 50]	       MNIST	Time 0.03372s	Training Accuracy: 80.18%	Test Accuracy: 56.25%
[  9/ 50]	FashionMNIST	Time 0.03201s	Training Accuracy: 83.01%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.04281s	Training Accuracy: 85.45%	Test Accuracy: 59.38%
[ 10/ 50]	FashionMNIST	Time 0.03385s	Training Accuracy: 86.62%	Test Accuracy: 59.38%
[ 11/ 50]	       MNIST	Time 0.04379s	Training Accuracy: 86.33%	Test Accuracy: 53.12%
[ 11/ 50]	FashionMNIST	Time 0.03565s	Training Accuracy: 87.89%	Test Accuracy: 62.50%
[ 12/ 50]	       MNIST	Time 0.03655s	Training Accuracy: 91.02%	Test Accuracy: 50.00%
[ 12/ 50]	FashionMNIST	Time 0.03838s	Training Accuracy: 90.04%	Test Accuracy: 65.62%
[ 13/ 50]	       MNIST	Time 0.03864s	Training Accuracy: 93.55%	Test Accuracy: 56.25%
[ 13/ 50]	FashionMNIST	Time 0.04656s	Training Accuracy: 91.80%	Test Accuracy: 65.62%
[ 14/ 50]	       MNIST	Time 0.03584s	Training Accuracy: 95.31%	Test Accuracy: 59.38%
[ 14/ 50]	FashionMNIST	Time 0.04604s	Training Accuracy: 93.75%	Test Accuracy: 62.50%
[ 15/ 50]	       MNIST	Time 0.03807s	Training Accuracy: 96.29%	Test Accuracy: 59.38%
[ 15/ 50]	FashionMNIST	Time 0.04020s	Training Accuracy: 94.92%	Test Accuracy: 68.75%
[ 16/ 50]	       MNIST	Time 0.03811s	Training Accuracy: 98.63%	Test Accuracy: 65.62%
[ 16/ 50]	FashionMNIST	Time 0.03614s	Training Accuracy: 96.97%	Test Accuracy: 65.62%
[ 17/ 50]	       MNIST	Time 0.03897s	Training Accuracy: 99.51%	Test Accuracy: 65.62%
[ 17/ 50]	FashionMNIST	Time 0.03533s	Training Accuracy: 97.85%	Test Accuracy: 68.75%
[ 18/ 50]	       MNIST	Time 0.04543s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 18/ 50]	FashionMNIST	Time 0.03541s	Training Accuracy: 97.27%	Test Accuracy: 68.75%
[ 19/ 50]	       MNIST	Time 0.03597s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 19/ 50]	FashionMNIST	Time 0.04515s	Training Accuracy: 98.93%	Test Accuracy: 68.75%
[ 20/ 50]	       MNIST	Time 0.03635s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 20/ 50]	FashionMNIST	Time 0.04648s	Training Accuracy: 99.41%	Test Accuracy: 68.75%
[ 21/ 50]	       MNIST	Time 0.03533s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 21/ 50]	FashionMNIST	Time 0.03712s	Training Accuracy: 99.51%	Test Accuracy: 71.88%
[ 22/ 50]	       MNIST	Time 0.03477s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 22/ 50]	FashionMNIST	Time 0.03834s	Training Accuracy: 99.61%	Test Accuracy: 71.88%
[ 23/ 50]	       MNIST	Time 0.04557s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 23/ 50]	FashionMNIST	Time 0.03901s	Training Accuracy: 99.80%	Test Accuracy: 71.88%
[ 24/ 50]	       MNIST	Time 0.04464s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 24/ 50]	FashionMNIST	Time 0.03433s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 25/ 50]	       MNIST	Time 0.03397s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 25/ 50]	FashionMNIST	Time 0.03400s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 26/ 50]	       MNIST	Time 0.03577s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 26/ 50]	FashionMNIST	Time 0.04462s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 27/ 50]	       MNIST	Time 0.03558s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 27/ 50]	FashionMNIST	Time 0.03634s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 28/ 50]	       MNIST	Time 0.03557s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	FashionMNIST	Time 0.03785s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 29/ 50]	       MNIST	Time 0.04760s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 29/ 50]	FashionMNIST	Time 0.03795s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 30/ 50]	       MNIST	Time 0.04645s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 30/ 50]	FashionMNIST	Time 0.03590s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 31/ 50]	       MNIST	Time 0.03601s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 31/ 50]	FashionMNIST	Time 0.03664s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 32/ 50]	       MNIST	Time 0.03673s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 32/ 50]	FashionMNIST	Time 0.04720s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 33/ 50]	       MNIST	Time 0.03415s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 33/ 50]	FashionMNIST	Time 0.04319s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 34/ 50]	       MNIST	Time 0.03745s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 34/ 50]	FashionMNIST	Time 0.03704s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 35/ 50]	       MNIST	Time 0.04601s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 35/ 50]	FashionMNIST	Time 0.04599s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 36/ 50]	       MNIST	Time 0.05098s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 36/ 50]	FashionMNIST	Time 0.07133s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 37/ 50]	       MNIST	Time 0.03926s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 37/ 50]	FashionMNIST	Time 0.03870s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 38/ 50]	       MNIST	Time 0.03492s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 38/ 50]	FashionMNIST	Time 0.04366s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 39/ 50]	       MNIST	Time 0.03533s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 39/ 50]	FashionMNIST	Time 0.04341s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 40/ 50]	       MNIST	Time 0.03603s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 40/ 50]	FashionMNIST	Time 0.03649s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 41/ 50]	       MNIST	Time 0.03704s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 41/ 50]	FashionMNIST	Time 0.03729s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 42/ 50]	       MNIST	Time 0.04816s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 42/ 50]	FashionMNIST	Time 0.03947s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 43/ 50]	       MNIST	Time 0.04503s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 43/ 50]	FashionMNIST	Time 0.03645s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 44/ 50]	       MNIST	Time 0.03611s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 44/ 50]	FashionMNIST	Time 0.04499s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 45/ 50]	       MNIST	Time 0.03547s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 45/ 50]	FashionMNIST	Time 0.04758s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 46/ 50]	       MNIST	Time 0.03338s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 46/ 50]	FashionMNIST	Time 0.03297s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 47/ 50]	       MNIST	Time 0.03528s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 47/ 50]	FashionMNIST	Time 0.03517s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 48/ 50]	       MNIST	Time 0.04368s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 48/ 50]	FashionMNIST	Time 0.03506s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 49/ 50]	       MNIST	Time 0.04569s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 49/ 50]	FashionMNIST	Time 0.03487s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 50/ 50]	       MNIST	Time 0.03529s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 50/ 50]	FashionMNIST	Time 0.03437s	Training Accuracy: 100.00%	Test Accuracy: 68.75%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 68.75%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.