Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1758301779.915928 1237731 service.cc:158] XLA service 0x38a5e360 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1758301779.916017 1237731 service.cc:166] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1758301779.917083 1237731 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1758301779.917155 1237731 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1758301779.917217 1237731 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1758301779.927690 1237731 cuda_dnn.cc:463] Loaded cuDNN version 91200
[ 1/ 50] MNIST Time 44.91185s Training Accuracy: 34.86% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.03526s Training Accuracy: 32.62% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.03441s Training Accuracy: 37.50% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.03214s Training Accuracy: 46.09% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.03528s Training Accuracy: 40.14% Test Accuracy: 43.75%
[ 3/ 50] FashionMNIST Time 0.03829s Training Accuracy: 52.54% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.03235s Training Accuracy: 52.44% Test Accuracy: 50.00%
[ 4/ 50] FashionMNIST Time 0.03246s Training Accuracy: 60.55% Test Accuracy: 53.12%
[ 5/ 50] MNIST Time 0.03494s Training Accuracy: 58.01% Test Accuracy: 43.75%
[ 5/ 50] FashionMNIST Time 0.03838s Training Accuracy: 66.02% Test Accuracy: 53.12%
[ 6/ 50] MNIST Time 0.03769s Training Accuracy: 62.89% Test Accuracy: 46.88%
[ 6/ 50] FashionMNIST Time 0.03613s Training Accuracy: 71.78% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.03435s Training Accuracy: 70.51% Test Accuracy: 50.00%
[ 7/ 50] FashionMNIST Time 0.03377s Training Accuracy: 76.37% Test Accuracy: 59.38%
[ 8/ 50] MNIST Time 0.03410s Training Accuracy: 76.17% Test Accuracy: 46.88%
[ 8/ 50] FashionMNIST Time 0.03720s Training Accuracy: 79.69% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.04663s Training Accuracy: 79.49% Test Accuracy: 53.12%
[ 9/ 50] FashionMNIST Time 0.03418s Training Accuracy: 82.91% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.04292s Training Accuracy: 85.25% Test Accuracy: 43.75%
[ 10/ 50] FashionMNIST Time 0.03467s Training Accuracy: 85.94% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.03336s Training Accuracy: 89.55% Test Accuracy: 46.88%
[ 11/ 50] FashionMNIST Time 0.03361s Training Accuracy: 87.89% Test Accuracy: 59.38%
[ 12/ 50] MNIST Time 0.03349s Training Accuracy: 89.55% Test Accuracy: 46.88%
[ 12/ 50] FashionMNIST Time 0.05185s Training Accuracy: 90.43% Test Accuracy: 65.62%
[ 13/ 50] MNIST Time 0.03251s Training Accuracy: 92.38% Test Accuracy: 56.25%
[ 13/ 50] FashionMNIST Time 0.03311s Training Accuracy: 92.97% Test Accuracy: 65.62%
[ 14/ 50] MNIST Time 0.04378s Training Accuracy: 95.21% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.03325s Training Accuracy: 93.95% Test Accuracy: 65.62%
[ 15/ 50] MNIST Time 0.03277s Training Accuracy: 97.07% Test Accuracy: 56.25%
[ 15/ 50] FashionMNIST Time 0.03380s Training Accuracy: 95.41% Test Accuracy: 65.62%
[ 16/ 50] MNIST Time 0.03270s Training Accuracy: 98.24% Test Accuracy: 59.38%
[ 16/ 50] FashionMNIST Time 0.04128s Training Accuracy: 96.39% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.03232s Training Accuracy: 98.83% Test Accuracy: 53.12%
[ 17/ 50] FashionMNIST Time 0.03327s Training Accuracy: 97.75% Test Accuracy: 75.00%
[ 18/ 50] MNIST Time 0.04257s Training Accuracy: 99.32% Test Accuracy: 56.25%
[ 18/ 50] FashionMNIST Time 0.03422s Training Accuracy: 97.95% Test Accuracy: 68.75%
[ 19/ 50] MNIST Time 0.03325s Training Accuracy: 99.61% Test Accuracy: 62.50%
[ 19/ 50] FashionMNIST Time 0.03438s Training Accuracy: 98.34% Test Accuracy: 75.00%
[ 20/ 50] MNIST Time 0.03402s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 20/ 50] FashionMNIST Time 0.04011s Training Accuracy: 99.12% Test Accuracy: 71.88%
[ 21/ 50] MNIST Time 0.03425s Training Accuracy: 99.80% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.03350s Training Accuracy: 99.12% Test Accuracy: 75.00%
[ 22/ 50] MNIST Time 0.04300s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 22/ 50] FashionMNIST Time 0.03303s Training Accuracy: 99.61% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.03452s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 23/ 50] FashionMNIST Time 0.04322s Training Accuracy: 99.90% Test Accuracy: 75.00%
[ 24/ 50] MNIST Time 0.03920s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 24/ 50] FashionMNIST Time 0.04076s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 25/ 50] MNIST Time 0.03256s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 25/ 50] FashionMNIST Time 0.03373s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.04086s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 26/ 50] FashionMNIST Time 0.03341s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 27/ 50] MNIST Time 0.03206s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 27/ 50] FashionMNIST Time 0.04449s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 28/ 50] MNIST Time 0.03732s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 28/ 50] FashionMNIST Time 0.04651s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 29/ 50] MNIST Time 0.03596s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] FashionMNIST Time 0.03495s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 30/ 50] MNIST Time 0.03957s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] FashionMNIST Time 0.03293s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 31/ 50] MNIST Time 0.03332s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] FashionMNIST Time 0.04431s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 32/ 50] MNIST Time 0.04419s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 32/ 50] FashionMNIST Time 0.03834s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 33/ 50] MNIST Time 0.03549s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 33/ 50] FashionMNIST Time 0.03447s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 34/ 50] MNIST Time 0.04177s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 34/ 50] FashionMNIST Time 0.03574s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 35/ 50] MNIST Time 0.03798s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 35/ 50] FashionMNIST Time 0.04611s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 36/ 50] MNIST Time 0.03718s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] FashionMNIST Time 0.04562s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 37/ 50] MNIST Time 0.03355s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 37/ 50] FashionMNIST Time 0.03521s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] MNIST Time 0.04263s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] FashionMNIST Time 0.03302s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 39/ 50] MNIST Time 0.03432s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 39/ 50] FashionMNIST Time 0.04227s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 40/ 50] MNIST Time 0.03469s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] FashionMNIST Time 0.04237s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 41/ 50] MNIST Time 0.03344s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 41/ 50] FashionMNIST Time 0.03709s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 42/ 50] MNIST Time 0.04165s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 42/ 50] FashionMNIST Time 0.03256s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] MNIST Time 0.03509s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 43/ 50] FashionMNIST Time 0.04235s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] MNIST Time 0.03321s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] FashionMNIST Time 0.04133s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] MNIST Time 0.03385s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] FashionMNIST Time 0.03415s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] MNIST Time 0.04073s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] FashionMNIST Time 0.03730s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] MNIST Time 0.03849s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] FashionMNIST Time 0.05852s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 48/ 50] MNIST Time 0.04812s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 48/ 50] FashionMNIST Time 0.04774s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.03226s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 49/ 50] FashionMNIST Time 0.03401s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 50/ 50] MNIST Time 0.04350s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] FashionMNIST Time 0.03333s Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.