Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling LuxComponentArraysExt...
   1588.6 ms  ✓ Lux → LuxComponentArraysExt
  1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling LuxMLUtilsExt...
   2155.7 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 164 already precompiled.

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end
create_model (generic function with 1 method)

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
2025-05-23 23:07:11.234361: I external/xla/xla/service/service.cc:152] XLA service 0x2c38c500 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-23 23:07:11.234398: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
2025-05-23 23:07:11.234405: I external/xla/xla/service/service.cc:160]   StreamExecutor device (1): Quadro RTX 5000, Compute Capability 7.5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748041631.239268  846235 se_gpu_pjrt_client.cc:1026] Using BFC allocator.
I0000 00:00:1748041631.239345  846235 gpu_helpers.cc:136] XLA backend allocating 12527321088 bytes on device 0 for BFCAllocator.
I0000 00:00:1748041631.239391  846235 gpu_helpers.cc:136] XLA backend allocating 12527321088 bytes on device 1 for BFCAllocator.
I0000 00:00:1748041631.239410  846235 gpu_helpers.cc:177] XLA backend will use up to 4175773696 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1748041631.239428  846235 gpu_helpers.cc:177] XLA backend will use up to 4175773696 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1748041631.254557  846235 cuda_dnn.cc:529] Loaded cuDNN version 90400
[  1/ 50]	       MNIST	Time 34.55751s	Training Accuracy: 34.57%	Test Accuracy: 37.50%
[  1/ 50]	FashionMNIST	Time 0.02479s	Training Accuracy: 32.62%	Test Accuracy: 43.75%
[  2/ 50]	       MNIST	Time 0.02557s	Training Accuracy: 36.91%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.07014s	Training Accuracy: 46.00%	Test Accuracy: 46.88%
[  3/ 50]	       MNIST	Time 0.02439s	Training Accuracy: 41.60%	Test Accuracy: 34.38%
[  3/ 50]	FashionMNIST	Time 0.02883s	Training Accuracy: 53.32%	Test Accuracy: 56.25%
[  4/ 50]	       MNIST	Time 0.02958s	Training Accuracy: 52.25%	Test Accuracy: 43.75%
[  4/ 50]	FashionMNIST	Time 0.02947s	Training Accuracy: 61.33%	Test Accuracy: 56.25%
[  5/ 50]	       MNIST	Time 0.02905s	Training Accuracy: 58.69%	Test Accuracy: 43.75%
[  5/ 50]	FashionMNIST	Time 0.02855s	Training Accuracy: 68.85%	Test Accuracy: 56.25%
[  6/ 50]	       MNIST	Time 0.02970s	Training Accuracy: 62.89%	Test Accuracy: 37.50%
[  6/ 50]	FashionMNIST	Time 0.03771s	Training Accuracy: 73.73%	Test Accuracy: 56.25%
[  7/ 50]	       MNIST	Time 0.02992s	Training Accuracy: 68.85%	Test Accuracy: 43.75%
[  7/ 50]	FashionMNIST	Time 0.05951s	Training Accuracy: 75.68%	Test Accuracy: 53.12%
[  8/ 50]	       MNIST	Time 0.02622s	Training Accuracy: 77.64%	Test Accuracy: 40.62%
[  8/ 50]	FashionMNIST	Time 0.03957s	Training Accuracy: 79.88%	Test Accuracy: 59.38%
[  9/ 50]	       MNIST	Time 0.03195s	Training Accuracy: 80.66%	Test Accuracy: 53.12%
[  9/ 50]	FashionMNIST	Time 0.02470s	Training Accuracy: 83.69%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.02585s	Training Accuracy: 83.69%	Test Accuracy: 43.75%
[ 10/ 50]	FashionMNIST	Time 0.03771s	Training Accuracy: 87.60%	Test Accuracy: 56.25%
[ 11/ 50]	       MNIST	Time 0.03004s	Training Accuracy: 88.67%	Test Accuracy: 53.12%
[ 11/ 50]	FashionMNIST	Time 0.04026s	Training Accuracy: 89.45%	Test Accuracy: 56.25%
[ 12/ 50]	       MNIST	Time 0.03049s	Training Accuracy: 90.43%	Test Accuracy: 56.25%
[ 12/ 50]	FashionMNIST	Time 0.03039s	Training Accuracy: 91.89%	Test Accuracy: 59.38%
[ 13/ 50]	       MNIST	Time 0.04104s	Training Accuracy: 92.48%	Test Accuracy: 56.25%
[ 13/ 50]	FashionMNIST	Time 0.03036s	Training Accuracy: 93.46%	Test Accuracy: 62.50%
[ 14/ 50]	       MNIST	Time 0.02906s	Training Accuracy: 95.61%	Test Accuracy: 56.25%
[ 14/ 50]	FashionMNIST	Time 0.03097s	Training Accuracy: 94.92%	Test Accuracy: 59.38%
[ 15/ 50]	       MNIST	Time 0.03053s	Training Accuracy: 96.78%	Test Accuracy: 65.62%
[ 15/ 50]	FashionMNIST	Time 0.03917s	Training Accuracy: 95.51%	Test Accuracy: 65.62%
[ 16/ 50]	       MNIST	Time 0.02944s	Training Accuracy: 97.85%	Test Accuracy: 59.38%
[ 16/ 50]	FashionMNIST	Time 0.03343s	Training Accuracy: 96.19%	Test Accuracy: 68.75%
[ 17/ 50]	       MNIST	Time 0.03745s	Training Accuracy: 99.02%	Test Accuracy: 53.12%
[ 17/ 50]	FashionMNIST	Time 0.02951s	Training Accuracy: 97.27%	Test Accuracy: 68.75%
[ 18/ 50]	       MNIST	Time 0.04196s	Training Accuracy: 99.71%	Test Accuracy: 65.62%
[ 18/ 50]	FashionMNIST	Time 0.02592s	Training Accuracy: 97.27%	Test Accuracy: 62.50%
[ 19/ 50]	       MNIST	Time 0.02518s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 19/ 50]	FashionMNIST	Time 0.03722s	Training Accuracy: 99.12%	Test Accuracy: 68.75%
[ 20/ 50]	       MNIST	Time 0.02847s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 20/ 50]	FashionMNIST	Time 0.02841s	Training Accuracy: 99.32%	Test Accuracy: 65.62%
[ 21/ 50]	       MNIST	Time 0.02773s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 21/ 50]	FashionMNIST	Time 0.02903s	Training Accuracy: 99.41%	Test Accuracy: 68.75%
[ 22/ 50]	       MNIST	Time 0.03598s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 22/ 50]	FashionMNIST	Time 0.02797s	Training Accuracy: 99.61%	Test Accuracy: 65.62%
[ 23/ 50]	       MNIST	Time 0.02877s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 23/ 50]	FashionMNIST	Time 0.03658s	Training Accuracy: 99.61%	Test Accuracy: 68.75%
[ 24/ 50]	       MNIST	Time 0.02395s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 24/ 50]	FashionMNIST	Time 0.03672s	Training Accuracy: 99.80%	Test Accuracy: 68.75%
[ 25/ 50]	       MNIST	Time 0.02609s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 25/ 50]	FashionMNIST	Time 0.03857s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 26/ 50]	       MNIST	Time 0.03572s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 26/ 50]	FashionMNIST	Time 0.02930s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 27/ 50]	       MNIST	Time 0.03031s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 27/ 50]	FashionMNIST	Time 0.02976s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 28/ 50]	       MNIST	Time 0.02959s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 28/ 50]	FashionMNIST	Time 0.03953s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	       MNIST	Time 0.02737s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 29/ 50]	FashionMNIST	Time 0.04044s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 30/ 50]	       MNIST	Time 0.03484s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 30/ 50]	FashionMNIST	Time 0.03259s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 31/ 50]	       MNIST	Time 0.02419s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 31/ 50]	FashionMNIST	Time 0.03003s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 32/ 50]	       MNIST	Time 0.03117s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 32/ 50]	FashionMNIST	Time 0.04124s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 33/ 50]	       MNIST	Time 0.03089s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 33/ 50]	FashionMNIST	Time 0.03005s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 34/ 50]	       MNIST	Time 0.04107s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 34/ 50]	FashionMNIST	Time 0.03090s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 35/ 50]	       MNIST	Time 0.03894s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 35/ 50]	FashionMNIST	Time 0.03089s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 36/ 50]	       MNIST	Time 0.03043s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 36/ 50]	FashionMNIST	Time 0.03887s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 37/ 50]	       MNIST	Time 0.03014s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 37/ 50]	FashionMNIST	Time 0.03003s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	       MNIST	Time 0.02927s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 38/ 50]	FashionMNIST	Time 0.02976s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 39/ 50]	       MNIST	Time 0.03921s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 39/ 50]	FashionMNIST	Time 0.02982s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 40/ 50]	       MNIST	Time 0.03057s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 40/ 50]	FashionMNIST	Time 0.03996s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 41/ 50]	       MNIST	Time 0.02957s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 41/ 50]	FashionMNIST	Time 0.03859s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 42/ 50]	       MNIST	Time 0.03112s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 42/ 50]	FashionMNIST	Time 0.03009s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 43/ 50]	       MNIST	Time 0.03947s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 43/ 50]	FashionMNIST	Time 0.02931s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 44/ 50]	       MNIST	Time 0.03843s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 44/ 50]	FashionMNIST	Time 0.04902s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 45/ 50]	       MNIST	Time 0.03064s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 45/ 50]	FashionMNIST	Time 0.04207s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 46/ 50]	       MNIST	Time 0.03110s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 46/ 50]	FashionMNIST	Time 0.04098s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 47/ 50]	       MNIST	Time 0.03015s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 47/ 50]	FashionMNIST	Time 0.03457s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 48/ 50]	       MNIST	Time 0.02820s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 48/ 50]	FashionMNIST	Time 0.02453s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 49/ 50]	       MNIST	Time 0.02659s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 49/ 50]	FashionMNIST	Time 0.02330s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 50/ 50]	       MNIST	Time 0.02833s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 50/ 50]	FashionMNIST	Time 0.02822s	Training Accuracy: 100.00%	Test Accuracy: 68.75%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 68.75%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.