Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Zygote
CUDA.allowscalar(false)
Precompiling LuxComponentArraysExt...
1622.1 ms ✓ Lux → LuxComponentArraysExt
1 dependency successfully precompiled in 2 seconds. 113 already precompiled.
Precompiling LuxLibCUDAExt...
5438.1 ms ✓ LuxLib → LuxLibCUDAExt
1 dependency successfully precompiled in 6 seconds. 171 already precompiled.
Precompiling LuxLibcuDNNExt...
5541.7 ms ✓ LuxLib → LuxLibcuDNNExt
1 dependency successfully precompiled in 6 seconds. 176 already precompiled.
Precompiling LuxMLUtilsExt...
2282.3 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 167 already precompiled.
Precompiling LuxZygoteExt...
2900.0 ms ✓ Lux → LuxZygoteExt
1 dependency successfully precompiled in 3 seconds. 166 already precompiled.
Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
n_eval::Union{Nothing, Int}, batchsize::Int) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train); batchsize=min(batchsize, size(x_train, 4)), shuffle=true),
DataLoader(
(x_test, y_test); batchsize=min(batchsize, size(x_test, 4)), shuffle=false)
)
end
function load_datasets(batchsize=256)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end
Create and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
)
model = HyperNet(weight_generator, core_network)
return model
end
create_model (generic function with 1 method)
Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
[ 1/ 50] MNIST Time 94.45098s Training Accuracy: 59.18% Test Accuracy: 46.88%
[ 1/ 50] FashionMNIST Time 0.03458s Training Accuracy: 50.98% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.03528s Training Accuracy: 67.38% Test Accuracy: 65.62%
[ 2/ 50] FashionMNIST Time 0.03218s Training Accuracy: 55.08% Test Accuracy: 37.50%
[ 3/ 50] MNIST Time 0.03225s Training Accuracy: 73.24% Test Accuracy: 62.50%
[ 3/ 50] FashionMNIST Time 0.03322s Training Accuracy: 67.77% Test Accuracy: 53.12%
[ 4/ 50] MNIST Time 0.03267s Training Accuracy: 79.20% Test Accuracy: 65.62%
[ 4/ 50] FashionMNIST Time 0.03358s Training Accuracy: 63.77% Test Accuracy: 50.00%
[ 5/ 50] MNIST Time 0.06424s Training Accuracy: 84.47% Test Accuracy: 62.50%
[ 5/ 50] FashionMNIST Time 0.03170s Training Accuracy: 71.78% Test Accuracy: 50.00%
[ 6/ 50] MNIST Time 0.02446s Training Accuracy: 88.09% Test Accuracy: 65.62%
[ 6/ 50] FashionMNIST Time 0.02443s Training Accuracy: 76.76% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.02429s Training Accuracy: 89.16% Test Accuracy: 75.00%
[ 7/ 50] FashionMNIST Time 0.02518s Training Accuracy: 78.42% Test Accuracy: 65.62%
[ 8/ 50] MNIST Time 0.02348s Training Accuracy: 91.99% Test Accuracy: 81.25%
[ 8/ 50] FashionMNIST Time 0.03206s Training Accuracy: 79.79% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.02287s Training Accuracy: 95.02% Test Accuracy: 81.25%
[ 9/ 50] FashionMNIST Time 0.02440s Training Accuracy: 80.18% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.02371s Training Accuracy: 96.48% Test Accuracy: 81.25%
[ 10/ 50] FashionMNIST Time 0.02476s Training Accuracy: 81.05% Test Accuracy: 68.75%
[ 11/ 50] MNIST Time 0.02461s Training Accuracy: 97.85% Test Accuracy: 81.25%
[ 11/ 50] FashionMNIST Time 0.02367s Training Accuracy: 83.20% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.02466s Training Accuracy: 98.73% Test Accuracy: 81.25%
[ 12/ 50] FashionMNIST Time 0.02490s Training Accuracy: 84.18% Test Accuracy: 62.50%
[ 13/ 50] MNIST Time 0.02354s Training Accuracy: 98.54% Test Accuracy: 84.38%
[ 13/ 50] FashionMNIST Time 0.02442s Training Accuracy: 83.69% Test Accuracy: 65.62%
[ 14/ 50] MNIST Time 0.02491s Training Accuracy: 98.54% Test Accuracy: 81.25%
[ 14/ 50] FashionMNIST Time 0.02635s Training Accuracy: 85.25% Test Accuracy: 59.38%
[ 15/ 50] MNIST Time 0.02493s Training Accuracy: 99.41% Test Accuracy: 84.38%
[ 15/ 50] FashionMNIST Time 0.02355s Training Accuracy: 81.05% Test Accuracy: 65.62%
[ 16/ 50] MNIST Time 0.02565s Training Accuracy: 99.71% Test Accuracy: 84.38%
[ 16/ 50] FashionMNIST Time 0.02708s Training Accuracy: 82.71% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.02731s Training Accuracy: 99.80% Test Accuracy: 84.38%
[ 17/ 50] FashionMNIST Time 0.02515s Training Accuracy: 83.89% Test Accuracy: 71.88%
[ 18/ 50] MNIST Time 0.02647s Training Accuracy: 99.71% Test Accuracy: 84.38%
[ 18/ 50] FashionMNIST Time 0.02631s Training Accuracy: 83.11% Test Accuracy: 62.50%
[ 19/ 50] MNIST Time 0.02501s Training Accuracy: 99.80% Test Accuracy: 84.38%
[ 19/ 50] FashionMNIST Time 0.02496s Training Accuracy: 86.13% Test Accuracy: 62.50%
[ 20/ 50] MNIST Time 0.02613s Training Accuracy: 99.90% Test Accuracy: 84.38%
[ 20/ 50] FashionMNIST Time 0.02391s Training Accuracy: 84.57% Test Accuracy: 59.38%
[ 21/ 50] MNIST Time 0.02554s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 21/ 50] FashionMNIST Time 0.02692s Training Accuracy: 85.74% Test Accuracy: 62.50%
[ 22/ 50] MNIST Time 0.02382s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 22/ 50] FashionMNIST Time 0.02456s Training Accuracy: 89.26% Test Accuracy: 59.38%
[ 23/ 50] MNIST Time 0.02372s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 23/ 50] FashionMNIST Time 0.02369s Training Accuracy: 89.26% Test Accuracy: 59.38%
[ 24/ 50] MNIST Time 0.02994s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 24/ 50] FashionMNIST Time 0.02286s Training Accuracy: 88.96% Test Accuracy: 62.50%
[ 25/ 50] MNIST Time 0.02266s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 25/ 50] FashionMNIST Time 0.02504s Training Accuracy: 91.60% Test Accuracy: 62.50%
[ 26/ 50] MNIST Time 0.02349s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 26/ 50] FashionMNIST Time 0.02268s Training Accuracy: 91.99% Test Accuracy: 65.62%
[ 27/ 50] MNIST Time 0.02337s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 27/ 50] FashionMNIST Time 0.02308s Training Accuracy: 92.97% Test Accuracy: 62.50%
[ 28/ 50] MNIST Time 0.03522s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 28/ 50] FashionMNIST Time 0.02333s Training Accuracy: 93.07% Test Accuracy: 65.62%
[ 29/ 50] MNIST Time 0.02246s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 29/ 50] FashionMNIST Time 0.02220s Training Accuracy: 93.26% Test Accuracy: 62.50%
[ 30/ 50] MNIST Time 0.02752s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 30/ 50] FashionMNIST Time 0.02433s Training Accuracy: 93.65% Test Accuracy: 62.50%
[ 31/ 50] MNIST Time 0.02347s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 31/ 50] FashionMNIST Time 0.02389s Training Accuracy: 93.65% Test Accuracy: 62.50%
[ 32/ 50] MNIST Time 0.02469s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 32/ 50] FashionMNIST Time 0.03397s Training Accuracy: 93.95% Test Accuracy: 62.50%
[ 33/ 50] MNIST Time 0.02350s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 33/ 50] FashionMNIST Time 0.02493s Training Accuracy: 94.63% Test Accuracy: 62.50%
[ 34/ 50] MNIST Time 0.02372s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 34/ 50] FashionMNIST Time 0.02424s Training Accuracy: 94.73% Test Accuracy: 62.50%
[ 35/ 50] MNIST Time 0.02444s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 35/ 50] FashionMNIST Time 0.02449s Training Accuracy: 94.73% Test Accuracy: 59.38%
[ 36/ 50] MNIST Time 0.02493s Training Accuracy: 99.80% Test Accuracy: 84.38%
[ 36/ 50] FashionMNIST Time 0.02314s Training Accuracy: 95.02% Test Accuracy: 62.50%
[ 37/ 50] MNIST Time 0.02339s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 37/ 50] FashionMNIST Time 0.02375s Training Accuracy: 95.12% Test Accuracy: 65.62%
[ 38/ 50] MNIST Time 0.02321s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 38/ 50] FashionMNIST Time 0.02251s Training Accuracy: 95.90% Test Accuracy: 62.50%
[ 39/ 50] MNIST Time 0.02317s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 39/ 50] FashionMNIST Time 0.02318s Training Accuracy: 95.80% Test Accuracy: 65.62%
[ 40/ 50] MNIST Time 0.02267s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 40/ 50] FashionMNIST Time 0.02352s Training Accuracy: 96.39% Test Accuracy: 65.62%
[ 41/ 50] MNIST Time 0.02336s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 41/ 50] FashionMNIST Time 0.02320s Training Accuracy: 95.80% Test Accuracy: 65.62%
[ 42/ 50] MNIST Time 0.02326s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 42/ 50] FashionMNIST Time 0.02928s Training Accuracy: 95.80% Test Accuracy: 65.62%
[ 43/ 50] MNIST Time 0.02673s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 43/ 50] FashionMNIST Time 0.02489s Training Accuracy: 95.61% Test Accuracy: 65.62%
[ 44/ 50] MNIST Time 0.02546s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 44/ 50] FashionMNIST Time 0.02335s Training Accuracy: 96.58% Test Accuracy: 65.62%
[ 45/ 50] MNIST Time 0.03225s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 45/ 50] FashionMNIST Time 0.02560s Training Accuracy: 96.29% Test Accuracy: 65.62%
[ 46/ 50] MNIST Time 0.02537s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 46/ 50] FashionMNIST Time 0.02629s Training Accuracy: 96.39% Test Accuracy: 65.62%
[ 47/ 50] MNIST Time 0.02619s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 47/ 50] FashionMNIST Time 0.02527s Training Accuracy: 96.78% Test Accuracy: 65.62%
[ 48/ 50] MNIST Time 0.03111s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 48/ 50] FashionMNIST Time 0.02471s Training Accuracy: 96.88% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.02520s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 49/ 50] FashionMNIST Time 0.02442s Training Accuracy: 97.07% Test Accuracy: 65.62%
[ 50/ 50] MNIST Time 0.02592s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 50/ 50] FashionMNIST Time 0.02452s Training Accuracy: 97.46% Test Accuracy: 65.62%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 84.38%
[FINAL] FashionMNIST Training Accuracy: 97.46% Test Accuracy: 65.62%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.6.1
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.3
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.982 GiB / 4.750 GiB available)
This page was generated using Literate.jl.