Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1758241995.694755 3706750 service.cc:158] XLA service 0x22560740 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1758241995.694828 3706750 service.cc:166] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1758241995.695621 3706750 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1758241995.695680 3706750 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1758241995.695728 3706750 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1758241995.708141 3706750 cuda_dnn.cc:463] Loaded cuDNN version 91200
[ 1/ 50] MNIST Time 44.64796s Training Accuracy: 35.45% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.03934s Training Accuracy: 31.74% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.04115s Training Accuracy: 35.74% Test Accuracy: 31.25%
[ 2/ 50] FashionMNIST Time 0.07768s Training Accuracy: 47.07% Test Accuracy: 53.12%
[ 3/ 50] MNIST Time 0.03701s Training Accuracy: 40.72% Test Accuracy: 37.50%
[ 3/ 50] FashionMNIST Time 0.03356s Training Accuracy: 55.57% Test Accuracy: 53.12%
[ 4/ 50] MNIST Time 0.03282s Training Accuracy: 48.73% Test Accuracy: 34.38%
[ 4/ 50] FashionMNIST Time 0.03413s Training Accuracy: 62.21% Test Accuracy: 56.25%
[ 5/ 50] MNIST Time 0.04897s Training Accuracy: 56.15% Test Accuracy: 46.88%
[ 5/ 50] FashionMNIST Time 0.03753s Training Accuracy: 68.65% Test Accuracy: 59.38%
[ 6/ 50] MNIST Time 0.04907s Training Accuracy: 61.62% Test Accuracy: 40.62%
[ 6/ 50] FashionMNIST Time 0.03733s Training Accuracy: 74.61% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.04755s Training Accuracy: 67.87% Test Accuracy: 46.88%
[ 7/ 50] FashionMNIST Time 0.03761s Training Accuracy: 76.76% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.04804s Training Accuracy: 72.17% Test Accuracy: 43.75%
[ 8/ 50] FashionMNIST Time 0.03595s Training Accuracy: 81.54% Test Accuracy: 59.38%
[ 9/ 50] MNIST Time 0.03919s Training Accuracy: 77.83% Test Accuracy: 50.00%
[ 9/ 50] FashionMNIST Time 0.05106s Training Accuracy: 84.28% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.03953s Training Accuracy: 82.62% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.04714s Training Accuracy: 88.38% Test Accuracy: 65.62%
[ 11/ 50] MNIST Time 0.04005s Training Accuracy: 87.89% Test Accuracy: 56.25%
[ 11/ 50] FashionMNIST Time 0.03406s Training Accuracy: 89.84% Test Accuracy: 59.38%
[ 12/ 50] MNIST Time 0.03417s Training Accuracy: 85.74% Test Accuracy: 62.50%
[ 12/ 50] FashionMNIST Time 0.03705s Training Accuracy: 91.70% Test Accuracy: 71.88%
[ 13/ 50] MNIST Time 0.04598s Training Accuracy: 88.09% Test Accuracy: 50.00%
[ 13/ 50] FashionMNIST Time 0.03665s Training Accuracy: 93.26% Test Accuracy: 65.62%
[ 14/ 50] MNIST Time 0.03967s Training Accuracy: 92.48% Test Accuracy: 59.38%
[ 14/ 50] FashionMNIST Time 0.03070s Training Accuracy: 94.04% Test Accuracy: 68.75%
[ 15/ 50] MNIST Time 0.03510s Training Accuracy: 95.51% Test Accuracy: 62.50%
[ 15/ 50] FashionMNIST Time 0.03412s Training Accuracy: 94.53% Test Accuracy: 71.88%
[ 16/ 50] MNIST Time 0.03197s Training Accuracy: 96.39% Test Accuracy: 53.12%
[ 16/ 50] FashionMNIST Time 0.04597s Training Accuracy: 95.90% Test Accuracy: 71.88%
[ 17/ 50] MNIST Time 0.03163s Training Accuracy: 98.63% Test Accuracy: 62.50%
[ 17/ 50] FashionMNIST Time 0.03896s Training Accuracy: 97.46% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.03537s Training Accuracy: 99.22% Test Accuracy: 56.25%
[ 18/ 50] FashionMNIST Time 0.03598s Training Accuracy: 96.29% Test Accuracy: 68.75%
[ 19/ 50] MNIST Time 0.03230s Training Accuracy: 99.61% Test Accuracy: 62.50%
[ 19/ 50] FashionMNIST Time 0.03326s Training Accuracy: 98.14% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.04105s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 20/ 50] FashionMNIST Time 0.03296s Training Accuracy: 98.93% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.04249s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.03280s Training Accuracy: 98.93% Test Accuracy: 65.62%
[ 22/ 50] MNIST Time 0.03191s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 22/ 50] FashionMNIST Time 0.03313s Training Accuracy: 99.32% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.03378s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 23/ 50] FashionMNIST Time 0.04375s Training Accuracy: 99.80% Test Accuracy: 71.88%
[ 24/ 50] MNIST Time 0.03504s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 24/ 50] FashionMNIST Time 0.04022s Training Accuracy: 99.80% Test Accuracy: 71.88%
[ 25/ 50] MNIST Time 0.03795s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 25/ 50] FashionMNIST Time 0.03355s Training Accuracy: 99.90% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.03699s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 26/ 50] FashionMNIST Time 0.03499s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 27/ 50] MNIST Time 0.04573s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] FashionMNIST Time 0.03637s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 28/ 50] MNIST Time 0.04304s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.03420s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 29/ 50] MNIST Time 0.03358s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 29/ 50] FashionMNIST Time 0.03345s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 30/ 50] MNIST Time 0.03359s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 30/ 50] FashionMNIST Time 0.03352s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 31/ 50] MNIST Time 0.03435s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 31/ 50] FashionMNIST Time 0.04218s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 32/ 50] MNIST Time 0.03553s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] FashionMNIST Time 0.04625s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 33/ 50] MNIST Time 0.03553s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 33/ 50] FashionMNIST Time 0.03530s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 34/ 50] MNIST Time 0.03570s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.04480s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] MNIST Time 0.04923s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] FashionMNIST Time 0.03280s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 36/ 50] MNIST Time 0.04464s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] FashionMNIST Time 0.03503s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 37/ 50] MNIST Time 0.03507s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] FashionMNIST Time 0.03757s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 38/ 50] MNIST Time 0.03757s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 38/ 50] FashionMNIST Time 0.04606s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] MNIST Time 0.03421s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] FashionMNIST Time 0.07067s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 40/ 50] MNIST Time 0.04400s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] FashionMNIST Time 0.04168s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 41/ 50] MNIST Time 0.03754s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] FashionMNIST Time 0.03731s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 42/ 50] MNIST Time 0.04663s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] FashionMNIST Time 0.03683s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] MNIST Time 0.04755s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] FashionMNIST Time 0.03857s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] MNIST Time 0.03612s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 44/ 50] FashionMNIST Time 0.03841s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 45/ 50] MNIST Time 0.03639s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] FashionMNIST Time 0.04628s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] MNIST Time 0.03519s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 46/ 50] FashionMNIST Time 0.04898s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 47/ 50] MNIST Time 0.04647s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 47/ 50] FashionMNIST Time 0.03154s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.03435s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 48/ 50] FashionMNIST Time 0.03427s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.04323s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 49/ 50] FashionMNIST Time 0.03301s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 50/ 50] MNIST Time 0.04271s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 50/ 50] FashionMNIST Time 0.03275s Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 59.38%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 68.75%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.