Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1760386940.595062 2911630 service.cc:158] XLA service 0x309d94f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760386940.595125 2911630 service.cc:166] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1760386940.596210 2911630 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1760386940.596277 2911630 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1760386940.596351 2911630 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1760386940.608535 2911630 cuda_dnn.cc:463] Loaded cuDNN version 91200
[ 1/ 50] MNIST Time 45.51114s Training Accuracy: 35.45% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.04402s Training Accuracy: 32.81% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.04789s Training Accuracy: 35.55% Test Accuracy: 31.25%
[ 2/ 50] FashionMNIST Time 0.03627s Training Accuracy: 46.68% Test Accuracy: 53.12%
[ 3/ 50] MNIST Time 0.03822s Training Accuracy: 41.70% Test Accuracy: 40.62%
[ 3/ 50] FashionMNIST Time 0.03935s Training Accuracy: 53.22% Test Accuracy: 59.38%
[ 4/ 50] MNIST Time 0.03257s Training Accuracy: 49.51% Test Accuracy: 31.25%
[ 4/ 50] FashionMNIST Time 0.03333s Training Accuracy: 61.52% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.06480s Training Accuracy: 55.37% Test Accuracy: 59.38%
[ 5/ 50] FashionMNIST Time 0.03497s Training Accuracy: 66.41% Test Accuracy: 62.50%
[ 6/ 50] MNIST Time 0.03264s Training Accuracy: 61.43% Test Accuracy: 56.25%
[ 6/ 50] FashionMNIST Time 0.03122s Training Accuracy: 69.43% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.04244s Training Accuracy: 66.80% Test Accuracy: 46.88%
[ 7/ 50] FashionMNIST Time 0.03168s Training Accuracy: 73.44% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.04922s Training Accuracy: 72.66% Test Accuracy: 53.12%
[ 8/ 50] FashionMNIST Time 0.03242s Training Accuracy: 80.57% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.03147s Training Accuracy: 76.86% Test Accuracy: 46.88%
[ 9/ 50] FashionMNIST Time 0.03252s Training Accuracy: 83.11% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.03266s Training Accuracy: 82.81% Test Accuracy: 59.38%
[ 10/ 50] FashionMNIST Time 0.04259s Training Accuracy: 86.82% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.03197s Training Accuracy: 87.89% Test Accuracy: 62.50%
[ 11/ 50] FashionMNIST Time 0.04283s Training Accuracy: 89.84% Test Accuracy: 65.62%
[ 12/ 50] MNIST Time 0.03215s Training Accuracy: 87.11% Test Accuracy: 65.62%
[ 12/ 50] FashionMNIST Time 0.03225s Training Accuracy: 91.41% Test Accuracy: 71.88%
[ 13/ 50] MNIST Time 0.03201s Training Accuracy: 91.21% Test Accuracy: 56.25%
[ 13/ 50] FashionMNIST Time 0.03231s Training Accuracy: 93.36% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.04293s Training Accuracy: 93.46% Test Accuracy: 50.00%
[ 14/ 50] FashionMNIST Time 0.03224s Training Accuracy: 95.51% Test Accuracy: 68.75%
[ 15/ 50] MNIST Time 0.04448s Training Accuracy: 95.41% Test Accuracy: 53.12%
[ 15/ 50] FashionMNIST Time 0.03222s Training Accuracy: 95.51% Test Accuracy: 68.75%
[ 16/ 50] MNIST Time 0.03162s Training Accuracy: 97.17% Test Accuracy: 59.38%
[ 16/ 50] FashionMNIST Time 0.03206s Training Accuracy: 96.39% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.03201s Training Accuracy: 98.93% Test Accuracy: 56.25%
[ 17/ 50] FashionMNIST Time 0.04189s Training Accuracy: 97.66% Test Accuracy: 75.00%
[ 18/ 50] MNIST Time 0.03212s Training Accuracy: 99.12% Test Accuracy: 59.38%
[ 18/ 50] FashionMNIST Time 0.04147s Training Accuracy: 96.58% Test Accuracy: 75.00%
[ 19/ 50] MNIST Time 0.03180s Training Accuracy: 99.22% Test Accuracy: 56.25%
[ 19/ 50] FashionMNIST Time 0.03175s Training Accuracy: 98.63% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.03191s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 20/ 50] FashionMNIST Time 0.03190s Training Accuracy: 98.83% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.03208s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.03207s Training Accuracy: 99.12% Test Accuracy: 68.75%
[ 22/ 50] MNIST Time 0.04156s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 22/ 50] FashionMNIST Time 0.03175s Training Accuracy: 99.71% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.04212s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 23/ 50] FashionMNIST Time 0.03173s Training Accuracy: 99.80% Test Accuracy: 68.75%
[ 24/ 50] MNIST Time 0.03235s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 24/ 50] FashionMNIST Time 0.03173s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 25/ 50] MNIST Time 0.03215s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 25/ 50] FashionMNIST Time 0.04164s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 26/ 50] MNIST Time 0.03622s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 26/ 50] FashionMNIST Time 0.04840s Training Accuracy: 99.90% Test Accuracy: 75.00%
[ 27/ 50] MNIST Time 0.03701s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 27/ 50] FashionMNIST Time 0.03662s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 28/ 50] MNIST Time 0.03629s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 28/ 50] FashionMNIST Time 0.03663s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] MNIST Time 0.04744s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 29/ 50] FashionMNIST Time 0.03592s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] MNIST Time 0.04738s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 30/ 50] FashionMNIST Time 0.03661s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 31/ 50] MNIST Time 0.03607s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 31/ 50] FashionMNIST Time 0.03657s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] MNIST Time 0.03675s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 32/ 50] FashionMNIST Time 0.04829s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 33/ 50] MNIST Time 0.03686s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 33/ 50] FashionMNIST Time 0.04756s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 34/ 50] MNIST Time 0.03623s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 34/ 50] FashionMNIST Time 0.03631s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] MNIST Time 0.03691s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 35/ 50] FashionMNIST Time 0.03504s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 36/ 50] MNIST Time 0.04534s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 36/ 50] FashionMNIST Time 0.03318s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 37/ 50] MNIST Time 0.04464s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 37/ 50] FashionMNIST Time 0.04773s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 38/ 50] MNIST Time 0.02918s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 38/ 50] FashionMNIST Time 0.03112s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 39/ 50] MNIST Time 0.03137s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 39/ 50] FashionMNIST Time 0.03261s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 40/ 50] MNIST Time 0.03265s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 40/ 50] FashionMNIST Time 0.04339s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 41/ 50] MNIST Time 0.03240s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 41/ 50] FashionMNIST Time 0.04280s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 42/ 50] MNIST Time 0.03231s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 42/ 50] FashionMNIST Time 0.03237s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 43/ 50] MNIST Time 0.03219s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 43/ 50] FashionMNIST Time 0.03274s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 44/ 50] MNIST Time 0.04291s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 44/ 50] FashionMNIST Time 0.03238s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 45/ 50] MNIST Time 0.04320s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 45/ 50] FashionMNIST Time 0.03244s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 46/ 50] MNIST Time 0.03297s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 46/ 50] FashionMNIST Time 0.03286s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 47/ 50] MNIST Time 0.03270s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 47/ 50] FashionMNIST Time 0.04404s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 48/ 50] MNIST Time 0.03267s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 48/ 50] FashionMNIST Time 0.04342s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 49/ 50] MNIST Time 0.03251s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 49/ 50] FashionMNIST Time 0.03255s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 50/ 50] MNIST Time 0.03156s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 50/ 50] FashionMNIST Time 0.03187s Training Accuracy: 100.00% Test Accuracy: 59.38%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 59.38%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 59.38%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.