Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1758301779.915928 1237731 service.cc:158] XLA service 0x38a5e360 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1758301779.916017 1237731 service.cc:166]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1758301779.917083 1237731 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1758301779.917155 1237731 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1758301779.917217 1237731 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1758301779.927690 1237731 cuda_dnn.cc:463] Loaded cuDNN version 91200
[  1/ 50]	       MNIST	Time 44.91185s	Training Accuracy: 34.86%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.03526s	Training Accuracy: 32.62%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.03441s	Training Accuracy: 37.50%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.03214s	Training Accuracy: 46.09%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.03528s	Training Accuracy: 40.14%	Test Accuracy: 43.75%
[  3/ 50]	FashionMNIST	Time 0.03829s	Training Accuracy: 52.54%	Test Accuracy: 56.25%
[  4/ 50]	       MNIST	Time 0.03235s	Training Accuracy: 52.44%	Test Accuracy: 50.00%
[  4/ 50]	FashionMNIST	Time 0.03246s	Training Accuracy: 60.55%	Test Accuracy: 53.12%
[  5/ 50]	       MNIST	Time 0.03494s	Training Accuracy: 58.01%	Test Accuracy: 43.75%
[  5/ 50]	FashionMNIST	Time 0.03838s	Training Accuracy: 66.02%	Test Accuracy: 53.12%
[  6/ 50]	       MNIST	Time 0.03769s	Training Accuracy: 62.89%	Test Accuracy: 46.88%
[  6/ 50]	FashionMNIST	Time 0.03613s	Training Accuracy: 71.78%	Test Accuracy: 56.25%
[  7/ 50]	       MNIST	Time 0.03435s	Training Accuracy: 70.51%	Test Accuracy: 50.00%
[  7/ 50]	FashionMNIST	Time 0.03377s	Training Accuracy: 76.37%	Test Accuracy: 59.38%
[  8/ 50]	       MNIST	Time 0.03410s	Training Accuracy: 76.17%	Test Accuracy: 46.88%
[  8/ 50]	FashionMNIST	Time 0.03720s	Training Accuracy: 79.69%	Test Accuracy: 62.50%
[  9/ 50]	       MNIST	Time 0.04663s	Training Accuracy: 79.49%	Test Accuracy: 53.12%
[  9/ 50]	FashionMNIST	Time 0.03418s	Training Accuracy: 82.91%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.04292s	Training Accuracy: 85.25%	Test Accuracy: 43.75%
[ 10/ 50]	FashionMNIST	Time 0.03467s	Training Accuracy: 85.94%	Test Accuracy: 59.38%
[ 11/ 50]	       MNIST	Time 0.03336s	Training Accuracy: 89.55%	Test Accuracy: 46.88%
[ 11/ 50]	FashionMNIST	Time 0.03361s	Training Accuracy: 87.89%	Test Accuracy: 59.38%
[ 12/ 50]	       MNIST	Time 0.03349s	Training Accuracy: 89.55%	Test Accuracy: 46.88%
[ 12/ 50]	FashionMNIST	Time 0.05185s	Training Accuracy: 90.43%	Test Accuracy: 65.62%
[ 13/ 50]	       MNIST	Time 0.03251s	Training Accuracy: 92.38%	Test Accuracy: 56.25%
[ 13/ 50]	FashionMNIST	Time 0.03311s	Training Accuracy: 92.97%	Test Accuracy: 65.62%
[ 14/ 50]	       MNIST	Time 0.04378s	Training Accuracy: 95.21%	Test Accuracy: 56.25%
[ 14/ 50]	FashionMNIST	Time 0.03325s	Training Accuracy: 93.95%	Test Accuracy: 65.62%
[ 15/ 50]	       MNIST	Time 0.03277s	Training Accuracy: 97.07%	Test Accuracy: 56.25%
[ 15/ 50]	FashionMNIST	Time 0.03380s	Training Accuracy: 95.41%	Test Accuracy: 65.62%
[ 16/ 50]	       MNIST	Time 0.03270s	Training Accuracy: 98.24%	Test Accuracy: 59.38%
[ 16/ 50]	FashionMNIST	Time 0.04128s	Training Accuracy: 96.39%	Test Accuracy: 68.75%
[ 17/ 50]	       MNIST	Time 0.03232s	Training Accuracy: 98.83%	Test Accuracy: 53.12%
[ 17/ 50]	FashionMNIST	Time 0.03327s	Training Accuracy: 97.75%	Test Accuracy: 75.00%
[ 18/ 50]	       MNIST	Time 0.04257s	Training Accuracy: 99.32%	Test Accuracy: 56.25%
[ 18/ 50]	FashionMNIST	Time 0.03422s	Training Accuracy: 97.95%	Test Accuracy: 68.75%
[ 19/ 50]	       MNIST	Time 0.03325s	Training Accuracy: 99.61%	Test Accuracy: 62.50%
[ 19/ 50]	FashionMNIST	Time 0.03438s	Training Accuracy: 98.34%	Test Accuracy: 75.00%
[ 20/ 50]	       MNIST	Time 0.03402s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 20/ 50]	FashionMNIST	Time 0.04011s	Training Accuracy: 99.12%	Test Accuracy: 71.88%
[ 21/ 50]	       MNIST	Time 0.03425s	Training Accuracy: 99.80%	Test Accuracy: 62.50%
[ 21/ 50]	FashionMNIST	Time 0.03350s	Training Accuracy: 99.12%	Test Accuracy: 75.00%
[ 22/ 50]	       MNIST	Time 0.04300s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 22/ 50]	FashionMNIST	Time 0.03303s	Training Accuracy: 99.61%	Test Accuracy: 71.88%
[ 23/ 50]	       MNIST	Time 0.03452s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 23/ 50]	FashionMNIST	Time 0.04322s	Training Accuracy: 99.90%	Test Accuracy: 75.00%
[ 24/ 50]	       MNIST	Time 0.03920s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 24/ 50]	FashionMNIST	Time 0.04076s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 25/ 50]	       MNIST	Time 0.03256s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 25/ 50]	FashionMNIST	Time 0.03373s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 26/ 50]	       MNIST	Time 0.04086s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 26/ 50]	FashionMNIST	Time 0.03341s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 27/ 50]	       MNIST	Time 0.03206s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 27/ 50]	FashionMNIST	Time 0.04449s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 28/ 50]	       MNIST	Time 0.03732s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 28/ 50]	FashionMNIST	Time 0.04651s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 29/ 50]	       MNIST	Time 0.03596s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	FashionMNIST	Time 0.03495s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 30/ 50]	       MNIST	Time 0.03957s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 30/ 50]	FashionMNIST	Time 0.03293s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 31/ 50]	       MNIST	Time 0.03332s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 31/ 50]	FashionMNIST	Time 0.04431s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 32/ 50]	       MNIST	Time 0.04419s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 32/ 50]	FashionMNIST	Time 0.03834s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 33/ 50]	       MNIST	Time 0.03549s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 33/ 50]	FashionMNIST	Time 0.03447s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 34/ 50]	       MNIST	Time 0.04177s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 34/ 50]	FashionMNIST	Time 0.03574s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 35/ 50]	       MNIST	Time 0.03798s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 35/ 50]	FashionMNIST	Time 0.04611s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 36/ 50]	       MNIST	Time 0.03718s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 36/ 50]	FashionMNIST	Time 0.04562s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 37/ 50]	       MNIST	Time 0.03355s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 37/ 50]	FashionMNIST	Time 0.03521s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	       MNIST	Time 0.04263s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	FashionMNIST	Time 0.03302s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 39/ 50]	       MNIST	Time 0.03432s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 39/ 50]	FashionMNIST	Time 0.04227s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 40/ 50]	       MNIST	Time 0.03469s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 40/ 50]	FashionMNIST	Time 0.04237s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 41/ 50]	       MNIST	Time 0.03344s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 41/ 50]	FashionMNIST	Time 0.03709s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 42/ 50]	       MNIST	Time 0.04165s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 42/ 50]	FashionMNIST	Time 0.03256s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 43/ 50]	       MNIST	Time 0.03509s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 43/ 50]	FashionMNIST	Time 0.04235s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 44/ 50]	       MNIST	Time 0.03321s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 44/ 50]	FashionMNIST	Time 0.04133s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 45/ 50]	       MNIST	Time 0.03385s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 45/ 50]	FashionMNIST	Time 0.03415s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 46/ 50]	       MNIST	Time 0.04073s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 46/ 50]	FashionMNIST	Time 0.03730s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 47/ 50]	       MNIST	Time 0.03849s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 47/ 50]	FashionMNIST	Time 0.05852s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 48/ 50]	       MNIST	Time 0.04812s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 48/ 50]	FashionMNIST	Time 0.04774s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 49/ 50]	       MNIST	Time 0.03226s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 49/ 50]	FashionMNIST	Time 0.03401s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 50/ 50]	       MNIST	Time 0.04350s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 50/ 50]	FashionMNIST	Time 0.03333s	Training Accuracy: 100.00%	Test Accuracy: 65.62%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 65.62%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.