Skip to content
0.6k

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end
create_model (generic function with 1 method)

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
2025-05-23 22:16:49.242547: I external/xla/xla/service/service.cc:152] XLA service 0xb2dd090 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-23 22:16:49.242592: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748038609.243417  691513 se_gpu_pjrt_client.cc:1026] Using BFC allocator.
I0000 00:00:1748038609.243494  691513 gpu_helpers.cc:136] XLA backend allocating 12527321088 bytes on device 0 for BFCAllocator.
I0000 00:00:1748038609.243539  691513 gpu_helpers.cc:177] XLA backend will use up to 4175773696 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1748038609.255071  691513 cuda_dnn.cc:529] Loaded cuDNN version 90400
[  1/ 50]	       MNIST	Time 32.08024s	Training Accuracy: 34.57%	Test Accuracy: 37.50%
[  1/ 50]	FashionMNIST	Time 0.03595s	Training Accuracy: 32.52%	Test Accuracy: 43.75%
[  2/ 50]	       MNIST	Time 0.02783s	Training Accuracy: 36.33%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.02969s	Training Accuracy: 46.09%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.08189s	Training Accuracy: 41.41%	Test Accuracy: 31.25%
[  3/ 50]	FashionMNIST	Time 0.02853s	Training Accuracy: 57.32%	Test Accuracy: 56.25%
[  4/ 50]	       MNIST	Time 0.02572s	Training Accuracy: 51.76%	Test Accuracy: 37.50%
[  4/ 50]	FashionMNIST	Time 0.02870s	Training Accuracy: 64.26%	Test Accuracy: 56.25%
[  5/ 50]	       MNIST	Time 0.02710s	Training Accuracy: 58.01%	Test Accuracy: 37.50%
[  5/ 50]	FashionMNIST	Time 0.02782s	Training Accuracy: 70.90%	Test Accuracy: 56.25%
[  6/ 50]	       MNIST	Time 0.02819s	Training Accuracy: 62.50%	Test Accuracy: 34.38%
[  6/ 50]	FashionMNIST	Time 0.04260s	Training Accuracy: 75.59%	Test Accuracy: 53.12%
[  7/ 50]	       MNIST	Time 0.02871s	Training Accuracy: 68.65%	Test Accuracy: 37.50%
[  7/ 50]	FashionMNIST	Time 0.02890s	Training Accuracy: 76.17%	Test Accuracy: 53.12%
[  8/ 50]	       MNIST	Time 0.02827s	Training Accuracy: 75.59%	Test Accuracy: 40.62%
[  8/ 50]	FashionMNIST	Time 0.02826s	Training Accuracy: 81.15%	Test Accuracy: 68.75%
[  9/ 50]	       MNIST	Time 0.02862s	Training Accuracy: 79.49%	Test Accuracy: 46.88%
[  9/ 50]	FashionMNIST	Time 0.03942s	Training Accuracy: 82.62%	Test Accuracy: 65.62%
[ 10/ 50]	       MNIST	Time 0.02774s	Training Accuracy: 82.23%	Test Accuracy: 50.00%
[ 10/ 50]	FashionMNIST	Time 0.03785s	Training Accuracy: 87.01%	Test Accuracy: 53.12%
[ 11/ 50]	       MNIST	Time 0.02743s	Training Accuracy: 87.99%	Test Accuracy: 56.25%
[ 11/ 50]	FashionMNIST	Time 0.02790s	Training Accuracy: 89.84%	Test Accuracy: 65.62%
[ 12/ 50]	       MNIST	Time 0.02756s	Training Accuracy: 90.23%	Test Accuracy: 53.12%
[ 12/ 50]	FashionMNIST	Time 0.02814s	Training Accuracy: 90.92%	Test Accuracy: 65.62%
[ 13/ 50]	       MNIST	Time 0.03810s	Training Accuracy: 94.43%	Test Accuracy: 62.50%
[ 13/ 50]	FashionMNIST	Time 0.02821s	Training Accuracy: 92.58%	Test Accuracy: 71.88%
[ 14/ 50]	       MNIST	Time 0.03819s	Training Accuracy: 94.53%	Test Accuracy: 68.75%
[ 14/ 50]	FashionMNIST	Time 0.02572s	Training Accuracy: 94.43%	Test Accuracy: 65.62%
[ 15/ 50]	       MNIST	Time 0.02406s	Training Accuracy: 95.90%	Test Accuracy: 68.75%
[ 15/ 50]	FashionMNIST	Time 0.03694s	Training Accuracy: 94.53%	Test Accuracy: 71.88%
[ 16/ 50]	       MNIST	Time 0.02931s	Training Accuracy: 98.34%	Test Accuracy: 65.62%
[ 16/ 50]	FashionMNIST	Time 0.03475s	Training Accuracy: 94.73%	Test Accuracy: 65.62%
[ 17/ 50]	       MNIST	Time 0.02785s	Training Accuracy: 99.51%	Test Accuracy: 68.75%
[ 17/ 50]	FashionMNIST	Time 0.02741s	Training Accuracy: 97.27%	Test Accuracy: 75.00%
[ 18/ 50]	       MNIST	Time 0.02832s	Training Accuracy: 99.41%	Test Accuracy: 65.62%
[ 18/ 50]	FashionMNIST	Time 0.02587s	Training Accuracy: 96.68%	Test Accuracy: 68.75%
[ 19/ 50]	       MNIST	Time 0.03045s	Training Accuracy: 99.80%	Test Accuracy: 62.50%
[ 19/ 50]	FashionMNIST	Time 0.02800s	Training Accuracy: 98.34%	Test Accuracy: 68.75%
[ 20/ 50]	       MNIST	Time 0.04251s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 20/ 50]	FashionMNIST	Time 0.02937s	Training Accuracy: 98.34%	Test Accuracy: 75.00%
[ 21/ 50]	       MNIST	Time 0.02413s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 21/ 50]	FashionMNIST	Time 0.02610s	Training Accuracy: 98.34%	Test Accuracy: 71.88%
[ 22/ 50]	       MNIST	Time 0.03518s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 22/ 50]	FashionMNIST	Time 0.03803s	Training Accuracy: 99.51%	Test Accuracy: 75.00%
[ 23/ 50]	       MNIST	Time 0.02993s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 23/ 50]	FashionMNIST	Time 0.03643s	Training Accuracy: 99.41%	Test Accuracy: 71.88%
[ 24/ 50]	       MNIST	Time 0.02417s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 24/ 50]	FashionMNIST	Time 0.02424s	Training Accuracy: 99.61%	Test Accuracy: 68.75%
[ 25/ 50]	       MNIST	Time 0.03291s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 25/ 50]	FashionMNIST	Time 0.02487s	Training Accuracy: 99.80%	Test Accuracy: 75.00%
[ 26/ 50]	       MNIST	Time 0.03400s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 26/ 50]	FashionMNIST	Time 0.02359s	Training Accuracy: 99.90%	Test Accuracy: 75.00%
[ 27/ 50]	       MNIST	Time 0.02423s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 27/ 50]	FashionMNIST	Time 0.02527s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 28/ 50]	       MNIST	Time 0.02535s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 28/ 50]	FashionMNIST	Time 0.03752s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 29/ 50]	       MNIST	Time 0.02791s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	FashionMNIST	Time 0.03506s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 30/ 50]	       MNIST	Time 0.02722s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 30/ 50]	FashionMNIST	Time 0.02799s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 31/ 50]	       MNIST	Time 0.02788s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 31/ 50]	FashionMNIST	Time 0.02829s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 32/ 50]	       MNIST	Time 0.03607s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 32/ 50]	FashionMNIST	Time 0.02770s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 33/ 50]	       MNIST	Time 0.02767s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 33/ 50]	FashionMNIST	Time 0.02776s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 34/ 50]	       MNIST	Time 0.02729s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 34/ 50]	FashionMNIST	Time 0.03773s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 35/ 50]	       MNIST	Time 0.02755s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 35/ 50]	FashionMNIST	Time 0.03549s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 36/ 50]	       MNIST	Time 0.02509s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 36/ 50]	FashionMNIST	Time 0.02437s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 37/ 50]	       MNIST	Time 0.02467s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 37/ 50]	FashionMNIST	Time 0.02402s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 38/ 50]	       MNIST	Time 0.03380s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	FashionMNIST	Time 0.02363s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 39/ 50]	       MNIST	Time 0.03350s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 39/ 50]	FashionMNIST	Time 0.02377s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 40/ 50]	       MNIST	Time 0.02312s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 40/ 50]	FashionMNIST	Time 0.02884s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 41/ 50]	       MNIST	Time 0.02530s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 41/ 50]	FashionMNIST	Time 0.03879s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 42/ 50]	       MNIST	Time 0.02852s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 42/ 50]	FashionMNIST	Time 0.02867s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 43/ 50]	       MNIST	Time 0.02872s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 43/ 50]	FashionMNIST	Time 0.02898s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 44/ 50]	       MNIST	Time 0.03926s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 44/ 50]	FashionMNIST	Time 0.02440s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 45/ 50]	       MNIST	Time 0.03108s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 45/ 50]	FashionMNIST	Time 0.02536s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 46/ 50]	       MNIST	Time 0.02559s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 46/ 50]	FashionMNIST	Time 0.02612s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 47/ 50]	       MNIST	Time 0.02663s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 47/ 50]	FashionMNIST	Time 0.03483s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 48/ 50]	       MNIST	Time 0.02572s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 48/ 50]	FashionMNIST	Time 0.03517s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 49/ 50]	       MNIST	Time 0.02648s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 49/ 50]	FashionMNIST	Time 0.02635s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 50/ 50]	       MNIST	Time 0.03640s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 50/ 50]	FashionMNIST	Time 0.02595s	Training Accuracy: 100.00%	Test Accuracy: 75.00%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 75.00%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.