Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling LuxComponentArraysExt...
1547.1 ms ✓ Lux → LuxComponentArraysExt
1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling ComponentArraysReactantExt...
17145.5 ms ✓ ComponentArrays → ComponentArraysReactantExt
1 dependency successfully precompiled in 17 seconds. 95 already precompiled.
Precompiling ReactantOneHotArraysExt...
17390.7 ms ✓ Reactant → ReactantOneHotArraysExt
1 dependency successfully precompiled in 18 seconds. 104 already precompiled.
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
create_model (generic function with 1 method)
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-07-14 00:14:02.399863: I external/xla/xla/service/service.cc:153] XLA service 0x44f4cf00 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-14 00:14:02.399906: I external/xla/xla/service/service.cc:161] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752452042.400753 2780017 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752452042.400826 2780017 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752452042.400899 2780017 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752452042.414801 2780017 cuda_dnn.cc:471] Loaded cuDNN version 90800
[ 1/ 50] MNIST Time 55.35623s Training Accuracy: 34.77% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.04084s Training Accuracy: 32.42% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.04593s Training Accuracy: 36.43% Test Accuracy: 31.25%
[ 2/ 50] FashionMNIST Time 0.10024s Training Accuracy: 45.51% Test Accuracy: 43.75%
[ 3/ 50] MNIST Time 0.03791s Training Accuracy: 39.26% Test Accuracy: 40.62%
[ 3/ 50] FashionMNIST Time 0.03511s Training Accuracy: 51.27% Test Accuracy: 62.50%
[ 4/ 50] MNIST Time 0.05604s Training Accuracy: 49.51% Test Accuracy: 43.75%
[ 4/ 50] FashionMNIST Time 0.03575s Training Accuracy: 60.25% Test Accuracy: 53.12%
[ 5/ 50] MNIST Time 0.03147s Training Accuracy: 56.35% Test Accuracy: 37.50%
[ 5/ 50] FashionMNIST Time 0.03641s Training Accuracy: 65.43% Test Accuracy: 53.12%
[ 6/ 50] MNIST Time 0.03436s Training Accuracy: 61.82% Test Accuracy: 46.88%
[ 6/ 50] FashionMNIST Time 0.03308s Training Accuracy: 69.24% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.04164s Training Accuracy: 67.87% Test Accuracy: 46.88%
[ 7/ 50] FashionMNIST Time 0.03288s Training Accuracy: 75.00% Test Accuracy: 59.38%
[ 8/ 50] MNIST Time 0.04463s Training Accuracy: 74.90% Test Accuracy: 50.00%
[ 8/ 50] FashionMNIST Time 0.03537s Training Accuracy: 78.32% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.04308s Training Accuracy: 78.52% Test Accuracy: 53.12%
[ 9/ 50] FashionMNIST Time 0.03216s Training Accuracy: 79.49% Test Accuracy: 65.62%
[ 10/ 50] MNIST Time 0.04361s Training Accuracy: 83.11% Test Accuracy: 46.88%
[ 10/ 50] FashionMNIST Time 0.03323s Training Accuracy: 84.77% Test Accuracy: 62.50%
[ 11/ 50] MNIST Time 0.04545s Training Accuracy: 86.62% Test Accuracy: 56.25%
[ 11/ 50] FashionMNIST Time 0.03466s Training Accuracy: 86.82% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.03893s Training Accuracy: 87.99% Test Accuracy: 56.25%
[ 12/ 50] FashionMNIST Time 0.02956s Training Accuracy: 89.06% Test Accuracy: 68.75%
[ 13/ 50] MNIST Time 0.03331s Training Accuracy: 91.80% Test Accuracy: 56.25%
[ 13/ 50] FashionMNIST Time 0.03299s Training Accuracy: 92.29% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.03315s Training Accuracy: 94.73% Test Accuracy: 65.62%
[ 14/ 50] FashionMNIST Time 0.04797s Training Accuracy: 94.04% Test Accuracy: 68.75%
[ 15/ 50] MNIST Time 0.03137s Training Accuracy: 96.48% Test Accuracy: 65.62%
[ 15/ 50] FashionMNIST Time 0.03372s Training Accuracy: 95.02% Test Accuracy: 71.88%
[ 16/ 50] MNIST Time 0.03325s Training Accuracy: 96.97% Test Accuracy: 65.62%
[ 16/ 50] FashionMNIST Time 0.03351s Training Accuracy: 95.61% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.03279s Training Accuracy: 98.54% Test Accuracy: 59.38%
[ 17/ 50] FashionMNIST Time 0.03353s Training Accuracy: 96.58% Test Accuracy: 65.62%
[ 18/ 50] MNIST Time 0.03309s Training Accuracy: 98.63% Test Accuracy: 65.62%
[ 18/ 50] FashionMNIST Time 0.04281s Training Accuracy: 97.56% Test Accuracy: 71.88%
[ 19/ 50] MNIST Time 0.03279s Training Accuracy: 99.61% Test Accuracy: 65.62%
[ 19/ 50] FashionMNIST Time 0.04275s Training Accuracy: 98.63% Test Accuracy: 65.62%
[ 20/ 50] MNIST Time 0.03207s Training Accuracy: 99.71% Test Accuracy: 65.62%
[ 20/ 50] FashionMNIST Time 0.04248s Training Accuracy: 98.73% Test Accuracy: 71.88%
[ 21/ 50] MNIST Time 0.03217s Training Accuracy: 99.80% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.04091s Training Accuracy: 98.73% Test Accuracy: 65.62%
[ 22/ 50] MNIST Time 0.03037s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 22/ 50] FashionMNIST Time 0.03914s Training Accuracy: 99.22% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.03207s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 23/ 50] FashionMNIST Time 0.04149s Training Accuracy: 99.61% Test Accuracy: 68.75%
[ 24/ 50] MNIST Time 0.03298s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 24/ 50] FashionMNIST Time 0.03368s Training Accuracy: 99.80% Test Accuracy: 68.75%
[ 25/ 50] MNIST Time 0.03343s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 25/ 50] FashionMNIST Time 0.03209s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 26/ 50] MNIST Time 0.03289s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 26/ 50] FashionMNIST Time 0.03258s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 27/ 50] MNIST Time 0.03287s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 27/ 50] FashionMNIST Time 0.03252s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 28/ 50] MNIST Time 0.03398s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 28/ 50] FashionMNIST Time 0.03327s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 29/ 50] MNIST Time 0.03219s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] FashionMNIST Time 0.03230s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] MNIST Time 0.04004s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] FashionMNIST Time 0.03319s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 31/ 50] MNIST Time 0.04118s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 31/ 50] FashionMNIST Time 0.03222s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 32/ 50] MNIST Time 0.04338s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 32/ 50] FashionMNIST Time 0.03341s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 33/ 50] MNIST Time 0.04328s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 33/ 50] FashionMNIST Time 0.03280s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 34/ 50] MNIST Time 0.03890s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 34/ 50] FashionMNIST Time 0.03373s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] MNIST Time 0.04283s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] FashionMNIST Time 0.03416s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 36/ 50] MNIST Time 0.03327s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 36/ 50] FashionMNIST Time 0.03296s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 37/ 50] MNIST Time 0.03300s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 37/ 50] FashionMNIST Time 0.03536s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 38/ 50] MNIST Time 0.03277s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 38/ 50] FashionMNIST Time 0.03588s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] MNIST Time 0.03306s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] FashionMNIST Time 0.03356s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 40/ 50] MNIST Time 0.03386s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 40/ 50] FashionMNIST Time 0.03312s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 41/ 50] MNIST Time 0.03102s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 41/ 50] FashionMNIST Time 0.04152s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 42/ 50] MNIST Time 0.03327s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 42/ 50] FashionMNIST Time 0.04359s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] MNIST Time 0.03270s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] FashionMNIST Time 0.03975s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] MNIST Time 0.03293s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] FashionMNIST Time 0.04130s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] MNIST Time 0.03165s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 45/ 50] FashionMNIST Time 0.04076s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] MNIST Time 0.03385s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] FashionMNIST Time 0.04526s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] MNIST Time 0.03352s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 47/ 50] FashionMNIST Time 0.03405s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 48/ 50] MNIST Time 0.03415s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] FashionMNIST Time 0.03307s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 49/ 50] MNIST Time 0.03261s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] FashionMNIST Time 0.03317s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] MNIST Time 0.03306s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 50/ 50] FashionMNIST Time 0.04050s Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.