Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, ReactantLoading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endImplement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endDefining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
endCreate and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
endDefine Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endTraining
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1760850060.273683 116446 service.cc:158] XLA service 0x424dbbd0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760850060.273753 116446 service.cc:166] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1760850060.274729 116446 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1760850060.274796 116446 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1760850060.274863 116446 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1760850060.285008 116446 cuda_dnn.cc:463] Loaded cuDNN version 91200
[ 1/ 50] MNIST Time 39.95069s Training Accuracy: 35.06% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.04406s Training Accuracy: 31.93% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.03920s Training Accuracy: 37.50% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.03601s Training Accuracy: 46.58% Test Accuracy: 43.75%
[ 3/ 50] MNIST Time 0.03824s Training Accuracy: 41.31% Test Accuracy: 43.75%
[ 3/ 50] FashionMNIST Time 0.03510s Training Accuracy: 53.03% Test Accuracy: 59.38%
[ 4/ 50] MNIST Time 0.03051s Training Accuracy: 54.20% Test Accuracy: 43.75%
[ 4/ 50] FashionMNIST Time 0.06470s Training Accuracy: 61.62% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.02971s Training Accuracy: 57.13% Test Accuracy: 37.50%
[ 5/ 50] FashionMNIST Time 0.03132s Training Accuracy: 65.43% Test Accuracy: 56.25%
[ 6/ 50] MNIST Time 0.03102s Training Accuracy: 63.48% Test Accuracy: 43.75%
[ 6/ 50] FashionMNIST Time 0.04012s Training Accuracy: 71.58% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.03097s Training Accuracy: 71.88% Test Accuracy: 43.75%
[ 7/ 50] FashionMNIST Time 0.04005s Training Accuracy: 76.27% Test Accuracy: 59.38%
[ 8/ 50] MNIST Time 0.03079s Training Accuracy: 76.07% Test Accuracy: 43.75%
[ 8/ 50] FashionMNIST Time 0.03124s Training Accuracy: 80.47% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.03996s Training Accuracy: 79.30% Test Accuracy: 50.00%
[ 9/ 50] FashionMNIST Time 0.03093s Training Accuracy: 83.20% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.03953s Training Accuracy: 83.59% Test Accuracy: 46.88%
[ 10/ 50] FashionMNIST Time 0.03084s Training Accuracy: 86.91% Test Accuracy: 62.50%
[ 11/ 50] MNIST Time 0.03095s Training Accuracy: 86.23% Test Accuracy: 43.75%
[ 11/ 50] FashionMNIST Time 0.04041s Training Accuracy: 89.06% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.03080s Training Accuracy: 89.36% Test Accuracy: 50.00%
[ 12/ 50] FashionMNIST Time 0.03970s Training Accuracy: 89.65% Test Accuracy: 62.50%
[ 13/ 50] MNIST Time 0.03094s Training Accuracy: 92.09% Test Accuracy: 53.12%
[ 13/ 50] FashionMNIST Time 0.03138s Training Accuracy: 91.41% Test Accuracy: 59.38%
[ 14/ 50] MNIST Time 0.04120s Training Accuracy: 95.12% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.03188s Training Accuracy: 94.14% Test Accuracy: 62.50%
[ 15/ 50] MNIST Time 0.04004s Training Accuracy: 96.39% Test Accuracy: 59.38%
[ 15/ 50] FashionMNIST Time 0.03147s Training Accuracy: 94.34% Test Accuracy: 62.50%
[ 16/ 50] MNIST Time 0.03112s Training Accuracy: 97.46% Test Accuracy: 62.50%
[ 16/ 50] FashionMNIST Time 0.03101s Training Accuracy: 95.31% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.03126s Training Accuracy: 98.93% Test Accuracy: 62.50%
[ 17/ 50] FashionMNIST Time 0.04034s Training Accuracy: 96.58% Test Accuracy: 65.62%
[ 18/ 50] MNIST Time 0.03110s Training Accuracy: 99.61% Test Accuracy: 62.50%
[ 18/ 50] FashionMNIST Time 0.03131s Training Accuracy: 97.36% Test Accuracy: 65.62%
[ 19/ 50] MNIST Time 0.03121s Training Accuracy: 99.41% Test Accuracy: 62.50%
[ 19/ 50] FashionMNIST Time 0.03079s Training Accuracy: 98.24% Test Accuracy: 75.00%
[ 20/ 50] MNIST Time 0.03977s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 20/ 50] FashionMNIST Time 0.03098s Training Accuracy: 99.02% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.02982s Training Accuracy: 99.80% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.03070s Training Accuracy: 99.12% Test Accuracy: 68.75%
[ 22/ 50] MNIST Time 0.03060s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 22/ 50] FashionMNIST Time 0.03983s Training Accuracy: 99.61% Test Accuracy: 75.00%
[ 23/ 50] MNIST Time 0.03058s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 23/ 50] FashionMNIST Time 0.03084s Training Accuracy: 99.80% Test Accuracy: 75.00%
[ 24/ 50] MNIST Time 0.03065s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 24/ 50] FashionMNIST Time 0.03107s Training Accuracy: 99.80% Test Accuracy: 71.88%
[ 25/ 50] MNIST Time 0.03961s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 25/ 50] FashionMNIST Time 0.03050s Training Accuracy: 99.90% Test Accuracy: 75.00%
[ 26/ 50] MNIST Time 0.03066s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 26/ 50] FashionMNIST Time 0.03064s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 27/ 50] MNIST Time 0.03653s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] FashionMNIST Time 0.03873s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 28/ 50] MNIST Time 0.03125s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.03958s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 29/ 50] MNIST Time 0.03118s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 29/ 50] FashionMNIST Time 0.03122s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 30/ 50] MNIST Time 0.04025s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 30/ 50] FashionMNIST Time 0.03098s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 31/ 50] MNIST Time 0.03964s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 31/ 50] FashionMNIST Time 0.03121s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 32/ 50] MNIST Time 0.03125s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] FashionMNIST Time 0.04117s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 33/ 50] MNIST Time 0.03108s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 33/ 50] FashionMNIST Time 0.04014s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 34/ 50] MNIST Time 0.03126s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.03101s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 35/ 50] MNIST Time 0.04072s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] FashionMNIST Time 0.03172s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 36/ 50] MNIST Time 0.04191s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] FashionMNIST Time 0.03122s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 37/ 50] MNIST Time 0.03088s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] FashionMNIST Time 0.04083s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 38/ 50] MNIST Time 0.03215s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] FashionMNIST Time 0.03974s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 39/ 50] MNIST Time 0.03159s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 39/ 50] FashionMNIST Time 0.03140s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 40/ 50] MNIST Time 0.03117s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] FashionMNIST Time 0.03145s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 41/ 50] MNIST Time 0.04137s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] FashionMNIST Time 0.03099s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 42/ 50] MNIST Time 0.03069s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 42/ 50] FashionMNIST Time 0.03092s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 43/ 50] MNIST Time 0.03042s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 43/ 50] FashionMNIST Time 0.04024s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 44/ 50] MNIST Time 0.03076s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] FashionMNIST Time 0.04017s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 45/ 50] MNIST Time 0.03076s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] FashionMNIST Time 0.03070s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 46/ 50] MNIST Time 0.03103s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] FashionMNIST Time 0.03015s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 47/ 50] MNIST Time 0.03944s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] FashionMNIST Time 0.03054s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 48/ 50] MNIST Time 0.03893s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 48/ 50] FashionMNIST Time 0.03086s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 49/ 50] MNIST Time 0.03066s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 49/ 50] FashionMNIST Time 0.03177s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 50/ 50] MNIST Time 0.03105s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 50/ 50] FashionMNIST Time 0.04099s Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 68.75%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.