Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
using Lux, ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random,
Loading Datasets
function load_dataset(
::Type{dset}, n_train::Union{Nothing, Int},
n_eval::Union{Nothing, Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
(x_train, y_train);
batchsize = min(batchsize, size(x_train, 4)), shuffle = true, partial = false
(x_test, y_test);
batchsize = min(batchsize, size(x_test, 4)), shuffle = false, partial = false
function load_datasets(batchsize = 32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
function HyperNet(
weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |> getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch = :HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator = Lux.initialparameters(rng, hn.layers.weight_generator))
Create and Initialize the HyperNet
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride = 2),
Conv((3, 3), 16 => 32, relu; stride = 2),
Conv((3, 3), 32 => 64, relu; stride = 2),
Dense(64, 10)
return HyperNet(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
create_model (generic function with 1 method)
Define Utility Functions
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
return total_correct / total
accuracy (generic function with 1 method)
function train()
dev = reactant_device(; force = true)
model = create_model()
dataloaders = load_datasets() |> dev
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(), CrossEntropyLoss(; logits = Val(true)),
((concrete_data_idx, x), y), train_state; return_gradients = Val(false)
ttime = time() - stime
train_acc = round(
model_compiled, train_state.parameters,
train_state.states, train_dataloader, concrete_data_idx
) * 100;
digits = 2
test_acc = round(
model_compiled, train_state.parameters,
train_state.states, test_dataloader, concrete_data_idx
) * 100;
digits = 2
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
model_compiled, train_state.parameters,
train_state.states, train_dataloader, concrete_data_idx
) * 100;
digits = 2
test_acc = round(
model_compiled, train_state.parameters,
train_state.states, test_dataloader, concrete_data_idx
) * 100;
digits = 2
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
return test_acc_list
test_acc_list = train()
[ 1/ 50] MNIST Time 35.91804s Training Accuracy: 35.06% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.03871s Training Accuracy: 34.18% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.04073s Training Accuracy: 34.96% Test Accuracy: 37.50%
[ 2/ 50] FashionMNIST Time 0.07132s Training Accuracy: 51.17% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.03580s Training Accuracy: 36.52% Test Accuracy: 28.12%
[ 3/ 50] FashionMNIST Time 0.03345s Training Accuracy: 57.23% Test Accuracy: 53.12%
[ 4/ 50] MNIST Time 0.04741s Training Accuracy: 45.41% Test Accuracy: 34.38%
[ 4/ 50] FashionMNIST Time 0.03490s Training Accuracy: 64.45% Test Accuracy: 56.25%
[ 5/ 50] MNIST Time 0.03310s Training Accuracy: 50.78% Test Accuracy: 46.88%
[ 5/ 50] FashionMNIST Time 0.03279s Training Accuracy: 70.02% Test Accuracy: 59.38%
[ 6/ 50] MNIST Time 0.03607s Training Accuracy: 56.74% Test Accuracy: 40.62%
[ 6/ 50] FashionMNIST Time 0.02994s Training Accuracy: 74.02% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.03922s Training Accuracy: 62.40% Test Accuracy: 40.62%
[ 7/ 50] FashionMNIST Time 0.02892s Training Accuracy: 78.81% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.03913s Training Accuracy: 71.00% Test Accuracy: 43.75%
[ 8/ 50] FashionMNIST Time 0.02928s Training Accuracy: 83.69% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.04293s Training Accuracy: 74.61% Test Accuracy: 43.75%
[ 9/ 50] FashionMNIST Time 0.03059s Training Accuracy: 83.11% Test Accuracy: 65.62%
[ 10/ 50] MNIST Time 0.03220s Training Accuracy: 78.61% Test Accuracy: 46.88%
[ 10/ 50] FashionMNIST Time 0.03007s Training Accuracy: 88.09% Test Accuracy: 62.50%
[ 11/ 50] MNIST Time 0.03057s Training Accuracy: 84.08% Test Accuracy: 37.50%
[ 11/ 50] FashionMNIST Time 0.04358s Training Accuracy: 90.33% Test Accuracy: 59.38%
[ 12/ 50] MNIST Time 0.03569s Training Accuracy: 87.11% Test Accuracy: 43.75%
[ 12/ 50] FashionMNIST Time 0.03765s Training Accuracy: 91.80% Test Accuracy: 65.62%
[ 13/ 50] MNIST Time 0.03755s Training Accuracy: 90.23% Test Accuracy: 46.88%
[ 13/ 50] FashionMNIST Time 0.03633s Training Accuracy: 94.63% Test Accuracy: 62.50%
[ 14/ 50] MNIST Time 0.03024s Training Accuracy: 91.31% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.04041s Training Accuracy: 94.82% Test Accuracy: 59.38%
[ 15/ 50] MNIST Time 0.03308s Training Accuracy: 95.12% Test Accuracy: 53.12%
[ 15/ 50] FashionMNIST Time 0.02890s Training Accuracy: 95.90% Test Accuracy: 62.50%
[ 16/ 50] MNIST Time 0.02936s Training Accuracy: 96.68% Test Accuracy: 50.00%
[ 16/ 50] FashionMNIST Time 0.02841s Training Accuracy: 95.90% Test Accuracy: 59.38%
[ 17/ 50] MNIST Time 0.03651s Training Accuracy: 97.17% Test Accuracy: 56.25%
[ 17/ 50] FashionMNIST Time 0.03268s Training Accuracy: 97.85% Test Accuracy: 62.50%
[ 18/ 50] MNIST Time 0.03884s Training Accuracy: 98.14% Test Accuracy: 53.12%
[ 18/ 50] FashionMNIST Time 0.02806s Training Accuracy: 97.66% Test Accuracy: 62.50%
[ 19/ 50] MNIST Time 0.03930s Training Accuracy: 99.22% Test Accuracy: 56.25%
[ 19/ 50] FashionMNIST Time 0.02913s Training Accuracy: 99.32% Test Accuracy: 59.38%
[ 20/ 50] MNIST Time 0.03430s Training Accuracy: 99.71% Test Accuracy: 53.12%
[ 20/ 50] FashionMNIST Time 0.03048s Training Accuracy: 99.32% Test Accuracy: 59.38%
[ 21/ 50] MNIST Time 0.03063s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 21/ 50] FashionMNIST Time 0.02979s Training Accuracy: 99.51% Test Accuracy: 59.38%
[ 22/ 50] MNIST Time 0.03091s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 22/ 50] FashionMNIST Time 0.02789s Training Accuracy: 99.51% Test Accuracy: 59.38%
[ 23/ 50] MNIST Time 0.03142s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 23/ 50] FashionMNIST Time 0.03848s Training Accuracy: 99.61% Test Accuracy: 62.50%
[ 24/ 50] MNIST Time 0.03005s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 24/ 50] FashionMNIST Time 0.03746s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 25/ 50] MNIST Time 0.03042s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 25/ 50] FashionMNIST Time 0.03885s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 26/ 50] MNIST Time 0.03081s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 26/ 50] FashionMNIST Time 0.03063s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] MNIST Time 0.03073s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 27/ 50] FashionMNIST Time 0.02767s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] MNIST Time 0.02941s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 28/ 50] FashionMNIST Time 0.02802s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 29/ 50] MNIST Time 0.03935s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 29/ 50] FashionMNIST Time 0.02998s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 30/ 50] MNIST Time 0.03779s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 30/ 50] FashionMNIST Time 0.02995s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 31/ 50] MNIST Time 0.03823s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 31/ 50] FashionMNIST Time 0.03050s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] MNIST Time 0.03040s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 32/ 50] FashionMNIST Time 0.03044s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 33/ 50] MNIST Time 0.03360s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 33/ 50] FashionMNIST Time 0.02847s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] MNIST Time 0.02868s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 34/ 50] FashionMNIST Time 0.03713s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] MNIST Time 0.02848s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 35/ 50] FashionMNIST Time 0.04032s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] MNIST Time 0.02940s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 36/ 50] FashionMNIST Time 0.03713s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] MNIST Time 0.02845s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 37/ 50] FashionMNIST Time 0.03695s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 38/ 50] MNIST Time 0.02979s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 38/ 50] FashionMNIST Time 0.02973s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] MNIST Time 0.02844s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 39/ 50] FashionMNIST Time 0.02848s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] MNIST Time 0.02931s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 40/ 50] FashionMNIST Time 0.02836s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] MNIST Time 0.03853s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 41/ 50] FashionMNIST Time 0.02855s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] MNIST Time 0.03617s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 42/ 50] FashionMNIST Time 0.02867s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] MNIST Time 0.03773s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 43/ 50] FashionMNIST Time 0.02900s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 44/ 50] MNIST Time 0.02822s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 44/ 50] FashionMNIST Time 0.02832s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] MNIST Time 0.02983s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 45/ 50] FashionMNIST Time 0.02831s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] MNIST Time 0.02834s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 46/ 50] FashionMNIST Time 0.03677s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 47/ 50] MNIST Time 0.02830s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 47/ 50] FashionMNIST Time 0.03632s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 48/ 50] MNIST Time 0.02826s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 48/ 50] FashionMNIST Time 0.03593s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] MNIST Time 0.02867s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 49/ 50] FashionMNIST Time 0.02949s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 50/ 50] MNIST Time 0.02858s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 50/ 50] FashionMNIST Time 0.02897s Training Accuracy: 100.00% Test Accuracy: 62.50%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 56.25%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 62.50%
using InteractiveUtils
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_DEBUG = Literate
