Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling LuxComponentArraysExt...
1588.6 ms ✓ Lux → LuxComponentArraysExt
1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling LuxMLUtilsExt...
2155.7 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 164 already precompiled.
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
create_model (generic function with 1 method)
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-05-23 23:07:11.234361: I external/xla/xla/service/service.cc:152] XLA service 0x2c38c500 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-23 23:07:11.234398: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
2025-05-23 23:07:11.234405: I external/xla/xla/service/service.cc:160] StreamExecutor device (1): Quadro RTX 5000, Compute Capability 7.5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748041631.239268 846235 se_gpu_pjrt_client.cc:1026] Using BFC allocator.
I0000 00:00:1748041631.239345 846235 gpu_helpers.cc:136] XLA backend allocating 12527321088 bytes on device 0 for BFCAllocator.
I0000 00:00:1748041631.239391 846235 gpu_helpers.cc:136] XLA backend allocating 12527321088 bytes on device 1 for BFCAllocator.
I0000 00:00:1748041631.239410 846235 gpu_helpers.cc:177] XLA backend will use up to 4175773696 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1748041631.239428 846235 gpu_helpers.cc:177] XLA backend will use up to 4175773696 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1748041631.254557 846235 cuda_dnn.cc:529] Loaded cuDNN version 90400
[ 1/ 50] MNIST Time 34.55751s Training Accuracy: 34.57% Test Accuracy: 37.50%
[ 1/ 50] FashionMNIST Time 0.02479s Training Accuracy: 32.62% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.02557s Training Accuracy: 36.91% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.07014s Training Accuracy: 46.00% Test Accuracy: 46.88%
[ 3/ 50] MNIST Time 0.02439s Training Accuracy: 41.60% Test Accuracy: 34.38%
[ 3/ 50] FashionMNIST Time 0.02883s Training Accuracy: 53.32% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.02958s Training Accuracy: 52.25% Test Accuracy: 43.75%
[ 4/ 50] FashionMNIST Time 0.02947s Training Accuracy: 61.33% Test Accuracy: 56.25%
[ 5/ 50] MNIST Time 0.02905s Training Accuracy: 58.69% Test Accuracy: 43.75%
[ 5/ 50] FashionMNIST Time 0.02855s Training Accuracy: 68.85% Test Accuracy: 56.25%
[ 6/ 50] MNIST Time 0.02970s Training Accuracy: 62.89% Test Accuracy: 37.50%
[ 6/ 50] FashionMNIST Time 0.03771s Training Accuracy: 73.73% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.02992s Training Accuracy: 68.85% Test Accuracy: 43.75%
[ 7/ 50] FashionMNIST Time 0.05951s Training Accuracy: 75.68% Test Accuracy: 53.12%
[ 8/ 50] MNIST Time 0.02622s Training Accuracy: 77.64% Test Accuracy: 40.62%
[ 8/ 50] FashionMNIST Time 0.03957s Training Accuracy: 79.88% Test Accuracy: 59.38%
[ 9/ 50] MNIST Time 0.03195s Training Accuracy: 80.66% Test Accuracy: 53.12%
[ 9/ 50] FashionMNIST Time 0.02470s Training Accuracy: 83.69% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.02585s Training Accuracy: 83.69% Test Accuracy: 43.75%
[ 10/ 50] FashionMNIST Time 0.03771s Training Accuracy: 87.60% Test Accuracy: 56.25%
[ 11/ 50] MNIST Time 0.03004s Training Accuracy: 88.67% Test Accuracy: 53.12%
[ 11/ 50] FashionMNIST Time 0.04026s Training Accuracy: 89.45% Test Accuracy: 56.25%
[ 12/ 50] MNIST Time 0.03049s Training Accuracy: 90.43% Test Accuracy: 56.25%
[ 12/ 50] FashionMNIST Time 0.03039s Training Accuracy: 91.89% Test Accuracy: 59.38%
[ 13/ 50] MNIST Time 0.04104s Training Accuracy: 92.48% Test Accuracy: 56.25%
[ 13/ 50] FashionMNIST Time 0.03036s Training Accuracy: 93.46% Test Accuracy: 62.50%
[ 14/ 50] MNIST Time 0.02906s Training Accuracy: 95.61% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.03097s Training Accuracy: 94.92% Test Accuracy: 59.38%
[ 15/ 50] MNIST Time 0.03053s Training Accuracy: 96.78% Test Accuracy: 65.62%
[ 15/ 50] FashionMNIST Time 0.03917s Training Accuracy: 95.51% Test Accuracy: 65.62%
[ 16/ 50] MNIST Time 0.02944s Training Accuracy: 97.85% Test Accuracy: 59.38%
[ 16/ 50] FashionMNIST Time 0.03343s Training Accuracy: 96.19% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.03745s Training Accuracy: 99.02% Test Accuracy: 53.12%
[ 17/ 50] FashionMNIST Time 0.02951s Training Accuracy: 97.27% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.04196s Training Accuracy: 99.71% Test Accuracy: 65.62%
[ 18/ 50] FashionMNIST Time 0.02592s Training Accuracy: 97.27% Test Accuracy: 62.50%
[ 19/ 50] MNIST Time 0.02518s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 19/ 50] FashionMNIST Time 0.03722s Training Accuracy: 99.12% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.02847s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 20/ 50] FashionMNIST Time 0.02841s Training Accuracy: 99.32% Test Accuracy: 65.62%
[ 21/ 50] MNIST Time 0.02773s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 21/ 50] FashionMNIST Time 0.02903s Training Accuracy: 99.41% Test Accuracy: 68.75%
[ 22/ 50] MNIST Time 0.03598s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 22/ 50] FashionMNIST Time 0.02797s Training Accuracy: 99.61% Test Accuracy: 65.62%
[ 23/ 50] MNIST Time 0.02877s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 23/ 50] FashionMNIST Time 0.03658s Training Accuracy: 99.61% Test Accuracy: 68.75%
[ 24/ 50] MNIST Time 0.02395s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 24/ 50] FashionMNIST Time 0.03672s Training Accuracy: 99.80% Test Accuracy: 68.75%
[ 25/ 50] MNIST Time 0.02609s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 25/ 50] FashionMNIST Time 0.03857s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 26/ 50] MNIST Time 0.03572s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 26/ 50] FashionMNIST Time 0.02930s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 27/ 50] MNIST Time 0.03031s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 27/ 50] FashionMNIST Time 0.02976s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 28/ 50] MNIST Time 0.02959s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 28/ 50] FashionMNIST Time 0.03953s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] MNIST Time 0.02737s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 29/ 50] FashionMNIST Time 0.04044s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] MNIST Time 0.03484s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 30/ 50] FashionMNIST Time 0.03259s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] MNIST Time 0.02419s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 31/ 50] FashionMNIST Time 0.03003s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 32/ 50] MNIST Time 0.03117s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 32/ 50] FashionMNIST Time 0.04124s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 33/ 50] MNIST Time 0.03089s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 33/ 50] FashionMNIST Time 0.03005s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 34/ 50] MNIST Time 0.04107s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 34/ 50] FashionMNIST Time 0.03090s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 35/ 50] MNIST Time 0.03894s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 35/ 50] FashionMNIST Time 0.03089s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] MNIST Time 0.03043s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 36/ 50] FashionMNIST Time 0.03887s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 37/ 50] MNIST Time 0.03014s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 37/ 50] FashionMNIST Time 0.03003s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] MNIST Time 0.02927s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 38/ 50] FashionMNIST Time 0.02976s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 39/ 50] MNIST Time 0.03921s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 39/ 50] FashionMNIST Time 0.02982s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] MNIST Time 0.03057s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 40/ 50] FashionMNIST Time 0.03996s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 41/ 50] MNIST Time 0.02957s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 41/ 50] FashionMNIST Time 0.03859s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 42/ 50] MNIST Time 0.03112s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 42/ 50] FashionMNIST Time 0.03009s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 43/ 50] MNIST Time 0.03947s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 43/ 50] FashionMNIST Time 0.02931s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] MNIST Time 0.03843s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 44/ 50] FashionMNIST Time 0.04902s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] MNIST Time 0.03064s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 45/ 50] FashionMNIST Time 0.04207s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] MNIST Time 0.03110s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] FashionMNIST Time 0.04098s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 47/ 50] MNIST Time 0.03015s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 47/ 50] FashionMNIST Time 0.03457s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.02820s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 48/ 50] FashionMNIST Time 0.02453s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.02659s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] FashionMNIST Time 0.02330s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] MNIST Time 0.02833s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 50/ 50] FashionMNIST Time 0.02822s Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 62.50%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 68.75%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.