Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling ComponentArrays...
957.5 ms ✓ ComponentArrays
1 dependency successfully precompiled in 1 seconds. 23 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
622.8 ms ✓ MLDataDevices → MLDataDevicesComponentArraysExt
1 dependency successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling LuxComponentArraysExt...
552.8 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1443.4 ms ✓ Lux → LuxComponentArraysExt
2027.7 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
3 dependencies successfully precompiled in 2 seconds. 107 already precompiled.
Precompiling MLDatasets...
454.2 ms ✓ DelimitedFiles
620.9 ms ✓ ZipFile
692.5 ms ✓ GZip
475.3 ms ✓ PooledArrays
1084.3 ms ✓ SplittablesBase
960.4 ms ✓ MLCore
2620.1 ms ✓ Accessors
2215.5 ms ✓ AtomsBase
1869.6 ms ✓ ImageShow
2017.6 ms ✓ HDF5_jll
9277.2 ms ✓ JSON3
3160.7 ms ✓ DataDeps
2279.5 ms ✓ Pickle
1722.8 ms ✓ NPZ
21331.3 ms ✓ PrettyTables
867.0 ms ✓ Accessors → LinearAlgebraExt
696.5 ms ✓ Accessors → TestExt
648.1 ms ✓ Accessors → UnitfulExt
747.9 ms ✓ Accessors → StaticArraysExt
2427.2 ms ✓ Chemfiles
7384.5 ms ✓ HDF5
19518.4 ms ✓ CSV
777.0 ms ✓ BangBang
34301.2 ms ✓ JLD2
2461.1 ms ✓ MAT
510.7 ms ✓ BangBang → BangBangChainRulesCoreExt
724.8 ms ✓ BangBang → BangBangStaticArraysExt
497.8 ms ✓ BangBang → BangBangTablesExt
916.7 ms ✓ MicroCollections
2775.9 ms ✓ Transducers
728.4 ms ✓ Transducers → TransducersAdaptExt
5387.4 ms ✓ FLoops
6385.8 ms ✓ MLUtils
49592.0 ms ✓ DataFrames
1404.6 ms ✓ Transducers → TransducersDataFramesExt
1623.5 ms ✓ BangBang → BangBangDataFramesExt
9831.6 ms ✓ MLDatasets
37 dependencies successfully precompiled in 97 seconds. 166 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1525.4 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 101 already precompiled.
Precompiling LuxMLUtilsExt...
2142.9 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 3 seconds. 164 already precompiled.
Precompiling OneHotArrays...
976.8 ms ✓ OneHotArrays
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
759.8 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 1 seconds. 38 already precompiled.
Precompiling ComponentArraysReactantExt...
17650.2 ms ✓ ComponentArrays → ComponentArraysReactantExt
1 dependency successfully precompiled in 18 seconds. 95 already precompiled.
Precompiling ReactantOneHotArraysExt...
17829.6 ms ✓ Reactant → ReactantOneHotArraysExt
1 dependency successfully precompiled in 18 seconds. 104 already precompiled.
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
create_model (generic function with 1 method)
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-07-03 15:51:53.508007: I external/xla/xla/service/service.cc:153] XLA service 0x1ad15cb0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-03 15:51:53.508051: I external/xla/xla/service/service.cc:161] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1751557913.508974 104805 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1751557913.509044 104805 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1751557913.509088 104805 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1751557913.520706 104805 cuda_dnn.cc:471] Loaded cuDNN version 90800
[ 1/ 50] MNIST Time 54.32921s Training Accuracy: 35.16% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.10638s Training Accuracy: 32.42% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.04448s Training Accuracy: 36.91% Test Accuracy: 37.50%
[ 2/ 50] FashionMNIST Time 0.04382s Training Accuracy: 46.48% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.07007s Training Accuracy: 43.26% Test Accuracy: 31.25%
[ 3/ 50] FashionMNIST Time 0.03865s Training Accuracy: 54.98% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.03893s Training Accuracy: 51.56% Test Accuracy: 34.38%
[ 4/ 50] FashionMNIST Time 0.05394s Training Accuracy: 60.84% Test Accuracy: 56.25%
[ 5/ 50] MNIST Time 0.03551s Training Accuracy: 56.05% Test Accuracy: 40.62%
[ 5/ 50] FashionMNIST Time 0.04700s Training Accuracy: 68.26% Test Accuracy: 62.50%
[ 6/ 50] MNIST Time 0.03472s Training Accuracy: 59.96% Test Accuracy: 50.00%
[ 6/ 50] FashionMNIST Time 0.04470s Training Accuracy: 73.54% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.03534s Training Accuracy: 70.80% Test Accuracy: 46.88%
[ 7/ 50] FashionMNIST Time 0.04770s Training Accuracy: 76.76% Test Accuracy: 65.62%
[ 8/ 50] MNIST Time 0.03743s Training Accuracy: 73.83% Test Accuracy: 50.00%
[ 8/ 50] FashionMNIST Time 0.03442s Training Accuracy: 81.15% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.03507s Training Accuracy: 79.88% Test Accuracy: 56.25%
[ 9/ 50] FashionMNIST Time 0.03425s Training Accuracy: 83.89% Test Accuracy: 68.75%
[ 10/ 50] MNIST Time 0.03559s Training Accuracy: 84.18% Test Accuracy: 53.12%
[ 10/ 50] FashionMNIST Time 0.03756s Training Accuracy: 87.11% Test Accuracy: 62.50%
[ 11/ 50] MNIST Time 0.03477s Training Accuracy: 86.82% Test Accuracy: 56.25%
[ 11/ 50] FashionMNIST Time 0.03414s Training Accuracy: 89.55% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.04325s Training Accuracy: 90.33% Test Accuracy: 50.00%
[ 12/ 50] FashionMNIST Time 0.03376s Training Accuracy: 90.04% Test Accuracy: 68.75%
[ 13/ 50] MNIST Time 0.04556s Training Accuracy: 91.80% Test Accuracy: 56.25%
[ 13/ 50] FashionMNIST Time 0.03435s Training Accuracy: 92.68% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.04492s Training Accuracy: 94.43% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.04967s Training Accuracy: 94.73% Test Accuracy: 65.62%
[ 15/ 50] MNIST Time 0.06062s Training Accuracy: 96.97% Test Accuracy: 65.62%
[ 15/ 50] FashionMNIST Time 0.03358s Training Accuracy: 95.12% Test Accuracy: 65.62%
[ 16/ 50] MNIST Time 0.03587s Training Accuracy: 98.14% Test Accuracy: 62.50%
[ 16/ 50] FashionMNIST Time 0.03273s Training Accuracy: 96.78% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.03699s Training Accuracy: 99.02% Test Accuracy: 65.62%
[ 17/ 50] FashionMNIST Time 0.03346s Training Accuracy: 97.46% Test Accuracy: 65.62%
[ 18/ 50] MNIST Time 0.03535s Training Accuracy: 99.80% Test Accuracy: 62.50%
[ 18/ 50] FashionMNIST Time 0.03624s Training Accuracy: 97.56% Test Accuracy: 65.62%
[ 19/ 50] MNIST Time 0.03545s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 19/ 50] FashionMNIST Time 0.04608s Training Accuracy: 98.54% Test Accuracy: 65.62%
[ 20/ 50] MNIST Time 0.03666s Training Accuracy: 99.80% Test Accuracy: 65.62%
[ 20/ 50] FashionMNIST Time 0.04336s Training Accuracy: 99.12% Test Accuracy: 71.88%
[ 21/ 50] MNIST Time 0.03396s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 21/ 50] FashionMNIST Time 0.04471s Training Accuracy: 99.12% Test Accuracy: 71.88%
[ 22/ 50] MNIST Time 0.03290s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 22/ 50] FashionMNIST Time 0.04392s Training Accuracy: 99.61% Test Accuracy: 68.75%
[ 23/ 50] MNIST Time 0.03400s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 23/ 50] FashionMNIST Time 0.03401s Training Accuracy: 99.61% Test Accuracy: 68.75%
[ 24/ 50] MNIST Time 0.03420s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 24/ 50] FashionMNIST Time 0.03680s Training Accuracy: 99.80% Test Accuracy: 65.62%
[ 25/ 50] MNIST Time 0.03205s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 25/ 50] FashionMNIST Time 0.03386s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 26/ 50] MNIST Time 0.03440s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 26/ 50] FashionMNIST Time 0.03452s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 27/ 50] MNIST Time 0.04889s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 27/ 50] FashionMNIST Time 0.03746s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 28/ 50] MNIST Time 0.04775s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 28/ 50] FashionMNIST Time 0.03594s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 29/ 50] MNIST Time 0.04474s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] FashionMNIST Time 0.03323s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] MNIST Time 0.04435s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] FashionMNIST Time 0.03357s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 31/ 50] MNIST Time 0.03355s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 31/ 50] FashionMNIST Time 0.03380s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 32/ 50] MNIST Time 0.03296s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 32/ 50] FashionMNIST Time 0.03311s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 33/ 50] MNIST Time 0.03396s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 33/ 50] FashionMNIST Time 0.03329s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 34/ 50] MNIST Time 0.03593s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 34/ 50] FashionMNIST Time 0.04342s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 35/ 50] MNIST Time 0.03190s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] FashionMNIST Time 0.04248s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 36/ 50] MNIST Time 0.03180s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 36/ 50] FashionMNIST Time 0.04237s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 37/ 50] MNIST Time 0.03368s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 37/ 50] FashionMNIST Time 0.04247s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 38/ 50] MNIST Time 0.03309s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 38/ 50] FashionMNIST Time 0.03246s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] MNIST Time 0.03323s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] FashionMNIST Time 0.03331s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 40/ 50] MNIST Time 0.03275s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 40/ 50] FashionMNIST Time 0.03498s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 41/ 50] MNIST Time 0.03176s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 41/ 50] FashionMNIST Time 0.03284s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 42/ 50] MNIST Time 0.04612s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 42/ 50] FashionMNIST Time 0.03330s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 43/ 50] MNIST Time 0.04355s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] FashionMNIST Time 0.03321s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 44/ 50] MNIST Time 0.04515s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] FashionMNIST Time 0.03293s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 45/ 50] MNIST Time 0.04484s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 45/ 50] FashionMNIST Time 0.03318s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 46/ 50] MNIST Time 0.03382s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] FashionMNIST Time 0.03388s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 47/ 50] MNIST Time 0.03331s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 47/ 50] FashionMNIST Time 0.03299s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 48/ 50] MNIST Time 0.03436s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] FashionMNIST Time 0.03379s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 49/ 50] MNIST Time 0.03496s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] FashionMNIST Time 0.03422s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 50/ 50] MNIST Time 0.03351s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 50/ 50] FashionMNIST Time 0.04691s Training Accuracy: 100.00% Test Accuracy: 71.88%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 71.88%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.