Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, ReactantLoading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endImplement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endDefining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
endCreate and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
endDefine Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
ŷ, _ = model((data_idx, x), ps, st)
target_class = y |> cdev |> onecold
predicted_class = ŷ |> cdev |> onecold
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endTraining
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = load_datasets() |> dev
Random.seed!(1234)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 52.14702s Training Accuracy: 34.57% Test Accuracy: 37.50%
[ 1/ 50] FashionMNIST Time 0.09470s Training Accuracy: 32.52% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.09591s Training Accuracy: 36.33% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.09175s Training Accuracy: 46.19% Test Accuracy: 46.88%
[ 3/ 50] MNIST Time 0.09947s Training Accuracy: 42.77% Test Accuracy: 28.12%
[ 3/ 50] FashionMNIST Time 0.09327s Training Accuracy: 56.74% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.09146s Training Accuracy: 51.17% Test Accuracy: 37.50%
[ 4/ 50] FashionMNIST Time 0.08886s Training Accuracy: 63.48% Test Accuracy: 62.50%
[ 5/ 50] MNIST Time 0.09005s Training Accuracy: 55.96% Test Accuracy: 43.75%
[ 5/ 50] FashionMNIST Time 0.08774s Training Accuracy: 71.19% Test Accuracy: 53.12%
[ 6/ 50] MNIST Time 0.08640s Training Accuracy: 63.09% Test Accuracy: 34.38%
[ 6/ 50] FashionMNIST Time 0.09766s Training Accuracy: 74.32% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.08802s Training Accuracy: 67.97% Test Accuracy: 50.00%
[ 7/ 50] FashionMNIST Time 0.08768s Training Accuracy: 76.37% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.09663s Training Accuracy: 74.80% Test Accuracy: 46.88%
[ 8/ 50] FashionMNIST Time 0.08709s Training Accuracy: 81.93% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.08843s Training Accuracy: 81.05% Test Accuracy: 50.00%
[ 9/ 50] FashionMNIST Time 0.09636s Training Accuracy: 85.16% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.08886s Training Accuracy: 82.42% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.08932s Training Accuracy: 87.11% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.10146s Training Accuracy: 87.30% Test Accuracy: 50.00%
[ 11/ 50] FashionMNIST Time 0.09334s Training Accuracy: 89.45% Test Accuracy: 68.75%
[ 12/ 50] MNIST Time 0.08959s Training Accuracy: 89.75% Test Accuracy: 50.00%
[ 12/ 50] FashionMNIST Time 0.10023s Training Accuracy: 90.82% Test Accuracy: 71.88%
[ 13/ 50] MNIST Time 0.09370s Training Accuracy: 93.85% Test Accuracy: 59.38%
[ 13/ 50] FashionMNIST Time 0.09390s Training Accuracy: 94.34% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.09533s Training Accuracy: 93.95% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.08964s Training Accuracy: 95.51% Test Accuracy: 68.75%
[ 15/ 50] MNIST Time 0.08720s Training Accuracy: 95.61% Test Accuracy: 62.50%
[ 15/ 50] FashionMNIST Time 0.08979s Training Accuracy: 94.73% Test Accuracy: 71.88%
[ 16/ 50] MNIST Time 0.08736s Training Accuracy: 97.95% Test Accuracy: 62.50%
[ 16/ 50] FashionMNIST Time 0.08851s Training Accuracy: 96.48% Test Accuracy: 75.00%
[ 17/ 50] MNIST Time 0.09153s Training Accuracy: 99.32% Test Accuracy: 62.50%
[ 17/ 50] FashionMNIST Time 0.09371s Training Accuracy: 97.66% Test Accuracy: 71.88%
[ 18/ 50] MNIST Time 0.09028s Training Accuracy: 99.51% Test Accuracy: 59.38%
[ 18/ 50] FashionMNIST Time 0.09046s Training Accuracy: 97.46% Test Accuracy: 71.88%
[ 19/ 50] MNIST Time 0.09221s Training Accuracy: 99.61% Test Accuracy: 59.38%
[ 19/ 50] FashionMNIST Time 0.09063s Training Accuracy: 98.34% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.09383s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 20/ 50] FashionMNIST Time 0.09382s Training Accuracy: 98.54% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.10026s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.09206s Training Accuracy: 99.22% Test Accuracy: 75.00%
[ 22/ 50] MNIST Time 0.09153s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 22/ 50] FashionMNIST Time 0.09799s Training Accuracy: 99.32% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.09689s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 23/ 50] FashionMNIST Time 0.08956s Training Accuracy: 99.61% Test Accuracy: 71.88%
[ 24/ 50] MNIST Time 0.09785s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 24/ 50] FashionMNIST Time 0.08759s Training Accuracy: 99.90% Test Accuracy: 71.88%
[ 25/ 50] MNIST Time 0.08893s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 25/ 50] FashionMNIST Time 0.10239s Training Accuracy: 99.90% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.09288s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 26/ 50] FashionMNIST Time 0.08850s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 27/ 50] MNIST Time 0.09575s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] FashionMNIST Time 0.08667s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 28/ 50] MNIST Time 0.08701s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.09736s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 29/ 50] MNIST Time 0.08862s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 29/ 50] FashionMNIST Time 0.09008s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 30/ 50] MNIST Time 0.08951s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 30/ 50] FashionMNIST Time 0.09020s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 31/ 50] MNIST Time 0.08647s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 31/ 50] FashionMNIST Time 0.08745s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 32/ 50] MNIST Time 0.08798s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] FashionMNIST Time 0.08694s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 33/ 50] MNIST Time 0.08676s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 33/ 50] FashionMNIST Time 0.09037s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 34/ 50] MNIST Time 0.08792s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.08892s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 35/ 50] MNIST Time 0.08873s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] FashionMNIST Time 0.08914s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 36/ 50] MNIST Time 0.08768s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] FashionMNIST Time 0.08973s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 37/ 50] MNIST Time 0.09986s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] FashionMNIST Time 0.09130s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 38/ 50] MNIST Time 0.09234s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 38/ 50] FashionMNIST Time 0.09978s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 39/ 50] MNIST Time 0.08899s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] FashionMNIST Time 0.09000s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 40/ 50] MNIST Time 0.09923s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] FashionMNIST Time 0.09245s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 41/ 50] MNIST Time 0.09214s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] FashionMNIST Time 0.10876s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 42/ 50] MNIST Time 0.09406s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] FashionMNIST Time 0.09157s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 43/ 50] MNIST Time 0.11084s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] FashionMNIST Time 0.08797s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 44/ 50] MNIST Time 0.08799s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 44/ 50] FashionMNIST Time 0.09904s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 45/ 50] MNIST Time 0.08659s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] FashionMNIST Time 0.09036s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 46/ 50] MNIST Time 0.09229s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] FashionMNIST Time 0.09210s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 47/ 50] MNIST Time 0.09108s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 47/ 50] FashionMNIST Time 0.09164s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 48/ 50] MNIST Time 0.09003s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 48/ 50] FashionMNIST Time 0.08685s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 49/ 50] MNIST Time 0.08910s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] FashionMNIST Time 0.08759s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 50/ 50] MNIST Time 0.08768s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 50/ 50] FashionMNIST Time 0.08916s Training Accuracy: 100.00% Test Accuracy: 75.00%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 62.50%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 75.00%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.