Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling LuxComponentArraysExt...
1546.5 ms ✓ Lux → LuxComponentArraysExt
1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling MLDatasets...
2102.5 ms ✓ HDF5_jll
3320.0 ms ✓ DataDeps
7469.4 ms ✓ HDF5
2376.6 ms ✓ MAT
10840.0 ms ✓ MLDatasets
5 dependencies successfully precompiled in 23 seconds. 198 already precompiled.
Precompiling LuxMLUtilsExt...
2231.3 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 3 seconds. 164 already precompiled.
Precompiling ComponentArraysReactantExt...
17647.2 ms ✓ ComponentArrays → ComponentArraysReactantExt
1 dependency successfully precompiled in 18 seconds. 95 already precompiled.
Precompiling ReactantOneHotArraysExt...
17463.5 ms ✓ Reactant → ReactantOneHotArraysExt
1 dependency successfully precompiled in 18 seconds. 104 already precompiled.
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
create_model (generic function with 1 method)
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-07-03 16:08:53.312374: I external/xla/xla/service/service.cc:153] XLA service 0xbf761c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-03 16:08:53.312416: I external/xla/xla/service/service.cc:161] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1751558933.313242 251317 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1751558933.313302 251317 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1751558933.313346 251317 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1751558933.324739 251317 cuda_dnn.cc:471] Loaded cuDNN version 90800
[ 1/ 50] MNIST Time 56.37556s Training Accuracy: 34.86% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.05224s Training Accuracy: 31.25% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.03793s Training Accuracy: 36.62% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.02990s Training Accuracy: 46.48% Test Accuracy: 53.12%
[ 3/ 50] MNIST Time 0.03149s Training Accuracy: 39.75% Test Accuracy: 40.62%
[ 3/ 50] FashionMNIST Time 0.03857s Training Accuracy: 53.03% Test Accuracy: 65.62%
[ 4/ 50] MNIST Time 0.03015s Training Accuracy: 52.44% Test Accuracy: 46.88%
[ 4/ 50] FashionMNIST Time 0.03397s Training Accuracy: 62.89% Test Accuracy: 62.50%
[ 5/ 50] MNIST Time 0.05155s Training Accuracy: 56.35% Test Accuracy: 46.88%
[ 5/ 50] FashionMNIST Time 0.03870s Training Accuracy: 66.11% Test Accuracy: 59.38%
[ 6/ 50] MNIST Time 0.03693s Training Accuracy: 62.99% Test Accuracy: 50.00%
[ 6/ 50] FashionMNIST Time 0.03518s Training Accuracy: 73.44% Test Accuracy: 62.50%
[ 7/ 50] MNIST Time 0.04435s Training Accuracy: 70.31% Test Accuracy: 53.12%
[ 7/ 50] FashionMNIST Time 0.03614s Training Accuracy: 75.98% Test Accuracy: 65.62%
[ 8/ 50] MNIST Time 0.04210s Training Accuracy: 77.15% Test Accuracy: 53.12%
[ 8/ 50] FashionMNIST Time 0.03797s Training Accuracy: 80.47% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.03622s Training Accuracy: 80.27% Test Accuracy: 56.25%
[ 9/ 50] FashionMNIST Time 0.03619s Training Accuracy: 83.69% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.03446s Training Accuracy: 83.69% Test Accuracy: 56.25%
[ 10/ 50] FashionMNIST Time 0.03388s Training Accuracy: 86.13% Test Accuracy: 62.50%
[ 11/ 50] MNIST Time 0.03490s Training Accuracy: 86.91% Test Accuracy: 50.00%
[ 11/ 50] FashionMNIST Time 0.04449s Training Accuracy: 88.28% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.03477s Training Accuracy: 88.77% Test Accuracy: 59.38%
[ 12/ 50] FashionMNIST Time 0.04674s Training Accuracy: 89.55% Test Accuracy: 65.62%
[ 13/ 50] MNIST Time 0.03791s Training Accuracy: 94.53% Test Accuracy: 56.25%
[ 13/ 50] FashionMNIST Time 0.03539s Training Accuracy: 92.58% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.03613s Training Accuracy: 96.00% Test Accuracy: 62.50%
[ 14/ 50] FashionMNIST Time 0.03434s Training Accuracy: 93.55% Test Accuracy: 65.62%
[ 15/ 50] MNIST Time 0.04809s Training Accuracy: 95.41% Test Accuracy: 62.50%
[ 15/ 50] FashionMNIST Time 0.03416s Training Accuracy: 94.04% Test Accuracy: 68.75%
[ 16/ 50] MNIST Time 0.04854s Training Accuracy: 97.95% Test Accuracy: 56.25%
[ 16/ 50] FashionMNIST Time 0.03932s Training Accuracy: 96.00% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.03546s Training Accuracy: 99.22% Test Accuracy: 65.62%
[ 17/ 50] FashionMNIST Time 0.03662s Training Accuracy: 95.51% Test Accuracy: 71.88%
[ 18/ 50] MNIST Time 0.03472s Training Accuracy: 99.61% Test Accuracy: 65.62%
[ 18/ 50] FashionMNIST Time 0.04578s Training Accuracy: 96.00% Test Accuracy: 68.75%
[ 19/ 50] MNIST Time 0.03517s Training Accuracy: 99.80% Test Accuracy: 65.62%
[ 19/ 50] FashionMNIST Time 0.04614s Training Accuracy: 98.05% Test Accuracy: 71.88%
[ 20/ 50] MNIST Time 0.03462s Training Accuracy: 99.80% Test Accuracy: 71.88%
[ 20/ 50] FashionMNIST Time 0.03546s Training Accuracy: 99.12% Test Accuracy: 75.00%
[ 21/ 50] MNIST Time 0.03503s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 21/ 50] FashionMNIST Time 0.03509s Training Accuracy: 98.54% Test Accuracy: 68.75%
[ 22/ 50] MNIST Time 0.04682s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 22/ 50] FashionMNIST Time 0.03287s Training Accuracy: 99.22% Test Accuracy: 75.00%
[ 23/ 50] MNIST Time 0.04147s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 23/ 50] FashionMNIST Time 0.03267s Training Accuracy: 99.71% Test Accuracy: 75.00%
[ 24/ 50] MNIST Time 0.03231s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 24/ 50] FashionMNIST Time 0.03555s Training Accuracy: 99.51% Test Accuracy: 75.00%
[ 25/ 50] MNIST Time 0.03580s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 25/ 50] FashionMNIST Time 0.04585s Training Accuracy: 99.90% Test Accuracy: 75.00%
[ 26/ 50] MNIST Time 0.03391s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 26/ 50] FashionMNIST Time 0.04533s Training Accuracy: 99.80% Test Accuracy: 75.00%
[ 27/ 50] MNIST Time 0.03375s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 27/ 50] FashionMNIST Time 0.04852s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 28/ 50] MNIST Time 0.03477s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 28/ 50] FashionMNIST Time 0.03027s Training Accuracy: 99.80% Test Accuracy: 75.00%
[ 29/ 50] MNIST Time 0.03222s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] FashionMNIST Time 0.03440s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 30/ 50] MNIST Time 0.04315s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] FashionMNIST Time 0.03578s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 31/ 50] MNIST Time 0.04885s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 31/ 50] FashionMNIST Time 0.03527s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 32/ 50] MNIST Time 0.03635s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 32/ 50] FashionMNIST Time 0.03479s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 33/ 50] MNIST Time 0.03558s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 33/ 50] FashionMNIST Time 0.04637s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 34/ 50] MNIST Time 0.03406s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 34/ 50] FashionMNIST Time 0.04475s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 35/ 50] MNIST Time 0.03414s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] FashionMNIST Time 0.03459s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 36/ 50] MNIST Time 0.03495s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 36/ 50] FashionMNIST Time 0.03181s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 37/ 50] MNIST Time 0.04173s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 37/ 50] FashionMNIST Time 0.03558s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 38/ 50] MNIST Time 0.04769s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] FashionMNIST Time 0.03643s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 39/ 50] MNIST Time 0.03438s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] FashionMNIST Time 0.03534s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 40/ 50] MNIST Time 0.03658s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 40/ 50] FashionMNIST Time 0.04670s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 41/ 50] MNIST Time 0.03441s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 41/ 50] FashionMNIST Time 0.04477s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 42/ 50] MNIST Time 0.03587s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 42/ 50] FashionMNIST Time 0.04644s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 43/ 50] MNIST Time 0.03574s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] FashionMNIST Time 0.03477s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 44/ 50] MNIST Time 0.03458s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] FashionMNIST Time 0.03578s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 45/ 50] MNIST Time 0.04510s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] FashionMNIST Time 0.03415s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 46/ 50] MNIST Time 0.04566s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] FashionMNIST Time 0.03365s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 47/ 50] MNIST Time 0.03363s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 47/ 50] FashionMNIST Time 0.03349s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 48/ 50] MNIST Time 0.03515s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 48/ 50] FashionMNIST Time 0.04409s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 49/ 50] MNIST Time 0.03473s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] FashionMNIST Time 0.04580s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 50/ 50] MNIST Time 0.04756s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] FashionMNIST Time 0.03439s Training Accuracy: 100.00% Test Accuracy: 75.00%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 75.00%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.