Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling ComponentArraysReactantExt...
17777.8 ms ✓ ComponentArrays → ComponentArraysReactantExt
1 dependency successfully precompiled in 18 seconds. 95 already precompiled.
Precompiling ReactantOneHotArraysExt...
17356.2 ms ✓ Reactant → ReactantOneHotArraysExt
1 dependency successfully precompiled in 18 seconds. 104 already precompiled.
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
create_model (generic function with 1 method)
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-07-09 04:19:37.074460: I external/xla/xla/service/service.cc:153] XLA service 0x37d40e50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-09 04:19:37.074504: I external/xla/xla/service/service.cc:161] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752034777.075523 1223409 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752034777.075620 1223409 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752034777.075682 1223409 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752034777.090291 1223409 cuda_dnn.cc:471] Loaded cuDNN version 90800
[ 1/ 50] MNIST Time 57.33054s Training Accuracy: 35.25% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.04579s Training Accuracy: 32.81% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.10136s Training Accuracy: 36.72% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.04550s Training Accuracy: 48.24% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.04622s Training Accuracy: 40.72% Test Accuracy: 37.50%
[ 3/ 50] FashionMNIST Time 0.07281s Training Accuracy: 54.10% Test Accuracy: 53.12%
[ 4/ 50] MNIST Time 0.03742s Training Accuracy: 52.93% Test Accuracy: 43.75%
[ 4/ 50] FashionMNIST Time 0.04111s Training Accuracy: 61.82% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.07351s Training Accuracy: 55.37% Test Accuracy: 40.62%
[ 5/ 50] FashionMNIST Time 0.04199s Training Accuracy: 66.41% Test Accuracy: 59.38%
[ 6/ 50] MNIST Time 0.04080s Training Accuracy: 61.52% Test Accuracy: 37.50%
[ 6/ 50] FashionMNIST Time 0.03549s Training Accuracy: 71.29% Test Accuracy: 62.50%
[ 7/ 50] MNIST Time 0.04441s Training Accuracy: 67.97% Test Accuracy: 40.62%
[ 7/ 50] FashionMNIST Time 0.03976s Training Accuracy: 75.59% Test Accuracy: 65.62%
[ 8/ 50] MNIST Time 0.05213s Training Accuracy: 76.66% Test Accuracy: 46.88%
[ 8/ 50] FashionMNIST Time 0.03696s Training Accuracy: 79.69% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.04874s Training Accuracy: 80.27% Test Accuracy: 53.12%
[ 9/ 50] FashionMNIST Time 0.03583s Training Accuracy: 81.54% Test Accuracy: 65.62%
[ 10/ 50] MNIST Time 0.04598s Training Accuracy: 83.89% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.03729s Training Accuracy: 86.04% Test Accuracy: 68.75%
[ 11/ 50] MNIST Time 0.05240s Training Accuracy: 87.40% Test Accuracy: 53.12%
[ 11/ 50] FashionMNIST Time 0.03900s Training Accuracy: 88.28% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.04762s Training Accuracy: 89.75% Test Accuracy: 53.12%
[ 12/ 50] FashionMNIST Time 0.03764s Training Accuracy: 89.36% Test Accuracy: 71.88%
[ 13/ 50] MNIST Time 0.04787s Training Accuracy: 92.29% Test Accuracy: 59.38%
[ 13/ 50] FashionMNIST Time 0.03656s Training Accuracy: 91.99% Test Accuracy: 71.88%
[ 14/ 50] MNIST Time 0.03615s Training Accuracy: 95.02% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.03541s Training Accuracy: 93.65% Test Accuracy: 65.62%
[ 15/ 50] MNIST Time 0.03463s Training Accuracy: 96.29% Test Accuracy: 56.25%
[ 15/ 50] FashionMNIST Time 0.03565s Training Accuracy: 94.63% Test Accuracy: 68.75%
[ 16/ 50] MNIST Time 0.03485s Training Accuracy: 98.34% Test Accuracy: 56.25%
[ 16/ 50] FashionMNIST Time 0.03470s Training Accuracy: 96.09% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.03618s Training Accuracy: 99.22% Test Accuracy: 56.25%
[ 17/ 50] FashionMNIST Time 0.03676s Training Accuracy: 96.97% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.03503s Training Accuracy: 99.71% Test Accuracy: 62.50%
[ 18/ 50] FashionMNIST Time 0.03550s Training Accuracy: 96.00% Test Accuracy: 65.62%
[ 19/ 50] MNIST Time 0.03957s Training Accuracy: 99.61% Test Accuracy: 56.25%
[ 19/ 50] FashionMNIST Time 0.03919s Training Accuracy: 98.24% Test Accuracy: 71.88%
[ 20/ 50] MNIST Time 0.03769s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 20/ 50] FashionMNIST Time 0.04939s Training Accuracy: 98.73% Test Accuracy: 75.00%
[ 21/ 50] MNIST Time 0.03725s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 21/ 50] FashionMNIST Time 0.04380s Training Accuracy: 98.93% Test Accuracy: 71.88%
[ 22/ 50] MNIST Time 0.03601s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 22/ 50] FashionMNIST Time 0.04651s Training Accuracy: 99.41% Test Accuracy: 75.00%
[ 23/ 50] MNIST Time 0.03458s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 23/ 50] FashionMNIST Time 0.04530s Training Accuracy: 99.80% Test Accuracy: 68.75%
[ 24/ 50] MNIST Time 0.03364s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 24/ 50] FashionMNIST Time 0.04542s Training Accuracy: 99.71% Test Accuracy: 71.88%
[ 25/ 50] MNIST Time 0.03633s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 25/ 50] FashionMNIST Time 0.04904s Training Accuracy: 99.90% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.03446s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 26/ 50] FashionMNIST Time 0.04674s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 27/ 50] MNIST Time 0.03522s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] FashionMNIST Time 0.05190s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 28/ 50] MNIST Time 0.03707s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.03490s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 29/ 50] MNIST Time 0.03558s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] FashionMNIST Time 0.03559s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 30/ 50] MNIST Time 0.03769s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] FashionMNIST Time 0.03677s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 31/ 50] MNIST Time 0.03631s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] FashionMNIST Time 0.03688s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 32/ 50] MNIST Time 0.03451s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 32/ 50] FashionMNIST Time 0.03478s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 33/ 50] MNIST Time 0.03598s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 33/ 50] FashionMNIST Time 0.03562s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 34/ 50] MNIST Time 0.03684s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.03753s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 35/ 50] MNIST Time 0.04911s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 35/ 50] FashionMNIST Time 0.03721s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 36/ 50] MNIST Time 0.04746s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] FashionMNIST Time 0.03472s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 37/ 50] MNIST Time 0.04412s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] FashionMNIST Time 0.03388s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 38/ 50] MNIST Time 0.04484s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 38/ 50] FashionMNIST Time 0.03401s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 39/ 50] MNIST Time 0.04487s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] FashionMNIST Time 0.03149s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 40/ 50] MNIST Time 0.04462s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] FashionMNIST Time 0.03384s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 41/ 50] MNIST Time 0.04389s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] FashionMNIST Time 0.03421s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 42/ 50] MNIST Time 0.03341s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] FashionMNIST Time 0.03376s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 43/ 50] MNIST Time 0.03422s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] FashionMNIST Time 0.03456s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 44/ 50] MNIST Time 0.03370s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 44/ 50] FashionMNIST Time 0.03409s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 45/ 50] MNIST Time 0.03412s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] FashionMNIST Time 0.03500s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 46/ 50] MNIST Time 0.03401s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] FashionMNIST Time 0.03981s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 47/ 50] MNIST Time 0.03384s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 47/ 50] FashionMNIST Time 0.03333s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 48/ 50] MNIST Time 0.03302s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 48/ 50] FashionMNIST Time 0.03311s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 49/ 50] MNIST Time 0.03456s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] FashionMNIST Time 0.04666s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 50/ 50] MNIST Time 0.03608s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 50/ 50] FashionMNIST Time 0.04766s Training Accuracy: 100.00% Test Accuracy: 71.88%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 71.88%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.