Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, ReactantLoading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endImplement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endDefining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
endCreate and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
endDefine Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
ŷ, _ = model((data_idx, x), ps, st)
target_class = y |> cdev |> onecold
predicted_class = ŷ |> cdev |> onecold
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endTraining
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = load_datasets() |> dev
Random.seed!(1234)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 68.66474s Training Accuracy: 34.57% Test Accuracy: 37.50%
[ 1/ 50] FashionMNIST Time 0.09102s Training Accuracy: 32.62% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.08948s Training Accuracy: 36.72% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.09169s Training Accuracy: 45.61% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.08911s Training Accuracy: 41.02% Test Accuracy: 28.12%
[ 3/ 50] FashionMNIST Time 0.08720s Training Accuracy: 57.03% Test Accuracy: 59.38%
[ 4/ 50] MNIST Time 0.08925s Training Accuracy: 52.25% Test Accuracy: 40.62%
[ 4/ 50] FashionMNIST Time 0.08731s Training Accuracy: 63.96% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.08615s Training Accuracy: 57.32% Test Accuracy: 37.50%
[ 5/ 50] FashionMNIST Time 0.08757s Training Accuracy: 70.51% Test Accuracy: 53.12%
[ 6/ 50] MNIST Time 0.08490s Training Accuracy: 63.67% Test Accuracy: 37.50%
[ 6/ 50] FashionMNIST Time 0.08741s Training Accuracy: 75.49% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.09414s Training Accuracy: 71.00% Test Accuracy: 34.38%
[ 7/ 50] FashionMNIST Time 0.08686s Training Accuracy: 76.66% Test Accuracy: 53.12%
[ 8/ 50] MNIST Time 0.08670s Training Accuracy: 75.68% Test Accuracy: 43.75%
[ 8/ 50] FashionMNIST Time 0.08567s Training Accuracy: 81.84% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.08507s Training Accuracy: 80.08% Test Accuracy: 43.75%
[ 9/ 50] FashionMNIST Time 0.08905s Training Accuracy: 83.89% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.08665s Training Accuracy: 84.96% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.09309s Training Accuracy: 88.57% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.08814s Training Accuracy: 89.45% Test Accuracy: 53.12%
[ 11/ 50] FashionMNIST Time 0.08566s Training Accuracy: 90.04% Test Accuracy: 68.75%
[ 12/ 50] MNIST Time 0.08749s Training Accuracy: 91.11% Test Accuracy: 50.00%
[ 12/ 50] FashionMNIST Time 0.09397s Training Accuracy: 91.11% Test Accuracy: 68.75%
[ 13/ 50] MNIST Time 0.09211s Training Accuracy: 93.46% Test Accuracy: 56.25%
[ 13/ 50] FashionMNIST Time 0.08720s Training Accuracy: 91.89% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.08715s Training Accuracy: 95.51% Test Accuracy: 53.12%
[ 14/ 50] FashionMNIST Time 0.08619s Training Accuracy: 94.53% Test Accuracy: 65.62%
[ 15/ 50] MNIST Time 0.08572s Training Accuracy: 96.58% Test Accuracy: 56.25%
[ 15/ 50] FashionMNIST Time 0.09332s Training Accuracy: 93.85% Test Accuracy: 71.88%
[ 16/ 50] MNIST Time 0.08934s Training Accuracy: 98.73% Test Accuracy: 62.50%
[ 16/ 50] FashionMNIST Time 0.08669s Training Accuracy: 96.29% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.08797s Training Accuracy: 99.12% Test Accuracy: 59.38%
[ 17/ 50] FashionMNIST Time 0.08551s Training Accuracy: 96.97% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.09255s Training Accuracy: 99.51% Test Accuracy: 56.25%
[ 18/ 50] FashionMNIST Time 0.08512s Training Accuracy: 96.97% Test Accuracy: 68.75%
[ 19/ 50] MNIST Time 0.08724s Training Accuracy: 99.80% Test Accuracy: 59.38%
[ 19/ 50] FashionMNIST Time 0.08602s Training Accuracy: 99.32% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.08387s Training Accuracy: 99.80% Test Accuracy: 59.38%
[ 20/ 50] FashionMNIST Time 0.09208s Training Accuracy: 98.73% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.08656s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.08497s Training Accuracy: 98.73% Test Accuracy: 65.62%
[ 22/ 50] MNIST Time 0.08437s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 22/ 50] FashionMNIST Time 0.08568s Training Accuracy: 99.61% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.09146s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 23/ 50] FashionMNIST Time 0.08779s Training Accuracy: 99.71% Test Accuracy: 65.62%
[ 24/ 50] MNIST Time 0.08777s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 24/ 50] FashionMNIST Time 0.08658s Training Accuracy: 99.51% Test Accuracy: 65.62%
[ 25/ 50] MNIST Time 0.08698s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 25/ 50] FashionMNIST Time 0.09203s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 26/ 50] MNIST Time 0.08523s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 26/ 50] FashionMNIST Time 0.08612s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] MNIST Time 0.08553s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] FashionMNIST Time 0.08586s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] MNIST Time 0.09116s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.08752s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] MNIST Time 0.08628s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] FashionMNIST Time 0.08644s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] MNIST Time 0.08574s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] FashionMNIST Time 0.09200s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] MNIST Time 0.08859s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] FashionMNIST Time 0.09048s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 32/ 50] MNIST Time 0.08794s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] FashionMNIST Time 0.08589s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 33/ 50] MNIST Time 0.09320s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 33/ 50] FashionMNIST Time 0.08834s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 34/ 50] MNIST Time 0.09556s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.08478s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] MNIST Time 0.08691s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 35/ 50] FashionMNIST Time 0.09204s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] MNIST Time 0.08601s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] FashionMNIST Time 0.08650s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 37/ 50] MNIST Time 0.08538s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 37/ 50] FashionMNIST Time 0.08416s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] MNIST Time 0.09231s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] FashionMNIST Time 0.08680s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] MNIST Time 0.08670s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 39/ 50] FashionMNIST Time 0.08650s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] MNIST Time 0.08637s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] FashionMNIST Time 0.09415s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 41/ 50] MNIST Time 0.08683s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 41/ 50] FashionMNIST Time 0.08654s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 42/ 50] MNIST Time 0.08687s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 42/ 50] FashionMNIST Time 0.08670s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] MNIST Time 0.09671s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 43/ 50] FashionMNIST Time 0.08837s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] MNIST Time 0.08730s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] FashionMNIST Time 0.09048s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 45/ 50] MNIST Time 0.08823s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] FashionMNIST Time 0.09215s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] MNIST Time 0.08551s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] FashionMNIST Time 0.08634s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] MNIST Time 0.08809s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] FashionMNIST Time 0.08601s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.09253s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 48/ 50] FashionMNIST Time 0.08888s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.08632s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 49/ 50] FashionMNIST Time 0.08487s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 50/ 50] MNIST Time 0.08543s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] FashionMNIST Time 0.09527s Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 68.75%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.