Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1760386940.595062 2911630 service.cc:158] XLA service 0x309d94f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760386940.595125 2911630 service.cc:166]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1760386940.596210 2911630 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1760386940.596277 2911630 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1760386940.596351 2911630 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1760386940.608535 2911630 cuda_dnn.cc:463] Loaded cuDNN version 91200
[  1/ 50]	       MNIST	Time 45.51114s	Training Accuracy: 35.45%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.04402s	Training Accuracy: 32.81%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.04789s	Training Accuracy: 35.55%	Test Accuracy: 31.25%
[  2/ 50]	FashionMNIST	Time 0.03627s	Training Accuracy: 46.68%	Test Accuracy: 53.12%
[  3/ 50]	       MNIST	Time 0.03822s	Training Accuracy: 41.70%	Test Accuracy: 40.62%
[  3/ 50]	FashionMNIST	Time 0.03935s	Training Accuracy: 53.22%	Test Accuracy: 59.38%
[  4/ 50]	       MNIST	Time 0.03257s	Training Accuracy: 49.51%	Test Accuracy: 31.25%
[  4/ 50]	FashionMNIST	Time 0.03333s	Training Accuracy: 61.52%	Test Accuracy: 59.38%
[  5/ 50]	       MNIST	Time 0.06480s	Training Accuracy: 55.37%	Test Accuracy: 59.38%
[  5/ 50]	FashionMNIST	Time 0.03497s	Training Accuracy: 66.41%	Test Accuracy: 62.50%
[  6/ 50]	       MNIST	Time 0.03264s	Training Accuracy: 61.43%	Test Accuracy: 56.25%
[  6/ 50]	FashionMNIST	Time 0.03122s	Training Accuracy: 69.43%	Test Accuracy: 59.38%
[  7/ 50]	       MNIST	Time 0.04244s	Training Accuracy: 66.80%	Test Accuracy: 46.88%
[  7/ 50]	FashionMNIST	Time 0.03168s	Training Accuracy: 73.44%	Test Accuracy: 62.50%
[  8/ 50]	       MNIST	Time 0.04922s	Training Accuracy: 72.66%	Test Accuracy: 53.12%
[  8/ 50]	FashionMNIST	Time 0.03242s	Training Accuracy: 80.57%	Test Accuracy: 62.50%
[  9/ 50]	       MNIST	Time 0.03147s	Training Accuracy: 76.86%	Test Accuracy: 46.88%
[  9/ 50]	FashionMNIST	Time 0.03252s	Training Accuracy: 83.11%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.03266s	Training Accuracy: 82.81%	Test Accuracy: 59.38%
[ 10/ 50]	FashionMNIST	Time 0.04259s	Training Accuracy: 86.82%	Test Accuracy: 59.38%
[ 11/ 50]	       MNIST	Time 0.03197s	Training Accuracy: 87.89%	Test Accuracy: 62.50%
[ 11/ 50]	FashionMNIST	Time 0.04283s	Training Accuracy: 89.84%	Test Accuracy: 65.62%
[ 12/ 50]	       MNIST	Time 0.03215s	Training Accuracy: 87.11%	Test Accuracy: 65.62%
[ 12/ 50]	FashionMNIST	Time 0.03225s	Training Accuracy: 91.41%	Test Accuracy: 71.88%
[ 13/ 50]	       MNIST	Time 0.03201s	Training Accuracy: 91.21%	Test Accuracy: 56.25%
[ 13/ 50]	FashionMNIST	Time 0.03231s	Training Accuracy: 93.36%	Test Accuracy: 68.75%
[ 14/ 50]	       MNIST	Time 0.04293s	Training Accuracy: 93.46%	Test Accuracy: 50.00%
[ 14/ 50]	FashionMNIST	Time 0.03224s	Training Accuracy: 95.51%	Test Accuracy: 68.75%
[ 15/ 50]	       MNIST	Time 0.04448s	Training Accuracy: 95.41%	Test Accuracy: 53.12%
[ 15/ 50]	FashionMNIST	Time 0.03222s	Training Accuracy: 95.51%	Test Accuracy: 68.75%
[ 16/ 50]	       MNIST	Time 0.03162s	Training Accuracy: 97.17%	Test Accuracy: 59.38%
[ 16/ 50]	FashionMNIST	Time 0.03206s	Training Accuracy: 96.39%	Test Accuracy: 65.62%
[ 17/ 50]	       MNIST	Time 0.03201s	Training Accuracy: 98.93%	Test Accuracy: 56.25%
[ 17/ 50]	FashionMNIST	Time 0.04189s	Training Accuracy: 97.66%	Test Accuracy: 75.00%
[ 18/ 50]	       MNIST	Time 0.03212s	Training Accuracy: 99.12%	Test Accuracy: 59.38%
[ 18/ 50]	FashionMNIST	Time 0.04147s	Training Accuracy: 96.58%	Test Accuracy: 75.00%
[ 19/ 50]	       MNIST	Time 0.03180s	Training Accuracy: 99.22%	Test Accuracy: 56.25%
[ 19/ 50]	FashionMNIST	Time 0.03175s	Training Accuracy: 98.63%	Test Accuracy: 68.75%
[ 20/ 50]	       MNIST	Time 0.03191s	Training Accuracy: 99.90%	Test Accuracy: 53.12%
[ 20/ 50]	FashionMNIST	Time 0.03190s	Training Accuracy: 98.83%	Test Accuracy: 68.75%
[ 21/ 50]	       MNIST	Time 0.03208s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 21/ 50]	FashionMNIST	Time 0.03207s	Training Accuracy: 99.12%	Test Accuracy: 68.75%
[ 22/ 50]	       MNIST	Time 0.04156s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 22/ 50]	FashionMNIST	Time 0.03175s	Training Accuracy: 99.71%	Test Accuracy: 71.88%
[ 23/ 50]	       MNIST	Time 0.04212s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 23/ 50]	FashionMNIST	Time 0.03173s	Training Accuracy: 99.80%	Test Accuracy: 68.75%
[ 24/ 50]	       MNIST	Time 0.03235s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 24/ 50]	FashionMNIST	Time 0.03173s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 25/ 50]	       MNIST	Time 0.03215s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 25/ 50]	FashionMNIST	Time 0.04164s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 26/ 50]	       MNIST	Time 0.03622s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 26/ 50]	FashionMNIST	Time 0.04840s	Training Accuracy: 99.90%	Test Accuracy: 75.00%
[ 27/ 50]	       MNIST	Time 0.03701s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 27/ 50]	FashionMNIST	Time 0.03662s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 28/ 50]	       MNIST	Time 0.03629s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 28/ 50]	FashionMNIST	Time 0.03663s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	       MNIST	Time 0.04744s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 29/ 50]	FashionMNIST	Time 0.03592s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 30/ 50]	       MNIST	Time 0.04738s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 30/ 50]	FashionMNIST	Time 0.03661s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 31/ 50]	       MNIST	Time 0.03607s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 31/ 50]	FashionMNIST	Time 0.03657s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 32/ 50]	       MNIST	Time 0.03675s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 32/ 50]	FashionMNIST	Time 0.04829s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 33/ 50]	       MNIST	Time 0.03686s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 33/ 50]	FashionMNIST	Time 0.04756s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 34/ 50]	       MNIST	Time 0.03623s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 34/ 50]	FashionMNIST	Time 0.03631s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 35/ 50]	       MNIST	Time 0.03691s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 35/ 50]	FashionMNIST	Time 0.03504s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 36/ 50]	       MNIST	Time 0.04534s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 36/ 50]	FashionMNIST	Time 0.03318s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 37/ 50]	       MNIST	Time 0.04464s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 37/ 50]	FashionMNIST	Time 0.04773s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 38/ 50]	       MNIST	Time 0.02918s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 38/ 50]	FashionMNIST	Time 0.03112s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 39/ 50]	       MNIST	Time 0.03137s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 39/ 50]	FashionMNIST	Time 0.03261s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 40/ 50]	       MNIST	Time 0.03265s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 40/ 50]	FashionMNIST	Time 0.04339s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 41/ 50]	       MNIST	Time 0.03240s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 41/ 50]	FashionMNIST	Time 0.04280s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 42/ 50]	       MNIST	Time 0.03231s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 42/ 50]	FashionMNIST	Time 0.03237s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 43/ 50]	       MNIST	Time 0.03219s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 43/ 50]	FashionMNIST	Time 0.03274s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 44/ 50]	       MNIST	Time 0.04291s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 44/ 50]	FashionMNIST	Time 0.03238s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 45/ 50]	       MNIST	Time 0.04320s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 45/ 50]	FashionMNIST	Time 0.03244s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 46/ 50]	       MNIST	Time 0.03297s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 46/ 50]	FashionMNIST	Time 0.03286s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 47/ 50]	       MNIST	Time 0.03270s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 47/ 50]	FashionMNIST	Time 0.04404s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 48/ 50]	       MNIST	Time 0.03267s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 48/ 50]	FashionMNIST	Time 0.04342s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 49/ 50]	       MNIST	Time 0.03251s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 49/ 50]	FashionMNIST	Time 0.03255s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 50/ 50]	       MNIST	Time 0.03156s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 50/ 50]	FashionMNIST	Time 0.03187s	Training Accuracy: 100.00%	Test Accuracy: 59.38%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 59.38%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.