Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, ReactantLoading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endImplement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endDefining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
endCreate and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
endDefine Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
ŷ, _ = model((data_idx, x), ps, st)
target_class = y |> cdev |> onecold
predicted_class = ŷ |> cdev |> onecold
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endTraining
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = load_datasets() |> dev
Random.seed!(1234)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 48.84021s Training Accuracy: 34.57% Test Accuracy: 37.50%
[ 1/ 50] FashionMNIST Time 0.09202s Training Accuracy: 32.52% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.09128s Training Accuracy: 36.33% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.08762s Training Accuracy: 46.19% Test Accuracy: 46.88%
[ 3/ 50] MNIST Time 0.08880s Training Accuracy: 42.77% Test Accuracy: 28.12%
[ 3/ 50] FashionMNIST Time 0.08956s Training Accuracy: 56.74% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.08838s Training Accuracy: 51.17% Test Accuracy: 37.50%
[ 4/ 50] FashionMNIST Time 0.09029s Training Accuracy: 63.48% Test Accuracy: 62.50%
[ 5/ 50] MNIST Time 0.08769s Training Accuracy: 55.96% Test Accuracy: 43.75%
[ 5/ 50] FashionMNIST Time 0.08950s Training Accuracy: 71.19% Test Accuracy: 53.12%
[ 6/ 50] MNIST Time 0.08884s Training Accuracy: 63.09% Test Accuracy: 34.38%
[ 6/ 50] FashionMNIST Time 0.08820s Training Accuracy: 74.32% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.09790s Training Accuracy: 67.97% Test Accuracy: 50.00%
[ 7/ 50] FashionMNIST Time 0.08662s Training Accuracy: 76.37% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.08902s Training Accuracy: 74.80% Test Accuracy: 46.88%
[ 8/ 50] FashionMNIST Time 0.10055s Training Accuracy: 81.93% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.08639s Training Accuracy: 81.05% Test Accuracy: 50.00%
[ 9/ 50] FashionMNIST Time 0.08597s Training Accuracy: 85.16% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.09936s Training Accuracy: 82.42% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.08901s Training Accuracy: 87.11% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.09011s Training Accuracy: 87.30% Test Accuracy: 50.00%
[ 11/ 50] FashionMNIST Time 0.09869s Training Accuracy: 89.45% Test Accuracy: 68.75%
[ 12/ 50] MNIST Time 0.09153s Training Accuracy: 89.75% Test Accuracy: 50.00%
[ 12/ 50] FashionMNIST Time 0.08788s Training Accuracy: 90.82% Test Accuracy: 71.88%
[ 13/ 50] MNIST Time 0.10019s Training Accuracy: 93.85% Test Accuracy: 59.38%
[ 13/ 50] FashionMNIST Time 0.08920s Training Accuracy: 94.34% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.09109s Training Accuracy: 93.95% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.09797s Training Accuracy: 95.51% Test Accuracy: 68.75%
[ 15/ 50] MNIST Time 0.08859s Training Accuracy: 95.61% Test Accuracy: 62.50%
[ 15/ 50] FashionMNIST Time 0.08875s Training Accuracy: 94.73% Test Accuracy: 71.88%
[ 16/ 50] MNIST Time 0.08958s Training Accuracy: 97.95% Test Accuracy: 62.50%
[ 16/ 50] FashionMNIST Time 0.08959s Training Accuracy: 96.48% Test Accuracy: 75.00%
[ 17/ 50] MNIST Time 0.08911s Training Accuracy: 99.32% Test Accuracy: 62.50%
[ 17/ 50] FashionMNIST Time 0.08885s Training Accuracy: 97.66% Test Accuracy: 71.88%
[ 18/ 50] MNIST Time 0.08881s Training Accuracy: 99.51% Test Accuracy: 59.38%
[ 18/ 50] FashionMNIST Time 0.08755s Training Accuracy: 97.46% Test Accuracy: 71.88%
[ 19/ 50] MNIST Time 0.08810s Training Accuracy: 99.61% Test Accuracy: 59.38%
[ 19/ 50] FashionMNIST Time 0.08844s Training Accuracy: 98.34% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.08678s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 20/ 50] FashionMNIST Time 0.10557s Training Accuracy: 98.54% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.08847s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.08642s Training Accuracy: 99.22% Test Accuracy: 75.00%
[ 22/ 50] MNIST Time 0.08743s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 22/ 50] FashionMNIST Time 0.08615s Training Accuracy: 99.32% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.08697s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 23/ 50] FashionMNIST Time 0.08838s Training Accuracy: 99.61% Test Accuracy: 71.88%
[ 24/ 50] MNIST Time 0.08925s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 24/ 50] FashionMNIST Time 0.08695s Training Accuracy: 99.90% Test Accuracy: 71.88%
[ 25/ 50] MNIST Time 0.08715s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 25/ 50] FashionMNIST Time 0.09797s Training Accuracy: 99.90% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.08676s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 26/ 50] FashionMNIST Time 0.08827s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 27/ 50] MNIST Time 0.09931s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] FashionMNIST Time 0.08757s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 28/ 50] MNIST Time 0.08757s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.09637s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 29/ 50] MNIST Time 0.08658s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 29/ 50] FashionMNIST Time 0.08639s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 30/ 50] MNIST Time 0.09771s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 30/ 50] FashionMNIST Time 0.08999s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 31/ 50] MNIST Time 0.09735s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 31/ 50] FashionMNIST Time 0.09602s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 32/ 50] MNIST Time 0.08743s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] FashionMNIST Time 0.08482s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 33/ 50] MNIST Time 0.08796s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 33/ 50] FashionMNIST Time 0.08779s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 34/ 50] MNIST Time 0.08921s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.08596s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 35/ 50] MNIST Time 0.08978s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] FashionMNIST Time 0.08639s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 36/ 50] MNIST Time 0.08804s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] FashionMNIST Time 0.09217s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 37/ 50] MNIST Time 0.08778s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] FashionMNIST Time 0.08569s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 38/ 50] MNIST Time 0.08629s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 38/ 50] FashionMNIST Time 0.08771s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 39/ 50] MNIST Time 0.08881s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] FashionMNIST Time 0.08830s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 40/ 50] MNIST Time 0.08885s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] FashionMNIST Time 0.08816s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 41/ 50] MNIST Time 0.09062s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] FashionMNIST Time 0.08828s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 42/ 50] MNIST Time 0.08959s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] FashionMNIST Time 0.09994s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 43/ 50] MNIST Time 0.08845s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] FashionMNIST Time 0.08858s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 44/ 50] MNIST Time 0.10099s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 44/ 50] FashionMNIST Time 0.09119s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 45/ 50] MNIST Time 0.11106s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] FashionMNIST Time 0.10201s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 46/ 50] MNIST Time 0.08924s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] FashionMNIST Time 0.08997s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 47/ 50] MNIST Time 0.09913s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 47/ 50] FashionMNIST Time 0.08715s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 48/ 50] MNIST Time 0.08628s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 48/ 50] FashionMNIST Time 0.09982s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 49/ 50] MNIST Time 0.08876s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] FashionMNIST Time 0.08811s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 50/ 50] MNIST Time 0.08869s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 50/ 50] FashionMNIST Time 0.08608s Training Accuracy: 100.00% Test Accuracy: 75.00%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 62.50%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 75.00%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 9V74 80-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver4)
GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.