Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1757733871.383174 558333 service.cc:163] XLA service 0x33e94510 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757733871.383222 558333 service.cc:171] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1757733871.383765 558333 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1757733871.383801 558333 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1757733871.383851 558333 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1757733871.393709 558333 cuda_dnn.cc:463] Loaded cuDNN version 91200
[ 1/ 50] MNIST Time 44.56980s Training Accuracy: 35.45% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.03309s Training Accuracy: 31.74% Test Accuracy: 40.62%
[ 2/ 50] MNIST Time 0.03185s Training Accuracy: 36.33% Test Accuracy: 31.25%
[ 2/ 50] FashionMNIST Time 0.03238s Training Accuracy: 46.00% Test Accuracy: 56.25%
[ 3/ 50] MNIST Time 0.04148s Training Accuracy: 41.89% Test Accuracy: 34.38%
[ 3/ 50] FashionMNIST Time 0.03182s Training Accuracy: 51.46% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.03533s Training Accuracy: 47.85% Test Accuracy: 34.38%
[ 4/ 50] FashionMNIST Time 0.03467s Training Accuracy: 60.74% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.03430s Training Accuracy: 54.00% Test Accuracy: 56.25%
[ 5/ 50] FashionMNIST Time 0.03263s Training Accuracy: 65.92% Test Accuracy: 62.50%
[ 6/ 50] MNIST Time 0.03622s Training Accuracy: 59.18% Test Accuracy: 46.88%
[ 6/ 50] FashionMNIST Time 0.03244s Training Accuracy: 70.12% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.03042s Training Accuracy: 66.80% Test Accuracy: 50.00%
[ 7/ 50] FashionMNIST Time 0.03231s Training Accuracy: 75.00% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.03201s Training Accuracy: 72.36% Test Accuracy: 53.12%
[ 8/ 50] FashionMNIST Time 0.03329s Training Accuracy: 79.79% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.03002s Training Accuracy: 76.27% Test Accuracy: 46.88%
[ 9/ 50] FashionMNIST Time 0.03267s Training Accuracy: 83.59% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.03257s Training Accuracy: 81.45% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.04474s Training Accuracy: 86.72% Test Accuracy: 62.50%
[ 11/ 50] MNIST Time 0.03266s Training Accuracy: 85.94% Test Accuracy: 46.88%
[ 11/ 50] FashionMNIST Time 0.04199s Training Accuracy: 88.38% Test Accuracy: 65.62%
[ 12/ 50] MNIST Time 0.03297s Training Accuracy: 88.18% Test Accuracy: 53.12%
[ 12/ 50] FashionMNIST Time 0.04250s Training Accuracy: 91.21% Test Accuracy: 62.50%
[ 13/ 50] MNIST Time 0.03014s Training Accuracy: 90.82% Test Accuracy: 43.75%
[ 13/ 50] FashionMNIST Time 0.04495s Training Accuracy: 92.09% Test Accuracy: 65.62%
[ 14/ 50] MNIST Time 0.03248s Training Accuracy: 93.46% Test Accuracy: 46.88%
[ 14/ 50] FashionMNIST Time 0.04335s Training Accuracy: 94.14% Test Accuracy: 68.75%
[ 15/ 50] MNIST Time 0.03416s Training Accuracy: 95.41% Test Accuracy: 46.88%
[ 15/ 50] FashionMNIST Time 0.03782s Training Accuracy: 95.31% Test Accuracy: 68.75%
[ 16/ 50] MNIST Time 0.03223s Training Accuracy: 97.17% Test Accuracy: 50.00%
[ 16/ 50] FashionMNIST Time 0.04253s Training Accuracy: 96.88% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.03075s Training Accuracy: 97.85% Test Accuracy: 46.88%
[ 17/ 50] FashionMNIST Time 0.04167s Training Accuracy: 97.85% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.03219s Training Accuracy: 98.73% Test Accuracy: 56.25%
[ 18/ 50] FashionMNIST Time 0.03212s Training Accuracy: 97.85% Test Accuracy: 68.75%
[ 19/ 50] MNIST Time 0.03253s Training Accuracy: 99.41% Test Accuracy: 46.88%
[ 19/ 50] FashionMNIST Time 0.03444s Training Accuracy: 98.14% Test Accuracy: 71.88%
[ 20/ 50] MNIST Time 0.03319s Training Accuracy: 99.71% Test Accuracy: 46.88%
[ 20/ 50] FashionMNIST Time 0.03239s Training Accuracy: 98.54% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.03199s Training Accuracy: 99.80% Test Accuracy: 46.88%
[ 21/ 50] FashionMNIST Time 0.03267s Training Accuracy: 99.12% Test Accuracy: 65.62%
[ 22/ 50] MNIST Time 0.03260s Training Accuracy: 99.80% Test Accuracy: 50.00%
[ 22/ 50] FashionMNIST Time 0.03097s Training Accuracy: 99.51% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.03227s Training Accuracy: 99.90% Test Accuracy: 50.00%
[ 23/ 50] FashionMNIST Time 0.03374s Training Accuracy: 99.61% Test Accuracy: 71.88%
[ 24/ 50] MNIST Time 0.03311s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 24/ 50] FashionMNIST Time 0.03419s Training Accuracy: 99.80% Test Accuracy: 65.62%
[ 25/ 50] MNIST Time 0.03391s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 25/ 50] FashionMNIST Time 0.03300s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 26/ 50] MNIST Time 0.03345s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 26/ 50] FashionMNIST Time 0.03474s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 27/ 50] MNIST Time 0.04702s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 27/ 50] FashionMNIST Time 0.03609s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 28/ 50] MNIST Time 0.04288s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 28/ 50] FashionMNIST Time 0.03339s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 29/ 50] MNIST Time 0.04576s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 29/ 50] FashionMNIST Time 0.03295s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] MNIST Time 0.04665s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 30/ 50] FashionMNIST Time 0.03221s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 31/ 50] MNIST Time 0.04267s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 31/ 50] FashionMNIST Time 0.03172s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 32/ 50] MNIST Time 0.04374s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 32/ 50] FashionMNIST Time 0.03238s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 33/ 50] MNIST Time 0.04248s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 33/ 50] FashionMNIST Time 0.03403s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 34/ 50] MNIST Time 0.05058s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 34/ 50] FashionMNIST Time 0.03103s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] MNIST Time 0.03344s Training Accuracy: 100.00% Test Accuracy: 50.00%
[ 35/ 50] FashionMNIST Time 0.03267s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 36/ 50] MNIST Time 0.03040s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 36/ 50] FashionMNIST Time 0.03182s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 37/ 50] MNIST Time 0.03037s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 37/ 50] FashionMNIST Time 0.03275s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 38/ 50] MNIST Time 0.03337s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 38/ 50] FashionMNIST Time 0.03348s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] MNIST Time 0.03353s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 39/ 50] FashionMNIST Time 0.03195s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 40/ 50] MNIST Time 0.03242s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 40/ 50] FashionMNIST Time 0.03370s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 41/ 50] MNIST Time 0.03344s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 41/ 50] FashionMNIST Time 0.03274s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 42/ 50] MNIST Time 0.03316s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 42/ 50] FashionMNIST Time 0.03977s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] MNIST Time 0.03278s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 43/ 50] FashionMNIST Time 0.04367s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] MNIST Time 0.03240s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 44/ 50] FashionMNIST Time 0.04206s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 45/ 50] MNIST Time 0.03392s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 45/ 50] FashionMNIST Time 0.04145s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] MNIST Time 0.03313s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 46/ 50] FashionMNIST Time 0.04258s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 47/ 50] MNIST Time 0.03340s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 47/ 50] FashionMNIST Time 0.04240s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.03493s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 48/ 50] FashionMNIST Time 0.04580s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.03367s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 49/ 50] FashionMNIST Time 0.04849s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 50/ 50] MNIST Time 0.03276s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 50/ 50] FashionMNIST Time 0.03555s Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 59.38%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 68.75%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.