Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling LuxComponentArraysExt...
   1546.5 ms  ✓ Lux → LuxComponentArraysExt
  1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling MLDatasets...
   2102.5 ms  ✓ HDF5_jll
   3320.0 ms  ✓ DataDeps
   7469.4 ms  ✓ HDF5
   2376.6 ms  ✓ MAT
  10840.0 ms  ✓ MLDatasets
  5 dependencies successfully precompiled in 23 seconds. 198 already precompiled.
Precompiling LuxMLUtilsExt...
   2231.3 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 3 seconds. 164 already precompiled.
Precompiling ComponentArraysReactantExt...
  17647.2 ms  ✓ ComponentArrays → ComponentArraysReactantExt
  1 dependency successfully precompiled in 18 seconds. 95 already precompiled.
Precompiling ReactantOneHotArraysExt...
  17463.5 ms  ✓ Reactant → ReactantOneHotArraysExt
  1 dependency successfully precompiled in 18 seconds. 104 already precompiled.

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end
create_model (generic function with 1 method)

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
2025-07-03 16:08:53.312374: I external/xla/xla/service/service.cc:153] XLA service 0xbf761c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-03 16:08:53.312416: I external/xla/xla/service/service.cc:161]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1751558933.313242  251317 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1751558933.313302  251317 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1751558933.313346  251317 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1751558933.324739  251317 cuda_dnn.cc:471] Loaded cuDNN version 90800
[  1/ 50]	       MNIST	Time 56.37556s	Training Accuracy: 34.86%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.05224s	Training Accuracy: 31.25%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.03793s	Training Accuracy: 36.62%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.02990s	Training Accuracy: 46.48%	Test Accuracy: 53.12%
[  3/ 50]	       MNIST	Time 0.03149s	Training Accuracy: 39.75%	Test Accuracy: 40.62%
[  3/ 50]	FashionMNIST	Time 0.03857s	Training Accuracy: 53.03%	Test Accuracy: 65.62%
[  4/ 50]	       MNIST	Time 0.03015s	Training Accuracy: 52.44%	Test Accuracy: 46.88%
[  4/ 50]	FashionMNIST	Time 0.03397s	Training Accuracy: 62.89%	Test Accuracy: 62.50%
[  5/ 50]	       MNIST	Time 0.05155s	Training Accuracy: 56.35%	Test Accuracy: 46.88%
[  5/ 50]	FashionMNIST	Time 0.03870s	Training Accuracy: 66.11%	Test Accuracy: 59.38%
[  6/ 50]	       MNIST	Time 0.03693s	Training Accuracy: 62.99%	Test Accuracy: 50.00%
[  6/ 50]	FashionMNIST	Time 0.03518s	Training Accuracy: 73.44%	Test Accuracy: 62.50%
[  7/ 50]	       MNIST	Time 0.04435s	Training Accuracy: 70.31%	Test Accuracy: 53.12%
[  7/ 50]	FashionMNIST	Time 0.03614s	Training Accuracy: 75.98%	Test Accuracy: 65.62%
[  8/ 50]	       MNIST	Time 0.04210s	Training Accuracy: 77.15%	Test Accuracy: 53.12%
[  8/ 50]	FashionMNIST	Time 0.03797s	Training Accuracy: 80.47%	Test Accuracy: 65.62%
[  9/ 50]	       MNIST	Time 0.03622s	Training Accuracy: 80.27%	Test Accuracy: 56.25%
[  9/ 50]	FashionMNIST	Time 0.03619s	Training Accuracy: 83.69%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.03446s	Training Accuracy: 83.69%	Test Accuracy: 56.25%
[ 10/ 50]	FashionMNIST	Time 0.03388s	Training Accuracy: 86.13%	Test Accuracy: 62.50%
[ 11/ 50]	       MNIST	Time 0.03490s	Training Accuracy: 86.91%	Test Accuracy: 50.00%
[ 11/ 50]	FashionMNIST	Time 0.04449s	Training Accuracy: 88.28%	Test Accuracy: 62.50%
[ 12/ 50]	       MNIST	Time 0.03477s	Training Accuracy: 88.77%	Test Accuracy: 59.38%
[ 12/ 50]	FashionMNIST	Time 0.04674s	Training Accuracy: 89.55%	Test Accuracy: 65.62%
[ 13/ 50]	       MNIST	Time 0.03791s	Training Accuracy: 94.53%	Test Accuracy: 56.25%
[ 13/ 50]	FashionMNIST	Time 0.03539s	Training Accuracy: 92.58%	Test Accuracy: 68.75%
[ 14/ 50]	       MNIST	Time 0.03613s	Training Accuracy: 96.00%	Test Accuracy: 62.50%
[ 14/ 50]	FashionMNIST	Time 0.03434s	Training Accuracy: 93.55%	Test Accuracy: 65.62%
[ 15/ 50]	       MNIST	Time 0.04809s	Training Accuracy: 95.41%	Test Accuracy: 62.50%
[ 15/ 50]	FashionMNIST	Time 0.03416s	Training Accuracy: 94.04%	Test Accuracy: 68.75%
[ 16/ 50]	       MNIST	Time 0.04854s	Training Accuracy: 97.95%	Test Accuracy: 56.25%
[ 16/ 50]	FashionMNIST	Time 0.03932s	Training Accuracy: 96.00%	Test Accuracy: 68.75%
[ 17/ 50]	       MNIST	Time 0.03546s	Training Accuracy: 99.22%	Test Accuracy: 65.62%
[ 17/ 50]	FashionMNIST	Time 0.03662s	Training Accuracy: 95.51%	Test Accuracy: 71.88%
[ 18/ 50]	       MNIST	Time 0.03472s	Training Accuracy: 99.61%	Test Accuracy: 65.62%
[ 18/ 50]	FashionMNIST	Time 0.04578s	Training Accuracy: 96.00%	Test Accuracy: 68.75%
[ 19/ 50]	       MNIST	Time 0.03517s	Training Accuracy: 99.80%	Test Accuracy: 65.62%
[ 19/ 50]	FashionMNIST	Time 0.04614s	Training Accuracy: 98.05%	Test Accuracy: 71.88%
[ 20/ 50]	       MNIST	Time 0.03462s	Training Accuracy: 99.80%	Test Accuracy: 71.88%
[ 20/ 50]	FashionMNIST	Time 0.03546s	Training Accuracy: 99.12%	Test Accuracy: 75.00%
[ 21/ 50]	       MNIST	Time 0.03503s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 21/ 50]	FashionMNIST	Time 0.03509s	Training Accuracy: 98.54%	Test Accuracy: 68.75%
[ 22/ 50]	       MNIST	Time 0.04682s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 22/ 50]	FashionMNIST	Time 0.03287s	Training Accuracy: 99.22%	Test Accuracy: 75.00%
[ 23/ 50]	       MNIST	Time 0.04147s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 23/ 50]	FashionMNIST	Time 0.03267s	Training Accuracy: 99.71%	Test Accuracy: 75.00%
[ 24/ 50]	       MNIST	Time 0.03231s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 24/ 50]	FashionMNIST	Time 0.03555s	Training Accuracy: 99.51%	Test Accuracy: 75.00%
[ 25/ 50]	       MNIST	Time 0.03580s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 25/ 50]	FashionMNIST	Time 0.04585s	Training Accuracy: 99.90%	Test Accuracy: 75.00%
[ 26/ 50]	       MNIST	Time 0.03391s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 26/ 50]	FashionMNIST	Time 0.04533s	Training Accuracy: 99.80%	Test Accuracy: 75.00%
[ 27/ 50]	       MNIST	Time 0.03375s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 27/ 50]	FashionMNIST	Time 0.04852s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 28/ 50]	       MNIST	Time 0.03477s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 28/ 50]	FashionMNIST	Time 0.03027s	Training Accuracy: 99.80%	Test Accuracy: 75.00%
[ 29/ 50]	       MNIST	Time 0.03222s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	FashionMNIST	Time 0.03440s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 30/ 50]	       MNIST	Time 0.04315s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 30/ 50]	FashionMNIST	Time 0.03578s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 31/ 50]	       MNIST	Time 0.04885s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 31/ 50]	FashionMNIST	Time 0.03527s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 32/ 50]	       MNIST	Time 0.03635s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 32/ 50]	FashionMNIST	Time 0.03479s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 33/ 50]	       MNIST	Time 0.03558s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 33/ 50]	FashionMNIST	Time 0.04637s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 34/ 50]	       MNIST	Time 0.03406s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 34/ 50]	FashionMNIST	Time 0.04475s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 35/ 50]	       MNIST	Time 0.03414s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 35/ 50]	FashionMNIST	Time 0.03459s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 36/ 50]	       MNIST	Time 0.03495s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 36/ 50]	FashionMNIST	Time 0.03181s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 37/ 50]	       MNIST	Time 0.04173s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 37/ 50]	FashionMNIST	Time 0.03558s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 38/ 50]	       MNIST	Time 0.04769s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	FashionMNIST	Time 0.03643s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 39/ 50]	       MNIST	Time 0.03438s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 39/ 50]	FashionMNIST	Time 0.03534s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 40/ 50]	       MNIST	Time 0.03658s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 40/ 50]	FashionMNIST	Time 0.04670s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 41/ 50]	       MNIST	Time 0.03441s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 41/ 50]	FashionMNIST	Time 0.04477s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 42/ 50]	       MNIST	Time 0.03587s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 42/ 50]	FashionMNIST	Time 0.04644s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 43/ 50]	       MNIST	Time 0.03574s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 43/ 50]	FashionMNIST	Time 0.03477s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 44/ 50]	       MNIST	Time 0.03458s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 44/ 50]	FashionMNIST	Time 0.03578s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 45/ 50]	       MNIST	Time 0.04510s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 45/ 50]	FashionMNIST	Time 0.03415s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 46/ 50]	       MNIST	Time 0.04566s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 46/ 50]	FashionMNIST	Time 0.03365s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 47/ 50]	       MNIST	Time 0.03363s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 47/ 50]	FashionMNIST	Time 0.03349s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 48/ 50]	       MNIST	Time 0.03515s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 48/ 50]	FashionMNIST	Time 0.04409s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 49/ 50]	       MNIST	Time 0.03473s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 49/ 50]	FashionMNIST	Time 0.04580s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 50/ 50]	       MNIST	Time 0.04756s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 50/ 50]	FashionMNIST	Time 0.03439s	Training Accuracy: 100.00%	Test Accuracy: 75.00%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 75.00%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.