Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-08-02 22:52:03.754867: I external/xla/xla/service/service.cc:163] XLA service 0x29c5f0a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-08-02 22:52:03.754917: I external/xla/xla/service/service.cc:171] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1754175123.755912 541672 se_gpu_pjrt_client.cc:1380] Using BFC allocator.
I0000 00:00:1754175123.755999 541672 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1754175123.756036 541672 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
2025-08-02 22:52:03.784723: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800
[ 1/ 50] MNIST Time 48.35461s Training Accuracy: 35.16% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.05158s Training Accuracy: 32.91% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.04933s Training Accuracy: 35.25% Test Accuracy: 37.50%
[ 2/ 50] FashionMNIST Time 0.30378s Training Accuracy: 46.48% Test Accuracy: 53.12%
[ 3/ 50] MNIST Time 0.04965s Training Accuracy: 42.48% Test Accuracy: 31.25%
[ 3/ 50] FashionMNIST Time 0.04336s Training Accuracy: 53.81% Test Accuracy: 59.38%
[ 4/ 50] MNIST Time 0.18753s Training Accuracy: 52.15% Test Accuracy: 43.75%
[ 4/ 50] FashionMNIST Time 0.03714s Training Accuracy: 61.82% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.03955s Training Accuracy: 57.71% Test Accuracy: 43.75%
[ 5/ 50] FashionMNIST Time 0.04418s Training Accuracy: 67.68% Test Accuracy: 56.25%
[ 6/ 50] MNIST Time 0.04251s Training Accuracy: 63.28% Test Accuracy: 50.00%
[ 6/ 50] FashionMNIST Time 0.04186s Training Accuracy: 74.71% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.03328s Training Accuracy: 72.36% Test Accuracy: 43.75%
[ 7/ 50] FashionMNIST Time 0.03312s Training Accuracy: 74.80% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.03298s Training Accuracy: 76.17% Test Accuracy: 50.00%
[ 8/ 50] FashionMNIST Time 0.05678s Training Accuracy: 80.66% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.03372s Training Accuracy: 80.18% Test Accuracy: 56.25%
[ 9/ 50] FashionMNIST Time 0.03201s Training Accuracy: 83.01% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.04281s Training Accuracy: 85.45% Test Accuracy: 59.38%
[ 10/ 50] FashionMNIST Time 0.03385s Training Accuracy: 86.62% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.04379s Training Accuracy: 86.33% Test Accuracy: 53.12%
[ 11/ 50] FashionMNIST Time 0.03565s Training Accuracy: 87.89% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.03655s Training Accuracy: 91.02% Test Accuracy: 50.00%
[ 12/ 50] FashionMNIST Time 0.03838s Training Accuracy: 90.04% Test Accuracy: 65.62%
[ 13/ 50] MNIST Time 0.03864s Training Accuracy: 93.55% Test Accuracy: 56.25%
[ 13/ 50] FashionMNIST Time 0.04656s Training Accuracy: 91.80% Test Accuracy: 65.62%
[ 14/ 50] MNIST Time 0.03584s Training Accuracy: 95.31% Test Accuracy: 59.38%
[ 14/ 50] FashionMNIST Time 0.04604s Training Accuracy: 93.75% Test Accuracy: 62.50%
[ 15/ 50] MNIST Time 0.03807s Training Accuracy: 96.29% Test Accuracy: 59.38%
[ 15/ 50] FashionMNIST Time 0.04020s Training Accuracy: 94.92% Test Accuracy: 68.75%
[ 16/ 50] MNIST Time 0.03811s Training Accuracy: 98.63% Test Accuracy: 65.62%
[ 16/ 50] FashionMNIST Time 0.03614s Training Accuracy: 96.97% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.03897s Training Accuracy: 99.51% Test Accuracy: 65.62%
[ 17/ 50] FashionMNIST Time 0.03533s Training Accuracy: 97.85% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.04543s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 18/ 50] FashionMNIST Time 0.03541s Training Accuracy: 97.27% Test Accuracy: 68.75%
[ 19/ 50] MNIST Time 0.03597s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 19/ 50] FashionMNIST Time 0.04515s Training Accuracy: 98.93% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.03635s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 20/ 50] FashionMNIST Time 0.04648s Training Accuracy: 99.41% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.03533s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 21/ 50] FashionMNIST Time 0.03712s Training Accuracy: 99.51% Test Accuracy: 71.88%
[ 22/ 50] MNIST Time 0.03477s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 22/ 50] FashionMNIST Time 0.03834s Training Accuracy: 99.61% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.04557s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 23/ 50] FashionMNIST Time 0.03901s Training Accuracy: 99.80% Test Accuracy: 71.88%
[ 24/ 50] MNIST Time 0.04464s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 24/ 50] FashionMNIST Time 0.03433s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 25/ 50] MNIST Time 0.03397s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 25/ 50] FashionMNIST Time 0.03400s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 26/ 50] MNIST Time 0.03577s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 26/ 50] FashionMNIST Time 0.04462s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 27/ 50] MNIST Time 0.03558s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] FashionMNIST Time 0.03634s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 28/ 50] MNIST Time 0.03557s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.03785s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 29/ 50] MNIST Time 0.04760s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 29/ 50] FashionMNIST Time 0.03795s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] MNIST Time 0.04645s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 30/ 50] FashionMNIST Time 0.03590s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 31/ 50] MNIST Time 0.03601s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 31/ 50] FashionMNIST Time 0.03664s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 32/ 50] MNIST Time 0.03673s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 32/ 50] FashionMNIST Time 0.04720s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 33/ 50] MNIST Time 0.03415s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 33/ 50] FashionMNIST Time 0.04319s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 34/ 50] MNIST Time 0.03745s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 34/ 50] FashionMNIST Time 0.03704s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] MNIST Time 0.04601s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 35/ 50] FashionMNIST Time 0.04599s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 36/ 50] MNIST Time 0.05098s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 36/ 50] FashionMNIST Time 0.07133s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 37/ 50] MNIST Time 0.03926s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 37/ 50] FashionMNIST Time 0.03870s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 38/ 50] MNIST Time 0.03492s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 38/ 50] FashionMNIST Time 0.04366s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] MNIST Time 0.03533s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 39/ 50] FashionMNIST Time 0.04341s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 40/ 50] MNIST Time 0.03603s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 40/ 50] FashionMNIST Time 0.03649s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 41/ 50] MNIST Time 0.03704s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 41/ 50] FashionMNIST Time 0.03729s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 42/ 50] MNIST Time 0.04816s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 42/ 50] FashionMNIST Time 0.03947s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] MNIST Time 0.04503s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 43/ 50] FashionMNIST Time 0.03645s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] MNIST Time 0.03611s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 44/ 50] FashionMNIST Time 0.04499s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 45/ 50] MNIST Time 0.03547s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 45/ 50] FashionMNIST Time 0.04758s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] MNIST Time 0.03338s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 46/ 50] FashionMNIST Time 0.03297s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 47/ 50] MNIST Time 0.03528s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 47/ 50] FashionMNIST Time 0.03517s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.04368s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 48/ 50] FashionMNIST Time 0.03506s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.04569s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] FashionMNIST Time 0.03487s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 50/ 50] MNIST Time 0.03529s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] FashionMNIST Time 0.03437s Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 68.75%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.