Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, ReactantLoading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endImplement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endDefining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
endCreate and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
endDefine Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
ŷ, _ = model((data_idx, x), ps, st)
target_class = y |> cdev |> onecold
predicted_class = ŷ |> cdev |> onecold
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endTraining
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = load_datasets() |> dev
Random.seed!(1234)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 64.72477s Training Accuracy: 34.57% Test Accuracy: 37.50%
[ 1/ 50] FashionMNIST Time 0.11760s Training Accuracy: 32.52% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.11684s Training Accuracy: 36.33% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.11614s Training Accuracy: 46.19% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.11733s Training Accuracy: 42.58% Test Accuracy: 31.25%
[ 3/ 50] FashionMNIST Time 0.11789s Training Accuracy: 56.84% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.13591s Training Accuracy: 51.86% Test Accuracy: 43.75%
[ 4/ 50] FashionMNIST Time 0.11681s Training Accuracy: 64.65% Test Accuracy: 56.25%
[ 5/ 50] MNIST Time 0.11202s Training Accuracy: 56.84% Test Accuracy: 37.50%
[ 5/ 50] FashionMNIST Time 0.13073s Training Accuracy: 71.48% Test Accuracy: 56.25%
[ 6/ 50] MNIST Time 0.12079s Training Accuracy: 62.89% Test Accuracy: 34.38%
[ 6/ 50] FashionMNIST Time 0.11888s Training Accuracy: 74.80% Test Accuracy: 53.12%
[ 7/ 50] MNIST Time 0.11846s Training Accuracy: 69.14% Test Accuracy: 34.38%
[ 7/ 50] FashionMNIST Time 0.11919s Training Accuracy: 77.25% Test Accuracy: 59.38%
[ 8/ 50] MNIST Time 0.12205s Training Accuracy: 75.59% Test Accuracy: 43.75%
[ 8/ 50] FashionMNIST Time 0.11644s Training Accuracy: 81.84% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.11798s Training Accuracy: 79.39% Test Accuracy: 50.00%
[ 9/ 50] FashionMNIST Time 0.12159s Training Accuracy: 83.98% Test Accuracy: 68.75%
[ 10/ 50] MNIST Time 0.11988s Training Accuracy: 82.42% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.12456s Training Accuracy: 86.91% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.11881s Training Accuracy: 87.30% Test Accuracy: 53.12%
[ 11/ 50] FashionMNIST Time 0.11442s Training Accuracy: 89.65% Test Accuracy: 68.75%
[ 12/ 50] MNIST Time 0.12077s Training Accuracy: 90.53% Test Accuracy: 50.00%
[ 12/ 50] FashionMNIST Time 0.11595s Training Accuracy: 92.09% Test Accuracy: 68.75%
[ 13/ 50] MNIST Time 0.12426s Training Accuracy: 93.36% Test Accuracy: 59.38%
[ 13/ 50] FashionMNIST Time 0.12408s Training Accuracy: 93.46% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.11910s Training Accuracy: 93.55% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.12066s Training Accuracy: 94.43% Test Accuracy: 68.75%
[ 15/ 50] MNIST Time 0.11652s Training Accuracy: 95.61% Test Accuracy: 56.25%
[ 15/ 50] FashionMNIST Time 0.12224s Training Accuracy: 95.61% Test Accuracy: 75.00%
[ 16/ 50] MNIST Time 0.12290s Training Accuracy: 97.95% Test Accuracy: 62.50%
[ 16/ 50] FashionMNIST Time 0.12084s Training Accuracy: 96.29% Test Accuracy: 75.00%
[ 17/ 50] MNIST Time 0.12638s Training Accuracy: 99.22% Test Accuracy: 62.50%
[ 17/ 50] FashionMNIST Time 0.11948s Training Accuracy: 97.46% Test Accuracy: 75.00%
[ 18/ 50] MNIST Time 0.12063s Training Accuracy: 99.12% Test Accuracy: 68.75%
[ 18/ 50] FashionMNIST Time 0.12151s Training Accuracy: 97.66% Test Accuracy: 75.00%
[ 19/ 50] MNIST Time 0.11850s Training Accuracy: 99.61% Test Accuracy: 59.38%
[ 19/ 50] FashionMNIST Time 0.12725s Training Accuracy: 98.63% Test Accuracy: 78.12%
[ 20/ 50] MNIST Time 0.11775s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 20/ 50] FashionMNIST Time 0.11497s Training Accuracy: 99.12% Test Accuracy: 78.12%
[ 21/ 50] MNIST Time 0.11894s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 21/ 50] FashionMNIST Time 0.11733s Training Accuracy: 99.12% Test Accuracy: 75.00%
[ 22/ 50] MNIST Time 0.12429s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 22/ 50] FashionMNIST Time 0.11534s Training Accuracy: 99.61% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.11776s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 23/ 50] FashionMNIST Time 0.11601s Training Accuracy: 99.80% Test Accuracy: 75.00%
[ 24/ 50] MNIST Time 0.12020s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 24/ 50] FashionMNIST Time 0.12607s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 25/ 50] MNIST Time 0.11578s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 25/ 50] FashionMNIST Time 0.11682s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.11837s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 26/ 50] FashionMNIST Time 0.11792s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 27/ 50] MNIST Time 0.12724s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] FashionMNIST Time 0.11828s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 28/ 50] MNIST Time 0.11467s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.11749s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 29/ 50] MNIST Time 0.11657s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] FashionMNIST Time 0.12593s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 30/ 50] MNIST Time 0.11706s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] FashionMNIST Time 0.11969s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 31/ 50] MNIST Time 0.12001s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 31/ 50] FashionMNIST Time 0.11649s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 32/ 50] MNIST Time 0.12324s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 32/ 50] FashionMNIST Time 0.11838s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 33/ 50] MNIST Time 0.12054s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 33/ 50] FashionMNIST Time 0.12105s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 34/ 50] MNIST Time 0.11897s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.12795s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 35/ 50] MNIST Time 0.11917s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] FashionMNIST Time 0.12005s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 36/ 50] MNIST Time 0.12128s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] FashionMNIST Time 0.12137s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 37/ 50] MNIST Time 0.12910s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] FashionMNIST Time 0.11948s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 38/ 50] MNIST Time 0.11948s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] FashionMNIST Time 0.11996s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 39/ 50] MNIST Time 0.11990s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] FashionMNIST Time 0.12697s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 40/ 50] MNIST Time 0.11952s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] FashionMNIST Time 0.11986s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 41/ 50] MNIST Time 0.12138s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 41/ 50] FashionMNIST Time 0.12340s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 42/ 50] MNIST Time 0.12862s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 42/ 50] FashionMNIST Time 0.11908s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 43/ 50] MNIST Time 0.11878s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 43/ 50] FashionMNIST Time 0.12169s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 44/ 50] MNIST Time 0.11867s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] FashionMNIST Time 0.12836s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 45/ 50] MNIST Time 0.12016s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] FashionMNIST Time 0.11943s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 46/ 50] MNIST Time 0.12237s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] FashionMNIST Time 0.12413s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 47/ 50] MNIST Time 0.12968s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] FashionMNIST Time 0.12330s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 48/ 50] MNIST Time 0.12000s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 48/ 50] FashionMNIST Time 0.12148s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 49/ 50] MNIST Time 0.12236s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 49/ 50] FashionMNIST Time 0.12698s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 50/ 50] MNIST Time 0.12195s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] FashionMNIST Time 0.11960s Training Accuracy: 100.00% Test Accuracy: 71.88%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 71.88%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, icelake-server)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.