Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling ComponentArraysReactantExt...
  17777.8 ms  ✓ ComponentArrays → ComponentArraysReactantExt
  1 dependency successfully precompiled in 18 seconds. 95 already precompiled.
Precompiling ReactantOneHotArraysExt...
  17356.2 ms  ✓ Reactant → ReactantOneHotArraysExt
  1 dependency successfully precompiled in 18 seconds. 104 already precompiled.

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end
create_model (generic function with 1 method)

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
2025-07-09 04:19:37.074460: I external/xla/xla/service/service.cc:153] XLA service 0x37d40e50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-09 04:19:37.074504: I external/xla/xla/service/service.cc:161]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752034777.075523 1223409 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752034777.075620 1223409 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752034777.075682 1223409 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752034777.090291 1223409 cuda_dnn.cc:471] Loaded cuDNN version 90800
[  1/ 50]	       MNIST	Time 57.33054s	Training Accuracy: 35.25%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.04579s	Training Accuracy: 32.81%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.10136s	Training Accuracy: 36.72%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.04550s	Training Accuracy: 48.24%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.04622s	Training Accuracy: 40.72%	Test Accuracy: 37.50%
[  3/ 50]	FashionMNIST	Time 0.07281s	Training Accuracy: 54.10%	Test Accuracy: 53.12%
[  4/ 50]	       MNIST	Time 0.03742s	Training Accuracy: 52.93%	Test Accuracy: 43.75%
[  4/ 50]	FashionMNIST	Time 0.04111s	Training Accuracy: 61.82%	Test Accuracy: 59.38%
[  5/ 50]	       MNIST	Time 0.07351s	Training Accuracy: 55.37%	Test Accuracy: 40.62%
[  5/ 50]	FashionMNIST	Time 0.04199s	Training Accuracy: 66.41%	Test Accuracy: 59.38%
[  6/ 50]	       MNIST	Time 0.04080s	Training Accuracy: 61.52%	Test Accuracy: 37.50%
[  6/ 50]	FashionMNIST	Time 0.03549s	Training Accuracy: 71.29%	Test Accuracy: 62.50%
[  7/ 50]	       MNIST	Time 0.04441s	Training Accuracy: 67.97%	Test Accuracy: 40.62%
[  7/ 50]	FashionMNIST	Time 0.03976s	Training Accuracy: 75.59%	Test Accuracy: 65.62%
[  8/ 50]	       MNIST	Time 0.05213s	Training Accuracy: 76.66%	Test Accuracy: 46.88%
[  8/ 50]	FashionMNIST	Time 0.03696s	Training Accuracy: 79.69%	Test Accuracy: 62.50%
[  9/ 50]	       MNIST	Time 0.04874s	Training Accuracy: 80.27%	Test Accuracy: 53.12%
[  9/ 50]	FashionMNIST	Time 0.03583s	Training Accuracy: 81.54%	Test Accuracy: 65.62%
[ 10/ 50]	       MNIST	Time 0.04598s	Training Accuracy: 83.89%	Test Accuracy: 50.00%
[ 10/ 50]	FashionMNIST	Time 0.03729s	Training Accuracy: 86.04%	Test Accuracy: 68.75%
[ 11/ 50]	       MNIST	Time 0.05240s	Training Accuracy: 87.40%	Test Accuracy: 53.12%
[ 11/ 50]	FashionMNIST	Time 0.03900s	Training Accuracy: 88.28%	Test Accuracy: 62.50%
[ 12/ 50]	       MNIST	Time 0.04762s	Training Accuracy: 89.75%	Test Accuracy: 53.12%
[ 12/ 50]	FashionMNIST	Time 0.03764s	Training Accuracy: 89.36%	Test Accuracy: 71.88%
[ 13/ 50]	       MNIST	Time 0.04787s	Training Accuracy: 92.29%	Test Accuracy: 59.38%
[ 13/ 50]	FashionMNIST	Time 0.03656s	Training Accuracy: 91.99%	Test Accuracy: 71.88%
[ 14/ 50]	       MNIST	Time 0.03615s	Training Accuracy: 95.02%	Test Accuracy: 56.25%
[ 14/ 50]	FashionMNIST	Time 0.03541s	Training Accuracy: 93.65%	Test Accuracy: 65.62%
[ 15/ 50]	       MNIST	Time 0.03463s	Training Accuracy: 96.29%	Test Accuracy: 56.25%
[ 15/ 50]	FashionMNIST	Time 0.03565s	Training Accuracy: 94.63%	Test Accuracy: 68.75%
[ 16/ 50]	       MNIST	Time 0.03485s	Training Accuracy: 98.34%	Test Accuracy: 56.25%
[ 16/ 50]	FashionMNIST	Time 0.03470s	Training Accuracy: 96.09%	Test Accuracy: 68.75%
[ 17/ 50]	       MNIST	Time 0.03618s	Training Accuracy: 99.22%	Test Accuracy: 56.25%
[ 17/ 50]	FashionMNIST	Time 0.03676s	Training Accuracy: 96.97%	Test Accuracy: 68.75%
[ 18/ 50]	       MNIST	Time 0.03503s	Training Accuracy: 99.71%	Test Accuracy: 62.50%
[ 18/ 50]	FashionMNIST	Time 0.03550s	Training Accuracy: 96.00%	Test Accuracy: 65.62%
[ 19/ 50]	       MNIST	Time 0.03957s	Training Accuracy: 99.61%	Test Accuracy: 56.25%
[ 19/ 50]	FashionMNIST	Time 0.03919s	Training Accuracy: 98.24%	Test Accuracy: 71.88%
[ 20/ 50]	       MNIST	Time 0.03769s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 20/ 50]	FashionMNIST	Time 0.04939s	Training Accuracy: 98.73%	Test Accuracy: 75.00%
[ 21/ 50]	       MNIST	Time 0.03725s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 21/ 50]	FashionMNIST	Time 0.04380s	Training Accuracy: 98.93%	Test Accuracy: 71.88%
[ 22/ 50]	       MNIST	Time 0.03601s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 22/ 50]	FashionMNIST	Time 0.04651s	Training Accuracy: 99.41%	Test Accuracy: 75.00%
[ 23/ 50]	       MNIST	Time 0.03458s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 23/ 50]	FashionMNIST	Time 0.04530s	Training Accuracy: 99.80%	Test Accuracy: 68.75%
[ 24/ 50]	       MNIST	Time 0.03364s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 24/ 50]	FashionMNIST	Time 0.04542s	Training Accuracy: 99.71%	Test Accuracy: 71.88%
[ 25/ 50]	       MNIST	Time 0.03633s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 25/ 50]	FashionMNIST	Time 0.04904s	Training Accuracy: 99.90%	Test Accuracy: 71.88%
[ 26/ 50]	       MNIST	Time 0.03446s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 26/ 50]	FashionMNIST	Time 0.04674s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 27/ 50]	       MNIST	Time 0.03522s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 27/ 50]	FashionMNIST	Time 0.05190s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 28/ 50]	       MNIST	Time 0.03707s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	FashionMNIST	Time 0.03490s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 29/ 50]	       MNIST	Time 0.03558s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	FashionMNIST	Time 0.03559s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 30/ 50]	       MNIST	Time 0.03769s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 30/ 50]	FashionMNIST	Time 0.03677s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 31/ 50]	       MNIST	Time 0.03631s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 31/ 50]	FashionMNIST	Time 0.03688s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 32/ 50]	       MNIST	Time 0.03451s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 32/ 50]	FashionMNIST	Time 0.03478s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 33/ 50]	       MNIST	Time 0.03598s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 33/ 50]	FashionMNIST	Time 0.03562s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 34/ 50]	       MNIST	Time 0.03684s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 34/ 50]	FashionMNIST	Time 0.03753s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 35/ 50]	       MNIST	Time 0.04911s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 35/ 50]	FashionMNIST	Time 0.03721s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 36/ 50]	       MNIST	Time 0.04746s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 36/ 50]	FashionMNIST	Time 0.03472s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 37/ 50]	       MNIST	Time 0.04412s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 37/ 50]	FashionMNIST	Time 0.03388s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 38/ 50]	       MNIST	Time 0.04484s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 38/ 50]	FashionMNIST	Time 0.03401s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 39/ 50]	       MNIST	Time 0.04487s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 39/ 50]	FashionMNIST	Time 0.03149s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 40/ 50]	       MNIST	Time 0.04462s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 40/ 50]	FashionMNIST	Time 0.03384s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 41/ 50]	       MNIST	Time 0.04389s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 41/ 50]	FashionMNIST	Time 0.03421s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 42/ 50]	       MNIST	Time 0.03341s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 42/ 50]	FashionMNIST	Time 0.03376s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 43/ 50]	       MNIST	Time 0.03422s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 43/ 50]	FashionMNIST	Time 0.03456s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 44/ 50]	       MNIST	Time 0.03370s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 44/ 50]	FashionMNIST	Time 0.03409s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 45/ 50]	       MNIST	Time 0.03412s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 45/ 50]	FashionMNIST	Time 0.03500s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 46/ 50]	       MNIST	Time 0.03401s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 46/ 50]	FashionMNIST	Time 0.03981s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 47/ 50]	       MNIST	Time 0.03384s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 47/ 50]	FashionMNIST	Time 0.03333s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 48/ 50]	       MNIST	Time 0.03302s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 48/ 50]	FashionMNIST	Time 0.03311s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 49/ 50]	       MNIST	Time 0.03456s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 49/ 50]	FashionMNIST	Time 0.04666s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 50/ 50]	       MNIST	Time 0.03608s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 50/ 50]	FashionMNIST	Time 0.04766s	Training Accuracy: 100.00%	Test Accuracy: 71.88%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 71.88%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.