Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, ReactantLoading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endImplement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endDefining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
endCreate and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
endDefine Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
ŷ, _ = model((data_idx, x), ps, st)
target_class = y |> cdev |> onecold
predicted_class = ŷ |> cdev |> onecold
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endTraining
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = load_datasets() |> dev
Random.seed!(1234)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 47.84256s Training Accuracy: 34.57% Test Accuracy: 37.50%
[ 1/ 50] FashionMNIST Time 0.09449s Training Accuracy: 32.62% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.08968s Training Accuracy: 36.72% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.08966s Training Accuracy: 45.61% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.08574s Training Accuracy: 41.02% Test Accuracy: 28.12%
[ 3/ 50] FashionMNIST Time 0.08902s Training Accuracy: 57.03% Test Accuracy: 59.38%
[ 4/ 50] MNIST Time 0.09090s Training Accuracy: 52.25% Test Accuracy: 40.62%
[ 4/ 50] FashionMNIST Time 0.08786s Training Accuracy: 63.96% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.10203s Training Accuracy: 57.32% Test Accuracy: 37.50%
[ 5/ 50] FashionMNIST Time 0.08967s Training Accuracy: 70.51% Test Accuracy: 53.12%
[ 6/ 50] MNIST Time 0.08777s Training Accuracy: 63.67% Test Accuracy: 37.50%
[ 6/ 50] FashionMNIST Time 0.08844s Training Accuracy: 75.49% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.09791s Training Accuracy: 71.00% Test Accuracy: 34.38%
[ 7/ 50] FashionMNIST Time 0.08679s Training Accuracy: 76.66% Test Accuracy: 53.12%
[ 8/ 50] MNIST Time 0.08752s Training Accuracy: 75.68% Test Accuracy: 43.75%
[ 8/ 50] FashionMNIST Time 0.08736s Training Accuracy: 81.84% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.08481s Training Accuracy: 80.08% Test Accuracy: 43.75%
[ 9/ 50] FashionMNIST Time 0.08862s Training Accuracy: 83.89% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.08804s Training Accuracy: 84.96% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.08839s Training Accuracy: 88.57% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.08976s Training Accuracy: 89.45% Test Accuracy: 53.12%
[ 11/ 50] FashionMNIST Time 0.08641s Training Accuracy: 90.04% Test Accuracy: 68.75%
[ 12/ 50] MNIST Time 0.08718s Training Accuracy: 91.11% Test Accuracy: 50.00%
[ 12/ 50] FashionMNIST Time 0.09024s Training Accuracy: 91.11% Test Accuracy: 68.75%
[ 13/ 50] MNIST Time 0.08832s Training Accuracy: 93.46% Test Accuracy: 56.25%
[ 13/ 50] FashionMNIST Time 0.09509s Training Accuracy: 91.89% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.08765s Training Accuracy: 95.51% Test Accuracy: 53.12%
[ 14/ 50] FashionMNIST Time 0.08664s Training Accuracy: 94.53% Test Accuracy: 65.62%
[ 15/ 50] MNIST Time 0.09742s Training Accuracy: 96.58% Test Accuracy: 56.25%
[ 15/ 50] FashionMNIST Time 0.08635s Training Accuracy: 93.85% Test Accuracy: 71.88%
[ 16/ 50] MNIST Time 0.08818s Training Accuracy: 98.73% Test Accuracy: 62.50%
[ 16/ 50] FashionMNIST Time 0.09615s Training Accuracy: 96.29% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.08826s Training Accuracy: 99.12% Test Accuracy: 59.38%
[ 17/ 50] FashionMNIST Time 0.08863s Training Accuracy: 96.97% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.08764s Training Accuracy: 99.51% Test Accuracy: 56.25%
[ 18/ 50] FashionMNIST Time 0.08746s Training Accuracy: 96.97% Test Accuracy: 68.75%
[ 19/ 50] MNIST Time 0.08800s Training Accuracy: 99.80% Test Accuracy: 59.38%
[ 19/ 50] FashionMNIST Time 0.08738s Training Accuracy: 99.32% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.08873s Training Accuracy: 99.80% Test Accuracy: 59.38%
[ 20/ 50] FashionMNIST Time 0.08971s Training Accuracy: 98.73% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.08948s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.08704s Training Accuracy: 98.73% Test Accuracy: 65.62%
[ 22/ 50] MNIST Time 0.08582s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 22/ 50] FashionMNIST Time 0.08711s Training Accuracy: 99.61% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.08903s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 23/ 50] FashionMNIST Time 0.08892s Training Accuracy: 99.71% Test Accuracy: 65.62%
[ 24/ 50] MNIST Time 0.08859s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 24/ 50] FashionMNIST Time 0.09520s Training Accuracy: 99.51% Test Accuracy: 65.62%
[ 25/ 50] MNIST Time 0.08627s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 25/ 50] FashionMNIST Time 0.09295s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 26/ 50] MNIST Time 0.09970s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 26/ 50] FashionMNIST Time 0.08892s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] MNIST Time 0.08762s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] FashionMNIST Time 0.09563s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] MNIST Time 0.08651s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.08887s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] MNIST Time 0.09012s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] FashionMNIST Time 0.08719s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] MNIST Time 0.08743s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] FashionMNIST Time 0.09277s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] MNIST Time 0.08967s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] FashionMNIST Time 0.09122s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 32/ 50] MNIST Time 0.09112s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] FashionMNIST Time 0.08939s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 33/ 50] MNIST Time 0.09186s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 33/ 50] FashionMNIST Time 0.08934s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 34/ 50] MNIST Time 0.09649s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.08883s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] MNIST Time 0.08876s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 35/ 50] FashionMNIST Time 0.09569s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] MNIST Time 0.08617s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] FashionMNIST Time 0.09046s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 37/ 50] MNIST Time 0.10282s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 37/ 50] FashionMNIST Time 0.08810s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] MNIST Time 0.08619s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] FashionMNIST Time 0.08784s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] MNIST Time 0.09203s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 39/ 50] FashionMNIST Time 0.09036s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] MNIST Time 0.08479s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] FashionMNIST Time 0.08526s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 41/ 50] MNIST Time 0.08758s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 41/ 50] FashionMNIST Time 0.08710s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 42/ 50] MNIST Time 0.08910s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 42/ 50] FashionMNIST Time 0.08647s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] MNIST Time 0.08971s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 43/ 50] FashionMNIST Time 0.08556s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] MNIST Time 0.08723s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] FashionMNIST Time 0.08938s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 45/ 50] MNIST Time 0.09859s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] FashionMNIST Time 0.08703s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] MNIST Time 0.08748s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] FashionMNIST Time 0.09853s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] MNIST Time 0.08802s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] FashionMNIST Time 0.08556s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.09696s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 48/ 50] FashionMNIST Time 0.08997s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.08769s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 49/ 50] FashionMNIST Time 0.08965s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 50/ 50] MNIST Time 0.08959s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] FashionMNIST Time 0.08726s Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 68.75%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.4
Commit 01a2eadb047 (2026-01-06 16:56 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.