Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1760327431.499524 1442466 service.cc:158] XLA service 0x36311e50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760327431.499603 1442466 service.cc:166] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1760327431.500605 1442466 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1760327431.500644 1442466 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1760327431.500684 1442466 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1760327431.511429 1442466 cuda_dnn.cc:463] Loaded cuDNN version 91200
[ 1/ 50] MNIST Time 44.16144s Training Accuracy: 35.64% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.04697s Training Accuracy: 31.74% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.04045s Training Accuracy: 35.64% Test Accuracy: 37.50%
[ 2/ 50] FashionMNIST Time 0.04474s Training Accuracy: 45.21% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.04061s Training Accuracy: 40.14% Test Accuracy: 43.75%
[ 3/ 50] FashionMNIST Time 0.03794s Training Accuracy: 53.91% Test Accuracy: 59.38%
[ 4/ 50] MNIST Time 0.03827s Training Accuracy: 47.95% Test Accuracy: 37.50%
[ 4/ 50] FashionMNIST Time 0.03606s Training Accuracy: 60.64% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.04552s Training Accuracy: 55.57% Test Accuracy: 46.88%
[ 5/ 50] FashionMNIST Time 0.04194s Training Accuracy: 65.92% Test Accuracy: 59.38%
[ 6/ 50] MNIST Time 0.03920s Training Accuracy: 60.74% Test Accuracy: 50.00%
[ 6/ 50] FashionMNIST Time 0.03929s Training Accuracy: 72.27% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.03858s Training Accuracy: 67.19% Test Accuracy: 46.88%
[ 7/ 50] FashionMNIST Time 0.05045s Training Accuracy: 77.34% Test Accuracy: 56.25%
[ 8/ 50] MNIST Time 0.03932s Training Accuracy: 72.46% Test Accuracy: 50.00%
[ 8/ 50] FashionMNIST Time 0.04704s Training Accuracy: 81.25% Test Accuracy: 56.25%
[ 9/ 50] MNIST Time 0.03749s Training Accuracy: 77.73% Test Accuracy: 46.88%
[ 9/ 50] FashionMNIST Time 0.03637s Training Accuracy: 83.69% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.03652s Training Accuracy: 83.59% Test Accuracy: 46.88%
[ 10/ 50] FashionMNIST Time 0.03730s Training Accuracy: 87.50% Test Accuracy: 62.50%
[ 11/ 50] MNIST Time 0.04709s Training Accuracy: 88.38% Test Accuracy: 50.00%
[ 11/ 50] FashionMNIST Time 0.03982s Training Accuracy: 89.16% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.04864s Training Accuracy: 90.04% Test Accuracy: 53.12%
[ 12/ 50] FashionMNIST Time 0.03661s Training Accuracy: 91.02% Test Accuracy: 65.62%
[ 13/ 50] MNIST Time 0.03445s Training Accuracy: 91.60% Test Accuracy: 46.88%
[ 13/ 50] FashionMNIST Time 0.03442s Training Accuracy: 93.75% Test Accuracy: 65.62%
[ 14/ 50] MNIST Time 0.03414s Training Accuracy: 94.63% Test Accuracy: 53.12%
[ 14/ 50] FashionMNIST Time 0.04496s Training Accuracy: 94.53% Test Accuracy: 62.50%
[ 15/ 50] MNIST Time 0.03390s Training Accuracy: 95.31% Test Accuracy: 56.25%
[ 15/ 50] FashionMNIST Time 0.04374s Training Accuracy: 95.80% Test Accuracy: 68.75%
[ 16/ 50] MNIST Time 0.03302s Training Accuracy: 97.56% Test Accuracy: 53.12%
[ 16/ 50] FashionMNIST Time 0.03282s Training Accuracy: 95.80% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.03355s Training Accuracy: 98.93% Test Accuracy: 50.00%
[ 17/ 50] FashionMNIST Time 0.03405s Training Accuracy: 96.68% Test Accuracy: 71.88%
[ 18/ 50] MNIST Time 0.04450s Training Accuracy: 99.41% Test Accuracy: 53.12%
[ 18/ 50] FashionMNIST Time 0.03373s Training Accuracy: 97.36% Test Accuracy: 71.88%
[ 19/ 50] MNIST Time 0.04407s Training Accuracy: 99.51% Test Accuracy: 53.12%
[ 19/ 50] FashionMNIST Time 0.03423s Training Accuracy: 98.54% Test Accuracy: 71.88%
[ 20/ 50] MNIST Time 0.03257s Training Accuracy: 99.80% Test Accuracy: 50.00%
[ 20/ 50] FashionMNIST Time 0.03313s Training Accuracy: 99.22% Test Accuracy: 71.88%
[ 21/ 50] MNIST Time 0.03418s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 21/ 50] FashionMNIST Time 0.04374s Training Accuracy: 99.61% Test Accuracy: 75.00%
[ 22/ 50] MNIST Time 0.03329s Training Accuracy: 99.90% Test Accuracy: 50.00%
[ 22/ 50] FashionMNIST Time 0.04691s Training Accuracy: 99.22% Test Accuracy: 75.00%
[ 23/ 50] MNIST Time 0.03269s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 23/ 50] FashionMNIST Time 0.03413s Training Accuracy: 99.41% Test Accuracy: 71.88%
[ 24/ 50] MNIST Time 0.03324s Training Accuracy: 99.90% Test Accuracy: 50.00%
[ 24/ 50] FashionMNIST Time 0.03345s Training Accuracy: 99.80% Test Accuracy: 75.00%
[ 25/ 50] MNIST Time 0.04341s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 25/ 50] FashionMNIST Time 0.03367s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 26/ 50] MNIST Time 0.04409s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 26/ 50] FashionMNIST Time 0.03411s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 27/ 50] MNIST Time 0.03305s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 27/ 50] FashionMNIST Time 0.03389s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 28/ 50] MNIST Time 0.03307s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 28/ 50] FashionMNIST Time 0.04396s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 29/ 50] MNIST Time 0.03289s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 29/ 50] FashionMNIST Time 0.04458s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 30/ 50] MNIST Time 0.03305s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 30/ 50] FashionMNIST Time 0.03334s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 31/ 50] MNIST Time 0.03341s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 31/ 50] FashionMNIST Time 0.03396s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 32/ 50] MNIST Time 0.04374s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 32/ 50] FashionMNIST Time 0.03353s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 33/ 50] MNIST Time 0.04334s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 33/ 50] FashionMNIST Time 0.03371s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 34/ 50] MNIST Time 0.03367s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 34/ 50] FashionMNIST Time 0.03382s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 35/ 50] MNIST Time 0.03343s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 35/ 50] FashionMNIST Time 0.04723s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 36/ 50] MNIST Time 0.03216s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 36/ 50] FashionMNIST Time 0.04361s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 37/ 50] MNIST Time 0.03154s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 37/ 50] FashionMNIST Time 0.03306s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 38/ 50] MNIST Time 0.03213s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 38/ 50] FashionMNIST Time 0.03335s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 39/ 50] MNIST Time 0.04400s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 39/ 50] FashionMNIST Time 0.03626s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 40/ 50] MNIST Time 0.04176s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 40/ 50] FashionMNIST Time 0.03400s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 41/ 50] MNIST Time 0.03407s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 41/ 50] FashionMNIST Time 0.03420s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 42/ 50] MNIST Time 0.03370s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 42/ 50] FashionMNIST Time 0.04494s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 43/ 50] MNIST Time 0.03360s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 43/ 50] FashionMNIST Time 0.04574s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 44/ 50] MNIST Time 0.03407s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 44/ 50] FashionMNIST Time 0.03244s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 45/ 50] MNIST Time 0.03424s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 45/ 50] FashionMNIST Time 0.03535s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 46/ 50] MNIST Time 0.04443s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 46/ 50] FashionMNIST Time 0.03421s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 47/ 50] MNIST Time 0.04485s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 47/ 50] FashionMNIST Time 0.03423s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 48/ 50] MNIST Time 0.03363s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 48/ 50] FashionMNIST Time 0.03394s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 49/ 50] MNIST Time 0.03419s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 49/ 50] FashionMNIST Time 0.03465s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 50/ 50] MNIST Time 0.03257s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 50/ 50] FashionMNIST Time 0.04516s Training Accuracy: 100.00% Test Accuracy: 71.88%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 56.25%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 71.88%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.