Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, ReactantLoading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endImplement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endDefining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
endCreate and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
endDefine Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
ŷ, _ = model((data_idx, x), ps, st)
target_class = y |> cdev |> onecold
predicted_class = ŷ |> cdev |> onecold
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endTraining
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = load_datasets() |> dev
Random.seed!(1234)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 47.92338s Training Accuracy: 34.57% Test Accuracy: 37.50%
[ 1/ 50] FashionMNIST Time 0.09140s Training Accuracy: 32.52% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.08455s Training Accuracy: 36.33% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.10560s Training Accuracy: 46.19% Test Accuracy: 46.88%
[ 3/ 50] MNIST Time 0.08393s Training Accuracy: 42.68% Test Accuracy: 28.12%
[ 3/ 50] FashionMNIST Time 0.08217s Training Accuracy: 56.64% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.08624s Training Accuracy: 51.37% Test Accuracy: 37.50%
[ 4/ 50] FashionMNIST Time 0.08552s Training Accuracy: 65.23% Test Accuracy: 62.50%
[ 5/ 50] MNIST Time 0.08488s Training Accuracy: 56.64% Test Accuracy: 40.62%
[ 5/ 50] FashionMNIST Time 0.08496s Training Accuracy: 70.51% Test Accuracy: 59.38%
[ 6/ 50] MNIST Time 0.30614s Training Accuracy: 61.62% Test Accuracy: 37.50%
[ 6/ 50] FashionMNIST Time 0.08376s Training Accuracy: 75.78% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.08380s Training Accuracy: 67.77% Test Accuracy: 43.75%
[ 7/ 50] FashionMNIST Time 0.08220s Training Accuracy: 75.59% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.08520s Training Accuracy: 74.90% Test Accuracy: 46.88%
[ 8/ 50] FashionMNIST Time 0.08879s Training Accuracy: 80.96% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.08283s Training Accuracy: 81.05% Test Accuracy: 53.12%
[ 9/ 50] FashionMNIST Time 0.08549s Training Accuracy: 84.77% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.08386s Training Accuracy: 82.52% Test Accuracy: 53.12%
[ 10/ 50] FashionMNIST Time 0.08421s Training Accuracy: 88.18% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.09461s Training Accuracy: 86.43% Test Accuracy: 53.12%
[ 11/ 50] FashionMNIST Time 0.08155s Training Accuracy: 89.84% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.08165s Training Accuracy: 90.04% Test Accuracy: 50.00%
[ 12/ 50] FashionMNIST Time 0.08282s Training Accuracy: 92.09% Test Accuracy: 68.75%
[ 13/ 50] MNIST Time 0.08531s Training Accuracy: 94.43% Test Accuracy: 59.38%
[ 13/ 50] FashionMNIST Time 0.08557s Training Accuracy: 93.46% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.08347s Training Accuracy: 94.53% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.08493s Training Accuracy: 95.51% Test Accuracy: 68.75%
[ 15/ 50] MNIST Time 0.08418s Training Accuracy: 96.29% Test Accuracy: 71.88%
[ 15/ 50] FashionMNIST Time 0.08574s Training Accuracy: 94.24% Test Accuracy: 68.75%
[ 16/ 50] MNIST Time 0.08408s Training Accuracy: 98.63% Test Accuracy: 59.38%
[ 16/ 50] FashionMNIST Time 0.08371s Training Accuracy: 95.51% Test Accuracy: 71.88%
[ 17/ 50] MNIST Time 0.08331s Training Accuracy: 99.32% Test Accuracy: 59.38%
[ 17/ 50] FashionMNIST Time 0.09246s Training Accuracy: 97.36% Test Accuracy: 71.88%
[ 18/ 50] MNIST Time 0.08352s Training Accuracy: 99.61% Test Accuracy: 71.88%
[ 18/ 50] FashionMNIST Time 0.08406s Training Accuracy: 97.46% Test Accuracy: 71.88%
[ 19/ 50] MNIST Time 0.09680s Training Accuracy: 99.61% Test Accuracy: 65.62%
[ 19/ 50] FashionMNIST Time 0.08479s Training Accuracy: 98.63% Test Accuracy: 65.62%
[ 20/ 50] MNIST Time 0.08239s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 20/ 50] FashionMNIST Time 0.08401s Training Accuracy: 98.63% Test Accuracy: 71.88%
[ 21/ 50] MNIST Time 0.08486s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 21/ 50] FashionMNIST Time 0.08380s Training Accuracy: 98.93% Test Accuracy: 71.88%
[ 22/ 50] MNIST Time 0.08486s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 22/ 50] FashionMNIST Time 0.08348s Training Accuracy: 99.51% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.08450s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 23/ 50] FashionMNIST Time 0.08156s Training Accuracy: 99.71% Test Accuracy: 68.75%
[ 24/ 50] MNIST Time 0.08334s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 24/ 50] FashionMNIST Time 0.08429s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 25/ 50] MNIST Time 0.08632s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 25/ 50] FashionMNIST Time 0.09528s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 26/ 50] MNIST Time 0.08381s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 26/ 50] FashionMNIST Time 0.08569s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 27/ 50] MNIST Time 0.09184s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 27/ 50] FashionMNIST Time 0.08370s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 28/ 50] MNIST Time 0.08263s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 28/ 50] FashionMNIST Time 0.09316s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 29/ 50] MNIST Time 0.08406s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] FashionMNIST Time 0.08369s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] MNIST Time 0.08162s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] FashionMNIST Time 0.08411s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 31/ 50] MNIST Time 0.08809s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] FashionMNIST Time 0.08735s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 32/ 50] MNIST Time 0.08484s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 32/ 50] FashionMNIST Time 0.08535s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 33/ 50] MNIST Time 0.08372s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 33/ 50] FashionMNIST Time 0.08255s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 34/ 50] MNIST Time 0.08368s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 34/ 50] FashionMNIST Time 0.08276s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] MNIST Time 0.09185s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 35/ 50] FashionMNIST Time 0.08488s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 36/ 50] MNIST Time 0.08576s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] FashionMNIST Time 0.09237s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 37/ 50] MNIST Time 0.08322s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 37/ 50] FashionMNIST Time 0.08299s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 38/ 50] MNIST Time 0.09251s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] FashionMNIST Time 0.08243s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] MNIST Time 0.08324s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] FashionMNIST Time 0.08391s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 40/ 50] MNIST Time 0.08307s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] FashionMNIST Time 0.08314s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 41/ 50] MNIST Time 0.08922s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] FashionMNIST Time 0.08223s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 42/ 50] MNIST Time 0.08369s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] FashionMNIST Time 0.08459s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] MNIST Time 0.08440s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] FashionMNIST Time 0.08607s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] MNIST Time 0.09017s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 44/ 50] FashionMNIST Time 0.09493s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 45/ 50] MNIST Time 0.08458s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] FashionMNIST Time 0.08421s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] MNIST Time 0.09339s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] FashionMNIST Time 0.08915s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 47/ 50] MNIST Time 0.08599s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 47/ 50] FashionMNIST Time 0.09286s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.08382s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 48/ 50] FashionMNIST Time 0.08387s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.08379s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] FashionMNIST Time 0.08348s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 50/ 50] MNIST Time 0.08243s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 50/ 50] FashionMNIST Time 0.08299s Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 62.50%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 68.75%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.4
Commit 01a2eadb047 (2026-01-06 16:56 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.