Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1757733871.383174  558333 service.cc:163] XLA service 0x33e94510 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757733871.383222  558333 service.cc:171]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1757733871.383765  558333 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1757733871.383801  558333 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1757733871.383851  558333 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1757733871.393709  558333 cuda_dnn.cc:463] Loaded cuDNN version 91200
[  1/ 50]	       MNIST	Time 44.56980s	Training Accuracy: 35.45%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.03309s	Training Accuracy: 31.74%	Test Accuracy: 40.62%
[  2/ 50]	       MNIST	Time 0.03185s	Training Accuracy: 36.33%	Test Accuracy: 31.25%
[  2/ 50]	FashionMNIST	Time 0.03238s	Training Accuracy: 46.00%	Test Accuracy: 56.25%
[  3/ 50]	       MNIST	Time 0.04148s	Training Accuracy: 41.89%	Test Accuracy: 34.38%
[  3/ 50]	FashionMNIST	Time 0.03182s	Training Accuracy: 51.46%	Test Accuracy: 56.25%
[  4/ 50]	       MNIST	Time 0.03533s	Training Accuracy: 47.85%	Test Accuracy: 34.38%
[  4/ 50]	FashionMNIST	Time 0.03467s	Training Accuracy: 60.74%	Test Accuracy: 59.38%
[  5/ 50]	       MNIST	Time 0.03430s	Training Accuracy: 54.00%	Test Accuracy: 56.25%
[  5/ 50]	FashionMNIST	Time 0.03263s	Training Accuracy: 65.92%	Test Accuracy: 62.50%
[  6/ 50]	       MNIST	Time 0.03622s	Training Accuracy: 59.18%	Test Accuracy: 46.88%
[  6/ 50]	FashionMNIST	Time 0.03244s	Training Accuracy: 70.12%	Test Accuracy: 59.38%
[  7/ 50]	       MNIST	Time 0.03042s	Training Accuracy: 66.80%	Test Accuracy: 50.00%
[  7/ 50]	FashionMNIST	Time 0.03231s	Training Accuracy: 75.00%	Test Accuracy: 62.50%
[  8/ 50]	       MNIST	Time 0.03201s	Training Accuracy: 72.36%	Test Accuracy: 53.12%
[  8/ 50]	FashionMNIST	Time 0.03329s	Training Accuracy: 79.79%	Test Accuracy: 62.50%
[  9/ 50]	       MNIST	Time 0.03002s	Training Accuracy: 76.27%	Test Accuracy: 46.88%
[  9/ 50]	FashionMNIST	Time 0.03267s	Training Accuracy: 83.59%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.03257s	Training Accuracy: 81.45%	Test Accuracy: 50.00%
[ 10/ 50]	FashionMNIST	Time 0.04474s	Training Accuracy: 86.72%	Test Accuracy: 62.50%
[ 11/ 50]	       MNIST	Time 0.03266s	Training Accuracy: 85.94%	Test Accuracy: 46.88%
[ 11/ 50]	FashionMNIST	Time 0.04199s	Training Accuracy: 88.38%	Test Accuracy: 65.62%
[ 12/ 50]	       MNIST	Time 0.03297s	Training Accuracy: 88.18%	Test Accuracy: 53.12%
[ 12/ 50]	FashionMNIST	Time 0.04250s	Training Accuracy: 91.21%	Test Accuracy: 62.50%
[ 13/ 50]	       MNIST	Time 0.03014s	Training Accuracy: 90.82%	Test Accuracy: 43.75%
[ 13/ 50]	FashionMNIST	Time 0.04495s	Training Accuracy: 92.09%	Test Accuracy: 65.62%
[ 14/ 50]	       MNIST	Time 0.03248s	Training Accuracy: 93.46%	Test Accuracy: 46.88%
[ 14/ 50]	FashionMNIST	Time 0.04335s	Training Accuracy: 94.14%	Test Accuracy: 68.75%
[ 15/ 50]	       MNIST	Time 0.03416s	Training Accuracy: 95.41%	Test Accuracy: 46.88%
[ 15/ 50]	FashionMNIST	Time 0.03782s	Training Accuracy: 95.31%	Test Accuracy: 68.75%
[ 16/ 50]	       MNIST	Time 0.03223s	Training Accuracy: 97.17%	Test Accuracy: 50.00%
[ 16/ 50]	FashionMNIST	Time 0.04253s	Training Accuracy: 96.88%	Test Accuracy: 65.62%
[ 17/ 50]	       MNIST	Time 0.03075s	Training Accuracy: 97.85%	Test Accuracy: 46.88%
[ 17/ 50]	FashionMNIST	Time 0.04167s	Training Accuracy: 97.85%	Test Accuracy: 68.75%
[ 18/ 50]	       MNIST	Time 0.03219s	Training Accuracy: 98.73%	Test Accuracy: 56.25%
[ 18/ 50]	FashionMNIST	Time 0.03212s	Training Accuracy: 97.85%	Test Accuracy: 68.75%
[ 19/ 50]	       MNIST	Time 0.03253s	Training Accuracy: 99.41%	Test Accuracy: 46.88%
[ 19/ 50]	FashionMNIST	Time 0.03444s	Training Accuracy: 98.14%	Test Accuracy: 71.88%
[ 20/ 50]	       MNIST	Time 0.03319s	Training Accuracy: 99.71%	Test Accuracy: 46.88%
[ 20/ 50]	FashionMNIST	Time 0.03239s	Training Accuracy: 98.54%	Test Accuracy: 68.75%
[ 21/ 50]	       MNIST	Time 0.03199s	Training Accuracy: 99.80%	Test Accuracy: 46.88%
[ 21/ 50]	FashionMNIST	Time 0.03267s	Training Accuracy: 99.12%	Test Accuracy: 65.62%
[ 22/ 50]	       MNIST	Time 0.03260s	Training Accuracy: 99.80%	Test Accuracy: 50.00%
[ 22/ 50]	FashionMNIST	Time 0.03097s	Training Accuracy: 99.51%	Test Accuracy: 71.88%
[ 23/ 50]	       MNIST	Time 0.03227s	Training Accuracy: 99.90%	Test Accuracy: 50.00%
[ 23/ 50]	FashionMNIST	Time 0.03374s	Training Accuracy: 99.61%	Test Accuracy: 71.88%
[ 24/ 50]	       MNIST	Time 0.03311s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 24/ 50]	FashionMNIST	Time 0.03419s	Training Accuracy: 99.80%	Test Accuracy: 65.62%
[ 25/ 50]	       MNIST	Time 0.03391s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 25/ 50]	FashionMNIST	Time 0.03300s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 26/ 50]	       MNIST	Time 0.03345s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 26/ 50]	FashionMNIST	Time 0.03474s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 27/ 50]	       MNIST	Time 0.04702s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 27/ 50]	FashionMNIST	Time 0.03609s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 28/ 50]	       MNIST	Time 0.04288s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 28/ 50]	FashionMNIST	Time 0.03339s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 29/ 50]	       MNIST	Time 0.04576s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 29/ 50]	FashionMNIST	Time 0.03295s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 30/ 50]	       MNIST	Time 0.04665s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 30/ 50]	FashionMNIST	Time 0.03221s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 31/ 50]	       MNIST	Time 0.04267s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 31/ 50]	FashionMNIST	Time 0.03172s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 32/ 50]	       MNIST	Time 0.04374s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 32/ 50]	FashionMNIST	Time 0.03238s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 33/ 50]	       MNIST	Time 0.04248s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 33/ 50]	FashionMNIST	Time 0.03403s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 34/ 50]	       MNIST	Time 0.05058s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 34/ 50]	FashionMNIST	Time 0.03103s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 35/ 50]	       MNIST	Time 0.03344s	Training Accuracy: 100.00%	Test Accuracy: 50.00%
[ 35/ 50]	FashionMNIST	Time 0.03267s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 36/ 50]	       MNIST	Time 0.03040s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 36/ 50]	FashionMNIST	Time 0.03182s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 37/ 50]	       MNIST	Time 0.03037s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 37/ 50]	FashionMNIST	Time 0.03275s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 38/ 50]	       MNIST	Time 0.03337s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 38/ 50]	FashionMNIST	Time 0.03348s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 39/ 50]	       MNIST	Time 0.03353s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 39/ 50]	FashionMNIST	Time 0.03195s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 40/ 50]	       MNIST	Time 0.03242s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 40/ 50]	FashionMNIST	Time 0.03370s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 41/ 50]	       MNIST	Time 0.03344s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 41/ 50]	FashionMNIST	Time 0.03274s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 42/ 50]	       MNIST	Time 0.03316s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 42/ 50]	FashionMNIST	Time 0.03977s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 43/ 50]	       MNIST	Time 0.03278s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 43/ 50]	FashionMNIST	Time 0.04367s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 44/ 50]	       MNIST	Time 0.03240s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 44/ 50]	FashionMNIST	Time 0.04206s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 45/ 50]	       MNIST	Time 0.03392s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 45/ 50]	FashionMNIST	Time 0.04145s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 46/ 50]	       MNIST	Time 0.03313s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 46/ 50]	FashionMNIST	Time 0.04258s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 47/ 50]	       MNIST	Time 0.03340s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 47/ 50]	FashionMNIST	Time 0.04240s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 48/ 50]	       MNIST	Time 0.03493s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 48/ 50]	FashionMNIST	Time 0.04580s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 49/ 50]	       MNIST	Time 0.03367s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 49/ 50]	FashionMNIST	Time 0.04849s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 50/ 50]	       MNIST	Time 0.03276s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 50/ 50]	FashionMNIST	Time 0.03555s	Training Accuracy: 100.00%	Test Accuracy: 68.75%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 68.75%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.