Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Loading Datasets
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
create_model (generic function with 1 method)
Define Utility Functions
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-05-23 22:16:49.242547: I external/xla/xla/service/service.cc:152] XLA service 0xb2dd090 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-23 22:16:49.242592: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748038609.243417 691513 se_gpu_pjrt_client.cc:1026] Using BFC allocator.
I0000 00:00:1748038609.243494 691513 gpu_helpers.cc:136] XLA backend allocating 12527321088 bytes on device 0 for BFCAllocator.
I0000 00:00:1748038609.243539 691513 gpu_helpers.cc:177] XLA backend will use up to 4175773696 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1748038609.255071 691513 cuda_dnn.cc:529] Loaded cuDNN version 90400
[ 1/ 50] MNIST Time 32.08024s Training Accuracy: 34.57% Test Accuracy: 37.50%
[ 1/ 50] FashionMNIST Time 0.03595s Training Accuracy: 32.52% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.02783s Training Accuracy: 36.33% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.02969s Training Accuracy: 46.09% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.08189s Training Accuracy: 41.41% Test Accuracy: 31.25%
[ 3/ 50] FashionMNIST Time 0.02853s Training Accuracy: 57.32% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.02572s Training Accuracy: 51.76% Test Accuracy: 37.50%
[ 4/ 50] FashionMNIST Time 0.02870s Training Accuracy: 64.26% Test Accuracy: 56.25%
[ 5/ 50] MNIST Time 0.02710s Training Accuracy: 58.01% Test Accuracy: 37.50%
[ 5/ 50] FashionMNIST Time 0.02782s Training Accuracy: 70.90% Test Accuracy: 56.25%
[ 6/ 50] MNIST Time 0.02819s Training Accuracy: 62.50% Test Accuracy: 34.38%
[ 6/ 50] FashionMNIST Time 0.04260s Training Accuracy: 75.59% Test Accuracy: 53.12%
[ 7/ 50] MNIST Time 0.02871s Training Accuracy: 68.65% Test Accuracy: 37.50%
[ 7/ 50] FashionMNIST Time 0.02890s Training Accuracy: 76.17% Test Accuracy: 53.12%
[ 8/ 50] MNIST Time 0.02827s Training Accuracy: 75.59% Test Accuracy: 40.62%
[ 8/ 50] FashionMNIST Time 0.02826s Training Accuracy: 81.15% Test Accuracy: 68.75%
[ 9/ 50] MNIST Time 0.02862s Training Accuracy: 79.49% Test Accuracy: 46.88%
[ 9/ 50] FashionMNIST Time 0.03942s Training Accuracy: 82.62% Test Accuracy: 65.62%
[ 10/ 50] MNIST Time 0.02774s Training Accuracy: 82.23% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.03785s Training Accuracy: 87.01% Test Accuracy: 53.12%
[ 11/ 50] MNIST Time 0.02743s Training Accuracy: 87.99% Test Accuracy: 56.25%
[ 11/ 50] FashionMNIST Time 0.02790s Training Accuracy: 89.84% Test Accuracy: 65.62%
[ 12/ 50] MNIST Time 0.02756s Training Accuracy: 90.23% Test Accuracy: 53.12%
[ 12/ 50] FashionMNIST Time 0.02814s Training Accuracy: 90.92% Test Accuracy: 65.62%
[ 13/ 50] MNIST Time 0.03810s Training Accuracy: 94.43% Test Accuracy: 62.50%
[ 13/ 50] FashionMNIST Time 0.02821s Training Accuracy: 92.58% Test Accuracy: 71.88%
[ 14/ 50] MNIST Time 0.03819s Training Accuracy: 94.53% Test Accuracy: 68.75%
[ 14/ 50] FashionMNIST Time 0.02572s Training Accuracy: 94.43% Test Accuracy: 65.62%
[ 15/ 50] MNIST Time 0.02406s Training Accuracy: 95.90% Test Accuracy: 68.75%
[ 15/ 50] FashionMNIST Time 0.03694s Training Accuracy: 94.53% Test Accuracy: 71.88%
[ 16/ 50] MNIST Time 0.02931s Training Accuracy: 98.34% Test Accuracy: 65.62%
[ 16/ 50] FashionMNIST Time 0.03475s Training Accuracy: 94.73% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.02785s Training Accuracy: 99.51% Test Accuracy: 68.75%
[ 17/ 50] FashionMNIST Time 0.02741s Training Accuracy: 97.27% Test Accuracy: 75.00%
[ 18/ 50] MNIST Time 0.02832s Training Accuracy: 99.41% Test Accuracy: 65.62%
[ 18/ 50] FashionMNIST Time 0.02587s Training Accuracy: 96.68% Test Accuracy: 68.75%
[ 19/ 50] MNIST Time 0.03045s Training Accuracy: 99.80% Test Accuracy: 62.50%
[ 19/ 50] FashionMNIST Time 0.02800s Training Accuracy: 98.34% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.04251s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 20/ 50] FashionMNIST Time 0.02937s Training Accuracy: 98.34% Test Accuracy: 75.00%
[ 21/ 50] MNIST Time 0.02413s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 21/ 50] FashionMNIST Time 0.02610s Training Accuracy: 98.34% Test Accuracy: 71.88%
[ 22/ 50] MNIST Time 0.03518s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 22/ 50] FashionMNIST Time 0.03803s Training Accuracy: 99.51% Test Accuracy: 75.00%
[ 23/ 50] MNIST Time 0.02993s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 23/ 50] FashionMNIST Time 0.03643s Training Accuracy: 99.41% Test Accuracy: 71.88%
[ 24/ 50] MNIST Time 0.02417s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 24/ 50] FashionMNIST Time 0.02424s Training Accuracy: 99.61% Test Accuracy: 68.75%
[ 25/ 50] MNIST Time 0.03291s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 25/ 50] FashionMNIST Time 0.02487s Training Accuracy: 99.80% Test Accuracy: 75.00%
[ 26/ 50] MNIST Time 0.03400s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 26/ 50] FashionMNIST Time 0.02359s Training Accuracy: 99.90% Test Accuracy: 75.00%
[ 27/ 50] MNIST Time 0.02423s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 27/ 50] FashionMNIST Time 0.02527s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 28/ 50] MNIST Time 0.02535s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 28/ 50] FashionMNIST Time 0.03752s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 29/ 50] MNIST Time 0.02791s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] FashionMNIST Time 0.03506s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 30/ 50] MNIST Time 0.02722s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] FashionMNIST Time 0.02799s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 31/ 50] MNIST Time 0.02788s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 31/ 50] FashionMNIST Time 0.02829s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 32/ 50] MNIST Time 0.03607s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 32/ 50] FashionMNIST Time 0.02770s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 33/ 50] MNIST Time 0.02767s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 33/ 50] FashionMNIST Time 0.02776s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 34/ 50] MNIST Time 0.02729s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 34/ 50] FashionMNIST Time 0.03773s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 35/ 50] MNIST Time 0.02755s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 35/ 50] FashionMNIST Time 0.03549s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 36/ 50] MNIST Time 0.02509s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] FashionMNIST Time 0.02437s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 37/ 50] MNIST Time 0.02467s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 37/ 50] FashionMNIST Time 0.02402s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 38/ 50] MNIST Time 0.03380s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] FashionMNIST Time 0.02363s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 39/ 50] MNIST Time 0.03350s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 39/ 50] FashionMNIST Time 0.02377s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 40/ 50] MNIST Time 0.02312s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] FashionMNIST Time 0.02884s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 41/ 50] MNIST Time 0.02530s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 41/ 50] FashionMNIST Time 0.03879s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 42/ 50] MNIST Time 0.02852s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] FashionMNIST Time 0.02867s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 43/ 50] MNIST Time 0.02872s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] FashionMNIST Time 0.02898s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 44/ 50] MNIST Time 0.03926s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 44/ 50] FashionMNIST Time 0.02440s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 45/ 50] MNIST Time 0.03108s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] FashionMNIST Time 0.02536s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 46/ 50] MNIST Time 0.02559s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] FashionMNIST Time 0.02612s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 47/ 50] MNIST Time 0.02663s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 47/ 50] FashionMNIST Time 0.03483s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 48/ 50] MNIST Time 0.02572s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 48/ 50] FashionMNIST Time 0.03517s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 49/ 50] MNIST Time 0.02648s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] FashionMNIST Time 0.02635s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 50/ 50] MNIST Time 0.03640s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 50/ 50] FashionMNIST Time 0.02595s Training Accuracy: 100.00% Test Accuracy: 75.00%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 62.50%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 75.00%
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.