Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, ReactantLoading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endImplement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endDefining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
endCreate and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
endDefine Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
ŷ, _ = model((data_idx, x), ps, st)
target_class = y |> cdev |> onecold
predicted_class = ŷ |> cdev |> onecold
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endTraining
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = load_datasets() |> dev
Random.seed!(1234)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 67.84203s Training Accuracy: 34.57% Test Accuracy: 37.50%
[ 1/ 50] FashionMNIST Time 0.14110s Training Accuracy: 32.62% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.09295s Training Accuracy: 36.72% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.09253s Training Accuracy: 45.61% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.09559s Training Accuracy: 41.02% Test Accuracy: 28.12%
[ 3/ 50] FashionMNIST Time 0.08765s Training Accuracy: 57.03% Test Accuracy: 59.38%
[ 4/ 50] MNIST Time 0.08834s Training Accuracy: 52.25% Test Accuracy: 40.62%
[ 4/ 50] FashionMNIST Time 0.08415s Training Accuracy: 63.96% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.09107s Training Accuracy: 57.32% Test Accuracy: 37.50%
[ 5/ 50] FashionMNIST Time 0.08962s Training Accuracy: 70.51% Test Accuracy: 53.12%
[ 6/ 50] MNIST Time 0.09344s Training Accuracy: 63.67% Test Accuracy: 37.50%
[ 6/ 50] FashionMNIST Time 0.09053s Training Accuracy: 75.49% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.08993s Training Accuracy: 71.00% Test Accuracy: 34.38%
[ 7/ 50] FashionMNIST Time 0.08954s Training Accuracy: 76.66% Test Accuracy: 53.12%
[ 8/ 50] MNIST Time 0.08844s Training Accuracy: 75.68% Test Accuracy: 43.75%
[ 8/ 50] FashionMNIST Time 0.09535s Training Accuracy: 81.84% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.09048s Training Accuracy: 80.08% Test Accuracy: 43.75%
[ 9/ 50] FashionMNIST Time 0.08935s Training Accuracy: 83.89% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.09194s Training Accuracy: 84.96% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.08935s Training Accuracy: 88.57% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.10289s Training Accuracy: 89.45% Test Accuracy: 53.12%
[ 11/ 50] FashionMNIST Time 0.08998s Training Accuracy: 90.04% Test Accuracy: 68.75%
[ 12/ 50] MNIST Time 0.09116s Training Accuracy: 91.11% Test Accuracy: 50.00%
[ 12/ 50] FashionMNIST Time 0.08926s Training Accuracy: 91.11% Test Accuracy: 68.75%
[ 13/ 50] MNIST Time 0.08872s Training Accuracy: 93.46% Test Accuracy: 56.25%
[ 13/ 50] FashionMNIST Time 0.09405s Training Accuracy: 91.89% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.09314s Training Accuracy: 95.51% Test Accuracy: 53.12%
[ 14/ 50] FashionMNIST Time 0.08822s Training Accuracy: 94.53% Test Accuracy: 65.62%
[ 15/ 50] MNIST Time 0.08968s Training Accuracy: 96.58% Test Accuracy: 56.25%
[ 15/ 50] FashionMNIST Time 0.09238s Training Accuracy: 93.85% Test Accuracy: 71.88%
[ 16/ 50] MNIST Time 0.08626s Training Accuracy: 98.73% Test Accuracy: 62.50%
[ 16/ 50] FashionMNIST Time 0.08900s Training Accuracy: 96.29% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.09404s Training Accuracy: 99.12% Test Accuracy: 59.38%
[ 17/ 50] FashionMNIST Time 0.09011s Training Accuracy: 96.97% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.08848s Training Accuracy: 99.51% Test Accuracy: 56.25%
[ 18/ 50] FashionMNIST Time 0.09391s Training Accuracy: 96.97% Test Accuracy: 68.75%
[ 19/ 50] MNIST Time 0.08784s Training Accuracy: 99.80% Test Accuracy: 59.38%
[ 19/ 50] FashionMNIST Time 0.09838s Training Accuracy: 99.32% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.08595s Training Accuracy: 99.80% Test Accuracy: 59.38%
[ 20/ 50] FashionMNIST Time 0.08543s Training Accuracy: 98.73% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.08655s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.08995s Training Accuracy: 98.73% Test Accuracy: 65.62%
[ 22/ 50] MNIST Time 0.09554s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 22/ 50] FashionMNIST Time 0.08640s Training Accuracy: 99.61% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.08696s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 23/ 50] FashionMNIST Time 0.08693s Training Accuracy: 99.71% Test Accuracy: 65.62%
[ 24/ 50] MNIST Time 0.08712s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 24/ 50] FashionMNIST Time 0.09485s Training Accuracy: 99.51% Test Accuracy: 65.62%
[ 25/ 50] MNIST Time 0.08627s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 25/ 50] FashionMNIST Time 0.08631s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 26/ 50] MNIST Time 0.08828s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 26/ 50] FashionMNIST Time 0.08661s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] MNIST Time 0.09187s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] FashionMNIST Time 0.08911s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] MNIST Time 0.08842s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.08539s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] MNIST Time 0.08768s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] FashionMNIST Time 0.09321s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] MNIST Time 0.08648s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] FashionMNIST Time 0.09249s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] MNIST Time 0.08474s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] FashionMNIST Time 0.08740s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 32/ 50] MNIST Time 0.08773s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] FashionMNIST Time 0.09306s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 33/ 50] MNIST Time 0.09040s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 33/ 50] FashionMNIST Time 0.08583s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 34/ 50] MNIST Time 0.08764s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.08921s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] MNIST Time 0.08914s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 35/ 50] FashionMNIST Time 0.09129s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] MNIST Time 0.08716s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] FashionMNIST Time 0.08515s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 37/ 50] MNIST Time 0.08539s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 37/ 50] FashionMNIST Time 0.08605s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] MNIST Time 0.09263s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] FashionMNIST Time 0.08562s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] MNIST Time 0.10237s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 39/ 50] FashionMNIST Time 0.08805s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] MNIST Time 0.08875s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] FashionMNIST Time 0.09226s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 41/ 50] MNIST Time 0.08483s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 41/ 50] FashionMNIST Time 0.08819s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 42/ 50] MNIST Time 0.09029s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 42/ 50] FashionMNIST Time 0.08805s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] MNIST Time 0.09410s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 43/ 50] FashionMNIST Time 0.08416s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] MNIST Time 0.08948s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] FashionMNIST Time 0.08772s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 45/ 50] MNIST Time 0.08788s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] FashionMNIST Time 0.09373s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] MNIST Time 0.08712s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] FashionMNIST Time 0.09230s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] MNIST Time 0.08816s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] FashionMNIST Time 0.08582s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.08714s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 48/ 50] FashionMNIST Time 0.08770s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.09564s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 49/ 50] FashionMNIST Time 0.08578s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 50/ 50] MNIST Time 0.08683s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] FashionMNIST Time 0.08947s Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 68.75%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.