Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, ReactantLoading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endImplement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endDefining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
endCreate and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
endDefine Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
ŷ, _ = model((data_idx, x), ps, st)
target_class = y |> cdev |> onecold
predicted_class = ŷ |> cdev |> onecold
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endTraining
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = load_datasets() |> dev
Random.seed!(1234)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 68.72875s Training Accuracy: 34.57% Test Accuracy: 37.50%
[ 1/ 50] FashionMNIST Time 0.12434s Training Accuracy: 32.62% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.08649s Training Accuracy: 36.72% Test Accuracy: 31.25%
[ 2/ 50] FashionMNIST Time 0.09092s Training Accuracy: 45.80% Test Accuracy: 46.88%
[ 3/ 50] MNIST Time 0.08911s Training Accuracy: 40.43% Test Accuracy: 28.12%
[ 3/ 50] FashionMNIST Time 0.08639s Training Accuracy: 58.50% Test Accuracy: 59.38%
[ 4/ 50] MNIST Time 0.08698s Training Accuracy: 51.17% Test Accuracy: 43.75%
[ 4/ 50] FashionMNIST Time 0.08677s Training Accuracy: 63.87% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.08822s Training Accuracy: 55.47% Test Accuracy: 40.62%
[ 5/ 50] FashionMNIST Time 0.08675s Training Accuracy: 70.02% Test Accuracy: 62.50%
[ 6/ 50] MNIST Time 0.08409s Training Accuracy: 63.48% Test Accuracy: 37.50%
[ 6/ 50] FashionMNIST Time 0.09311s Training Accuracy: 75.98% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.08945s Training Accuracy: 70.51% Test Accuracy: 37.50%
[ 7/ 50] FashionMNIST Time 0.09503s Training Accuracy: 75.49% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.08799s Training Accuracy: 76.17% Test Accuracy: 43.75%
[ 8/ 50] FashionMNIST Time 0.08404s Training Accuracy: 81.05% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.08494s Training Accuracy: 80.57% Test Accuracy: 46.88%
[ 9/ 50] FashionMNIST Time 0.08603s Training Accuracy: 84.18% Test Accuracy: 65.62%
[ 10/ 50] MNIST Time 0.08633s Training Accuracy: 83.69% Test Accuracy: 46.88%
[ 10/ 50] FashionMNIST Time 0.08724s Training Accuracy: 87.21% Test Accuracy: 68.75%
[ 11/ 50] MNIST Time 0.08646s Training Accuracy: 87.79% Test Accuracy: 50.00%
[ 11/ 50] FashionMNIST Time 0.09195s Training Accuracy: 90.23% Test Accuracy: 68.75%
[ 12/ 50] MNIST Time 0.08306s Training Accuracy: 89.55% Test Accuracy: 46.88%
[ 12/ 50] FashionMNIST Time 0.08523s Training Accuracy: 91.60% Test Accuracy: 68.75%
[ 13/ 50] MNIST Time 0.08326s Training Accuracy: 94.24% Test Accuracy: 56.25%
[ 13/ 50] FashionMNIST Time 0.08315s Training Accuracy: 94.04% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.08328s Training Accuracy: 96.00% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.09170s Training Accuracy: 95.21% Test Accuracy: 68.75%
[ 15/ 50] MNIST Time 0.08703s Training Accuracy: 96.48% Test Accuracy: 59.38%
[ 15/ 50] FashionMNIST Time 0.09121s Training Accuracy: 94.43% Test Accuracy: 68.75%
[ 16/ 50] MNIST Time 0.08589s Training Accuracy: 98.83% Test Accuracy: 62.50%
[ 16/ 50] FashionMNIST Time 0.08559s Training Accuracy: 96.29% Test Accuracy: 71.88%
[ 17/ 50] MNIST Time 0.08442s Training Accuracy: 99.41% Test Accuracy: 59.38%
[ 17/ 50] FashionMNIST Time 0.08644s Training Accuracy: 97.66% Test Accuracy: 75.00%
[ 18/ 50] MNIST Time 0.08491s Training Accuracy: 99.71% Test Accuracy: 65.62%
[ 18/ 50] FashionMNIST Time 0.08559s Training Accuracy: 97.46% Test Accuracy: 71.88%
[ 19/ 50] MNIST Time 0.08419s Training Accuracy: 99.80% Test Accuracy: 53.12%
[ 19/ 50] FashionMNIST Time 0.09212s Training Accuracy: 99.02% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.08440s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 20/ 50] FashionMNIST Time 0.08293s Training Accuracy: 99.12% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.08532s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 21/ 50] FashionMNIST Time 0.08527s Training Accuracy: 98.93% Test Accuracy: 71.88%
[ 22/ 50] MNIST Time 0.08473s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 22/ 50] FashionMNIST Time 0.08822s Training Accuracy: 99.90% Test Accuracy: 68.75%
[ 23/ 50] MNIST Time 0.08538s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 23/ 50] FashionMNIST Time 0.08972s Training Accuracy: 99.80% Test Accuracy: 71.88%
[ 24/ 50] MNIST Time 0.08490s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 24/ 50] FashionMNIST Time 0.08540s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 25/ 50] MNIST Time 0.08551s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 25/ 50] FashionMNIST Time 0.08617s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.08659s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 26/ 50] FashionMNIST Time 0.08677s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 27/ 50] MNIST Time 0.08835s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 27/ 50] FashionMNIST Time 0.09214s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 28/ 50] MNIST Time 0.08716s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 28/ 50] FashionMNIST Time 0.08877s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 29/ 50] MNIST Time 0.08409s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 29/ 50] FashionMNIST Time 0.08587s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] MNIST Time 0.08485s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 30/ 50] FashionMNIST Time 0.08606s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] MNIST Time 0.08584s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 31/ 50] FashionMNIST Time 0.08891s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 32/ 50] MNIST Time 0.08274s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] FashionMNIST Time 0.08565s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 33/ 50] MNIST Time 0.08396s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 33/ 50] FashionMNIST Time 0.08705s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 34/ 50] MNIST Time 0.08377s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.08302s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 35/ 50] MNIST Time 0.08660s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] FashionMNIST Time 0.08991s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] MNIST Time 0.08514s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] FashionMNIST Time 0.08453s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 37/ 50] MNIST Time 0.08323s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] FashionMNIST Time 0.08485s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] MNIST Time 0.08358s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 38/ 50] FashionMNIST Time 0.08349s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 39/ 50] MNIST Time 0.08356s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] FashionMNIST Time 0.08842s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 40/ 50] MNIST Time 0.08444s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] FashionMNIST Time 0.08350s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 41/ 50] MNIST Time 0.08428s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] FashionMNIST Time 0.08493s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 42/ 50] MNIST Time 0.08426s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] FashionMNIST Time 0.08644s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 43/ 50] MNIST Time 0.08478s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] FashionMNIST Time 0.09260s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] MNIST Time 0.08699s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 44/ 50] FashionMNIST Time 0.08912s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] MNIST Time 0.08433s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] FashionMNIST Time 0.08562s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] MNIST Time 0.08557s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] FashionMNIST Time 0.08293s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] MNIST Time 0.08335s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 47/ 50] FashionMNIST Time 0.09154s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.08399s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 48/ 50] FashionMNIST Time 0.08367s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 49/ 50] MNIST Time 0.08380s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] FashionMNIST Time 0.08435s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] MNIST Time 0.08577s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 50/ 50] FashionMNIST Time 0.08270s Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 62.50%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 65.62%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
JULIA_DEBUG = Literate
LD_LIBRARY_PATH =
JULIA_NUM_THREADS = 4
JULIA_CPU_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0This page was generated using Literate.jl.