Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1757982770.905553 397673 service.cc:163] XLA service 0x4fc35170 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757982770.905664 397673 service.cc:171] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1757982770.906523 397673 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1757982770.906567 397673 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1757982770.906608 397673 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1757982770.917039 397673 cuda_dnn.cc:463] Loaded cuDNN version 91200
[ 1/ 50] MNIST Time 45.73280s Training Accuracy: 35.06% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.04383s Training Accuracy: 31.74% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.09866s Training Accuracy: 37.30% Test Accuracy: 37.50%
[ 2/ 50] FashionMNIST Time 0.04291s Training Accuracy: 46.48% Test Accuracy: 46.88%
[ 3/ 50] MNIST Time 0.03741s Training Accuracy: 41.31% Test Accuracy: 31.25%
[ 3/ 50] FashionMNIST Time 0.03844s Training Accuracy: 52.44% Test Accuracy: 65.62%
[ 4/ 50] MNIST Time 0.03706s Training Accuracy: 52.73% Test Accuracy: 46.88%
[ 4/ 50] FashionMNIST Time 0.03569s Training Accuracy: 60.16% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.03270s Training Accuracy: 57.03% Test Accuracy: 40.62%
[ 5/ 50] FashionMNIST Time 0.03367s Training Accuracy: 66.80% Test Accuracy: 59.38%
[ 6/ 50] MNIST Time 0.03534s Training Accuracy: 62.99% Test Accuracy: 43.75%
[ 6/ 50] FashionMNIST Time 0.03597s Training Accuracy: 72.46% Test Accuracy: 62.50%
[ 7/ 50] MNIST Time 0.03686s Training Accuracy: 67.97% Test Accuracy: 43.75%
[ 7/ 50] FashionMNIST Time 0.03561s Training Accuracy: 76.95% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.03910s Training Accuracy: 76.46% Test Accuracy: 46.88%
[ 8/ 50] FashionMNIST Time 0.03026s Training Accuracy: 80.08% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.03277s Training Accuracy: 79.39% Test Accuracy: 50.00%
[ 9/ 50] FashionMNIST Time 0.03475s Training Accuracy: 82.81% Test Accuracy: 65.62%
[ 10/ 50] MNIST Time 0.04554s Training Accuracy: 85.25% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.03596s Training Accuracy: 84.96% Test Accuracy: 62.50%
[ 11/ 50] MNIST Time 0.04675s Training Accuracy: 87.70% Test Accuracy: 50.00%
[ 11/ 50] FashionMNIST Time 0.03698s Training Accuracy: 88.28% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.04450s Training Accuracy: 89.36% Test Accuracy: 46.88%
[ 12/ 50] FashionMNIST Time 0.03663s Training Accuracy: 90.43% Test Accuracy: 65.62%
[ 13/ 50] MNIST Time 0.04551s Training Accuracy: 93.07% Test Accuracy: 53.12%
[ 13/ 50] FashionMNIST Time 0.03676s Training Accuracy: 93.95% Test Accuracy: 71.88%
[ 14/ 50] MNIST Time 0.04751s Training Accuracy: 95.21% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.03738s Training Accuracy: 94.14% Test Accuracy: 71.88%
[ 15/ 50] MNIST Time 0.04317s Training Accuracy: 96.39% Test Accuracy: 56.25%
[ 15/ 50] FashionMNIST Time 0.03865s Training Accuracy: 94.63% Test Accuracy: 71.88%
[ 16/ 50] MNIST Time 0.03648s Training Accuracy: 97.27% Test Accuracy: 59.38%
[ 16/ 50] FashionMNIST Time 0.03671s Training Accuracy: 96.19% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.03672s Training Accuracy: 98.63% Test Accuracy: 53.12%
[ 17/ 50] FashionMNIST Time 0.03813s Training Accuracy: 97.36% Test Accuracy: 71.88%
[ 18/ 50] MNIST Time 0.03509s Training Accuracy: 99.32% Test Accuracy: 62.50%
[ 18/ 50] FashionMNIST Time 0.03330s Training Accuracy: 97.36% Test Accuracy: 71.88%
[ 19/ 50] MNIST Time 0.03331s Training Accuracy: 99.80% Test Accuracy: 62.50%
[ 19/ 50] FashionMNIST Time 0.03773s Training Accuracy: 98.73% Test Accuracy: 75.00%
[ 20/ 50] MNIST Time 0.02998s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 20/ 50] FashionMNIST Time 0.03590s Training Accuracy: 98.93% Test Accuracy: 75.00%
[ 21/ 50] MNIST Time 0.03812s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.03214s Training Accuracy: 99.22% Test Accuracy: 75.00%
[ 22/ 50] MNIST Time 0.03273s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 22/ 50] FashionMNIST Time 0.04566s Training Accuracy: 99.32% Test Accuracy: 78.12%
[ 23/ 50] MNIST Time 0.03777s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 23/ 50] FashionMNIST Time 0.05631s Training Accuracy: 99.41% Test Accuracy: 68.75%
[ 24/ 50] MNIST Time 0.03261s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 24/ 50] FashionMNIST Time 0.03945s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 25/ 50] MNIST Time 0.03472s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 25/ 50] FashionMNIST Time 0.04427s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 26/ 50] MNIST Time 0.03489s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 26/ 50] FashionMNIST Time 0.04371s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 27/ 50] MNIST Time 0.03467s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 27/ 50] FashionMNIST Time 0.04498s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 28/ 50] MNIST Time 0.03462s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 28/ 50] FashionMNIST Time 0.03528s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 29/ 50] MNIST Time 0.03329s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 29/ 50] FashionMNIST Time 0.03257s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] MNIST Time 0.03707s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] FashionMNIST Time 0.03367s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 31/ 50] MNIST Time 0.03293s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] FashionMNIST Time 0.03029s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 32/ 50] MNIST Time 0.03259s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] FashionMNIST Time 0.03022s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 33/ 50] MNIST Time 0.03112s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 33/ 50] FashionMNIST Time 0.03275s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 34/ 50] MNIST Time 0.03213s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 34/ 50] FashionMNIST Time 0.03291s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 35/ 50] MNIST Time 0.04353s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] FashionMNIST Time 0.03524s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] MNIST Time 0.04556s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] FashionMNIST Time 0.03399s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 37/ 50] MNIST Time 0.04881s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] FashionMNIST Time 0.03861s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] MNIST Time 0.04608s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] FashionMNIST Time 0.03373s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 39/ 50] MNIST Time 0.04214s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 39/ 50] FashionMNIST Time 0.03395s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] MNIST Time 0.05148s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] FashionMNIST Time 0.03158s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 41/ 50] MNIST Time 0.04108s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] FashionMNIST Time 0.03350s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 42/ 50] MNIST Time 0.03334s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 42/ 50] FashionMNIST Time 0.03446s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 43/ 50] MNIST Time 0.03504s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 43/ 50] FashionMNIST Time 0.03389s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] MNIST Time 0.03332s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] FashionMNIST Time 0.03354s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] MNIST Time 0.03374s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] FashionMNIST Time 0.03396s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] MNIST Time 0.03387s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] FashionMNIST Time 0.03324s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] MNIST Time 0.03382s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] FashionMNIST Time 0.04302s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.03295s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 48/ 50] FashionMNIST Time 0.04318s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.03377s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 49/ 50] FashionMNIST Time 0.04508s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] MNIST Time 0.03440s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] FashionMNIST Time 0.04388s Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 68.75%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.