Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling LuxComponentArraysExt...
   1547.1 ms  ✓ Lux → LuxComponentArraysExt
  1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling ComponentArraysReactantExt...
  17145.5 ms  ✓ ComponentArrays → ComponentArraysReactantExt
  1 dependency successfully precompiled in 17 seconds. 95 already precompiled.
Precompiling ReactantOneHotArraysExt...
  17390.7 ms  ✓ Reactant → ReactantOneHotArraysExt
  1 dependency successfully precompiled in 18 seconds. 104 already precompiled.

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end
create_model (generic function with 1 method)

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
2025-07-14 00:14:02.399863: I external/xla/xla/service/service.cc:153] XLA service 0x44f4cf00 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-14 00:14:02.399906: I external/xla/xla/service/service.cc:161]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752452042.400753 2780017 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752452042.400826 2780017 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752452042.400899 2780017 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752452042.414801 2780017 cuda_dnn.cc:471] Loaded cuDNN version 90800
[  1/ 50]	       MNIST	Time 55.35623s	Training Accuracy: 34.77%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.04084s	Training Accuracy: 32.42%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.04593s	Training Accuracy: 36.43%	Test Accuracy: 31.25%
[  2/ 50]	FashionMNIST	Time 0.10024s	Training Accuracy: 45.51%	Test Accuracy: 43.75%
[  3/ 50]	       MNIST	Time 0.03791s	Training Accuracy: 39.26%	Test Accuracy: 40.62%
[  3/ 50]	FashionMNIST	Time 0.03511s	Training Accuracy: 51.27%	Test Accuracy: 62.50%
[  4/ 50]	       MNIST	Time 0.05604s	Training Accuracy: 49.51%	Test Accuracy: 43.75%
[  4/ 50]	FashionMNIST	Time 0.03575s	Training Accuracy: 60.25%	Test Accuracy: 53.12%
[  5/ 50]	       MNIST	Time 0.03147s	Training Accuracy: 56.35%	Test Accuracy: 37.50%
[  5/ 50]	FashionMNIST	Time 0.03641s	Training Accuracy: 65.43%	Test Accuracy: 53.12%
[  6/ 50]	       MNIST	Time 0.03436s	Training Accuracy: 61.82%	Test Accuracy: 46.88%
[  6/ 50]	FashionMNIST	Time 0.03308s	Training Accuracy: 69.24%	Test Accuracy: 59.38%
[  7/ 50]	       MNIST	Time 0.04164s	Training Accuracy: 67.87%	Test Accuracy: 46.88%
[  7/ 50]	FashionMNIST	Time 0.03288s	Training Accuracy: 75.00%	Test Accuracy: 59.38%
[  8/ 50]	       MNIST	Time 0.04463s	Training Accuracy: 74.90%	Test Accuracy: 50.00%
[  8/ 50]	FashionMNIST	Time 0.03537s	Training Accuracy: 78.32%	Test Accuracy: 62.50%
[  9/ 50]	       MNIST	Time 0.04308s	Training Accuracy: 78.52%	Test Accuracy: 53.12%
[  9/ 50]	FashionMNIST	Time 0.03216s	Training Accuracy: 79.49%	Test Accuracy: 65.62%
[ 10/ 50]	       MNIST	Time 0.04361s	Training Accuracy: 83.11%	Test Accuracy: 46.88%
[ 10/ 50]	FashionMNIST	Time 0.03323s	Training Accuracy: 84.77%	Test Accuracy: 62.50%
[ 11/ 50]	       MNIST	Time 0.04545s	Training Accuracy: 86.62%	Test Accuracy: 56.25%
[ 11/ 50]	FashionMNIST	Time 0.03466s	Training Accuracy: 86.82%	Test Accuracy: 62.50%
[ 12/ 50]	       MNIST	Time 0.03893s	Training Accuracy: 87.99%	Test Accuracy: 56.25%
[ 12/ 50]	FashionMNIST	Time 0.02956s	Training Accuracy: 89.06%	Test Accuracy: 68.75%
[ 13/ 50]	       MNIST	Time 0.03331s	Training Accuracy: 91.80%	Test Accuracy: 56.25%
[ 13/ 50]	FashionMNIST	Time 0.03299s	Training Accuracy: 92.29%	Test Accuracy: 68.75%
[ 14/ 50]	       MNIST	Time 0.03315s	Training Accuracy: 94.73%	Test Accuracy: 65.62%
[ 14/ 50]	FashionMNIST	Time 0.04797s	Training Accuracy: 94.04%	Test Accuracy: 68.75%
[ 15/ 50]	       MNIST	Time 0.03137s	Training Accuracy: 96.48%	Test Accuracy: 65.62%
[ 15/ 50]	FashionMNIST	Time 0.03372s	Training Accuracy: 95.02%	Test Accuracy: 71.88%
[ 16/ 50]	       MNIST	Time 0.03325s	Training Accuracy: 96.97%	Test Accuracy: 65.62%
[ 16/ 50]	FashionMNIST	Time 0.03351s	Training Accuracy: 95.61%	Test Accuracy: 65.62%
[ 17/ 50]	       MNIST	Time 0.03279s	Training Accuracy: 98.54%	Test Accuracy: 59.38%
[ 17/ 50]	FashionMNIST	Time 0.03353s	Training Accuracy: 96.58%	Test Accuracy: 65.62%
[ 18/ 50]	       MNIST	Time 0.03309s	Training Accuracy: 98.63%	Test Accuracy: 65.62%
[ 18/ 50]	FashionMNIST	Time 0.04281s	Training Accuracy: 97.56%	Test Accuracy: 71.88%
[ 19/ 50]	       MNIST	Time 0.03279s	Training Accuracy: 99.61%	Test Accuracy: 65.62%
[ 19/ 50]	FashionMNIST	Time 0.04275s	Training Accuracy: 98.63%	Test Accuracy: 65.62%
[ 20/ 50]	       MNIST	Time 0.03207s	Training Accuracy: 99.71%	Test Accuracy: 65.62%
[ 20/ 50]	FashionMNIST	Time 0.04248s	Training Accuracy: 98.73%	Test Accuracy: 71.88%
[ 21/ 50]	       MNIST	Time 0.03217s	Training Accuracy: 99.80%	Test Accuracy: 62.50%
[ 21/ 50]	FashionMNIST	Time 0.04091s	Training Accuracy: 98.73%	Test Accuracy: 65.62%
[ 22/ 50]	       MNIST	Time 0.03037s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 22/ 50]	FashionMNIST	Time 0.03914s	Training Accuracy: 99.22%	Test Accuracy: 71.88%
[ 23/ 50]	       MNIST	Time 0.03207s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 23/ 50]	FashionMNIST	Time 0.04149s	Training Accuracy: 99.61%	Test Accuracy: 68.75%
[ 24/ 50]	       MNIST	Time 0.03298s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 24/ 50]	FashionMNIST	Time 0.03368s	Training Accuracy: 99.80%	Test Accuracy: 68.75%
[ 25/ 50]	       MNIST	Time 0.03343s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 25/ 50]	FashionMNIST	Time 0.03209s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 26/ 50]	       MNIST	Time 0.03289s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 26/ 50]	FashionMNIST	Time 0.03258s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 27/ 50]	       MNIST	Time 0.03287s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 27/ 50]	FashionMNIST	Time 0.03252s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 28/ 50]	       MNIST	Time 0.03398s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 28/ 50]	FashionMNIST	Time 0.03327s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 29/ 50]	       MNIST	Time 0.03219s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	FashionMNIST	Time 0.03230s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 30/ 50]	       MNIST	Time 0.04004s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 30/ 50]	FashionMNIST	Time 0.03319s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 31/ 50]	       MNIST	Time 0.04118s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 31/ 50]	FashionMNIST	Time 0.03222s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 32/ 50]	       MNIST	Time 0.04338s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 32/ 50]	FashionMNIST	Time 0.03341s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 33/ 50]	       MNIST	Time 0.04328s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 33/ 50]	FashionMNIST	Time 0.03280s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 34/ 50]	       MNIST	Time 0.03890s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 34/ 50]	FashionMNIST	Time 0.03373s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 35/ 50]	       MNIST	Time 0.04283s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 35/ 50]	FashionMNIST	Time 0.03416s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 36/ 50]	       MNIST	Time 0.03327s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 36/ 50]	FashionMNIST	Time 0.03296s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 37/ 50]	       MNIST	Time 0.03300s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 37/ 50]	FashionMNIST	Time 0.03536s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 38/ 50]	       MNIST	Time 0.03277s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 38/ 50]	FashionMNIST	Time 0.03588s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 39/ 50]	       MNIST	Time 0.03306s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 39/ 50]	FashionMNIST	Time 0.03356s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 40/ 50]	       MNIST	Time 0.03386s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 40/ 50]	FashionMNIST	Time 0.03312s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 41/ 50]	       MNIST	Time 0.03102s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 41/ 50]	FashionMNIST	Time 0.04152s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 42/ 50]	       MNIST	Time 0.03327s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 42/ 50]	FashionMNIST	Time 0.04359s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 43/ 50]	       MNIST	Time 0.03270s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 43/ 50]	FashionMNIST	Time 0.03975s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 44/ 50]	       MNIST	Time 0.03293s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 44/ 50]	FashionMNIST	Time 0.04130s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 45/ 50]	       MNIST	Time 0.03165s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 45/ 50]	FashionMNIST	Time 0.04076s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 46/ 50]	       MNIST	Time 0.03385s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 46/ 50]	FashionMNIST	Time 0.04526s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 47/ 50]	       MNIST	Time 0.03352s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 47/ 50]	FashionMNIST	Time 0.03405s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 48/ 50]	       MNIST	Time 0.03415s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 48/ 50]	FashionMNIST	Time 0.03307s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 49/ 50]	       MNIST	Time 0.03261s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 49/ 50]	FashionMNIST	Time 0.03317s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 50/ 50]	       MNIST	Time 0.03306s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 50/ 50]	FashionMNIST	Time 0.04050s	Training Accuracy: 100.00%	Test Accuracy: 65.62%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 65.62%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.