Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
Loading Datasets
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
n_eval::Union{Nothing, Int}, batchsize::Int) where {dset}
if n_train === nothing
imgs, labels = dset(:train)
imgs, labels = dset(:train)[1:n_train]
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
if n_eval === nothing
imgs, labels = dset(:test)
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)
function load_datasets(batchsize=256)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
Create and Initialize the HyperNet
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
model = HyperNet(weight_generator, core_network)
return model
create_model (generic function with 1 method)
Define Utility Functions
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
return total_correct / total
accuracy (generic function with 1 method)
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
return test_acc_list
test_acc_list = train()
[ 1/ 50] MNIST Time 86.41144s Training Accuracy: 54.00% Test Accuracy: 53.12%
[ 1/ 50] FashionMNIST Time 0.02907s Training Accuracy: 36.04% Test Accuracy: 31.25%
[ 2/ 50] MNIST Time 0.02835s Training Accuracy: 64.94% Test Accuracy: 62.50%
[ 2/ 50] FashionMNIST Time 0.02838s Training Accuracy: 51.76% Test Accuracy: 53.12%
[ 3/ 50] MNIST Time 0.05963s Training Accuracy: 73.24% Test Accuracy: 53.12%
[ 3/ 50] FashionMNIST Time 0.02768s Training Accuracy: 59.08% Test Accuracy: 50.00%
[ 4/ 50] MNIST Time 0.02138s Training Accuracy: 75.49% Test Accuracy: 59.38%
[ 4/ 50] FashionMNIST Time 0.02150s Training Accuracy: 68.65% Test Accuracy: 50.00%
[ 5/ 50] MNIST Time 0.02270s Training Accuracy: 81.64% Test Accuracy: 62.50%
[ 5/ 50] FashionMNIST Time 0.09385s Training Accuracy: 67.09% Test Accuracy: 65.62%
[ 6/ 50] MNIST Time 0.02166s Training Accuracy: 85.64% Test Accuracy: 68.75%
[ 6/ 50] FashionMNIST Time 0.02131s Training Accuracy: 71.78% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.02086s Training Accuracy: 87.60% Test Accuracy: 71.88%
[ 7/ 50] FashionMNIST Time 0.02124s Training Accuracy: 76.37% Test Accuracy: 75.00%
[ 8/ 50] MNIST Time 0.02026s Training Accuracy: 89.94% Test Accuracy: 75.00%
[ 8/ 50] FashionMNIST Time 0.02188s Training Accuracy: 80.96% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.02065s Training Accuracy: 91.70% Test Accuracy: 71.88%
[ 9/ 50] FashionMNIST Time 0.02066s Training Accuracy: 78.52% Test Accuracy: 71.88%
[ 10/ 50] MNIST Time 0.02248s Training Accuracy: 94.92% Test Accuracy: 71.88%
[ 10/ 50] FashionMNIST Time 0.02141s Training Accuracy: 80.76% Test Accuracy: 75.00%
[ 11/ 50] MNIST Time 0.02352s Training Accuracy: 95.51% Test Accuracy: 75.00%
[ 11/ 50] FashionMNIST Time 0.03260s Training Accuracy: 81.84% Test Accuracy: 68.75%
[ 12/ 50] MNIST Time 0.02039s Training Accuracy: 95.90% Test Accuracy: 75.00%
[ 12/ 50] FashionMNIST Time 0.02034s Training Accuracy: 82.23% Test Accuracy: 78.12%
[ 13/ 50] MNIST Time 0.02214s Training Accuracy: 97.66% Test Accuracy: 75.00%
[ 13/ 50] FashionMNIST Time 0.02058s Training Accuracy: 82.03% Test Accuracy: 78.12%
[ 14/ 50] MNIST Time 0.02051s Training Accuracy: 98.05% Test Accuracy: 78.12%
[ 14/ 50] FashionMNIST Time 0.02128s Training Accuracy: 83.50% Test Accuracy: 84.38%
[ 15/ 50] MNIST Time 0.02040s Training Accuracy: 98.93% Test Accuracy: 81.25%
[ 15/ 50] FashionMNIST Time 0.02037s Training Accuracy: 85.06% Test Accuracy: 81.25%
[ 16/ 50] MNIST Time 0.02180s Training Accuracy: 99.32% Test Accuracy: 81.25%
[ 16/ 50] FashionMNIST Time 0.02008s Training Accuracy: 85.55% Test Accuracy: 75.00%
[ 17/ 50] MNIST Time 0.02034s Training Accuracy: 99.61% Test Accuracy: 81.25%
[ 17/ 50] FashionMNIST Time 0.02047s Training Accuracy: 85.94% Test Accuracy: 78.12%
[ 18/ 50] MNIST Time 0.02575s Training Accuracy: 99.80% Test Accuracy: 81.25%
[ 18/ 50] FashionMNIST Time 0.02010s Training Accuracy: 86.62% Test Accuracy: 78.12%
[ 19/ 50] MNIST Time 0.02034s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 19/ 50] FashionMNIST Time 0.02190s Training Accuracy: 87.60% Test Accuracy: 78.12%
[ 20/ 50] MNIST Time 0.02036s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 20/ 50] FashionMNIST Time 0.02037s Training Accuracy: 88.57% Test Accuracy: 75.00%
[ 21/ 50] MNIST Time 0.02174s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 21/ 50] FashionMNIST Time 0.02068s Training Accuracy: 89.84% Test Accuracy: 75.00%
[ 22/ 50] MNIST Time 0.02085s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 22/ 50] FashionMNIST Time 0.02283s Training Accuracy: 89.75% Test Accuracy: 78.12%
[ 23/ 50] MNIST Time 0.02082s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 23/ 50] FashionMNIST Time 0.02086s Training Accuracy: 90.53% Test Accuracy: 75.00%
[ 24/ 50] MNIST Time 0.02338s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 24/ 50] FashionMNIST Time 0.02136s Training Accuracy: 90.92% Test Accuracy: 71.88%
[ 25/ 50] MNIST Time 0.02099s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 25/ 50] FashionMNIST Time 0.02270s Training Accuracy: 90.62% Test Accuracy: 78.12%
[ 26/ 50] MNIST Time 0.02105s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 26/ 50] FashionMNIST Time 0.02280s Training Accuracy: 90.14% Test Accuracy: 78.12%
[ 27/ 50] MNIST Time 0.02436s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 27/ 50] FashionMNIST Time 0.02136s Training Accuracy: 91.02% Test Accuracy: 78.12%
[ 28/ 50] MNIST Time 0.02088s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 28/ 50] FashionMNIST Time 0.02328s Training Accuracy: 90.62% Test Accuracy: 78.12%
[ 29/ 50] MNIST Time 0.02103s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 29/ 50] FashionMNIST Time 0.02093s Training Accuracy: 91.50% Test Accuracy: 78.12%
[ 30/ 50] MNIST Time 0.02463s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 30/ 50] FashionMNIST Time 0.02127s Training Accuracy: 91.60% Test Accuracy: 78.12%
[ 31/ 50] MNIST Time 0.02127s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 31/ 50] FashionMNIST Time 0.02264s Training Accuracy: 92.38% Test Accuracy: 75.00%
[ 32/ 50] MNIST Time 0.02173s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 32/ 50] FashionMNIST Time 0.02072s Training Accuracy: 92.58% Test Accuracy: 78.12%
[ 33/ 50] MNIST Time 0.02226s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 33/ 50] FashionMNIST Time 0.02081s Training Accuracy: 93.07% Test Accuracy: 78.12%
[ 34/ 50] MNIST Time 0.02189s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 34/ 50] FashionMNIST Time 0.02241s Training Accuracy: 93.36% Test Accuracy: 78.12%
[ 35/ 50] MNIST Time 0.02044s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 35/ 50] FashionMNIST Time 0.02066s Training Accuracy: 93.36% Test Accuracy: 78.12%
[ 36/ 50] MNIST Time 0.02227s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 36/ 50] FashionMNIST Time 0.02411s Training Accuracy: 93.95% Test Accuracy: 75.00%
[ 37/ 50] MNIST Time 0.02093s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 37/ 50] FashionMNIST Time 0.02723s Training Accuracy: 94.43% Test Accuracy: 71.88%
[ 38/ 50] MNIST Time 0.02085s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 38/ 50] FashionMNIST Time 0.02177s Training Accuracy: 94.34% Test Accuracy: 75.00%
[ 39/ 50] MNIST Time 0.02244s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 39/ 50] FashionMNIST Time 0.02153s Training Accuracy: 94.73% Test Accuracy: 75.00%
[ 40/ 50] MNIST Time 0.02167s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 40/ 50] FashionMNIST Time 0.02313s Training Accuracy: 95.02% Test Accuracy: 75.00%
[ 41/ 50] MNIST Time 0.02104s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 41/ 50] FashionMNIST Time 0.02070s Training Accuracy: 95.21% Test Accuracy: 78.12%
[ 42/ 50] MNIST Time 0.02538s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 42/ 50] FashionMNIST Time 0.02063s Training Accuracy: 95.41% Test Accuracy: 78.12%
[ 43/ 50] MNIST Time 0.02206s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 43/ 50] FashionMNIST Time 0.02219s Training Accuracy: 94.73% Test Accuracy: 78.12%
[ 44/ 50] MNIST Time 0.02107s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 44/ 50] FashionMNIST Time 0.02247s Training Accuracy: 95.31% Test Accuracy: 78.12%
[ 45/ 50] MNIST Time 0.02237s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 45/ 50] FashionMNIST Time 0.02071s Training Accuracy: 95.70% Test Accuracy: 78.12%
[ 46/ 50] MNIST Time 0.02070s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 46/ 50] FashionMNIST Time 0.02237s Training Accuracy: 95.12% Test Accuracy: 81.25%
[ 47/ 50] MNIST Time 0.02072s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 47/ 50] FashionMNIST Time 0.02157s Training Accuracy: 95.41% Test Accuracy: 75.00%
[ 48/ 50] MNIST Time 0.02487s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 48/ 50] FashionMNIST Time 0.02112s Training Accuracy: 95.61% Test Accuracy: 78.12%
[ 49/ 50] MNIST Time 0.02190s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 49/ 50] FashionMNIST Time 0.02341s Training Accuracy: 96.39% Test Accuracy: 75.00%
[ 50/ 50] MNIST Time 0.02079s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 50/ 50] FashionMNIST Time 0.02124s Training Accuracy: 96.68% Test Accuracy: 78.12%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 78.12%
[FINAL] FashionMNIST Training Accuracy: 96.68% Test Accuracy: 78.12%
using InteractiveUtils
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
- Julia: 1.11.2
- LLVM: 16.0.6
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.576 GiB / 4.750 GiB available)
