Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
CUDA.allowscalar(false)Precompiling Lux...
6465.2 ms ✓ LLVM
1905.7 ms ✓ UnsafeAtomicsLLVM
4598.8 ms ✓ KernelAbstractions
1568.3 ms ✓ KernelAbstractions → LinearAlgebraExt
1737.2 ms ✓ KernelAbstractions → EnzymeExt
6395.5 ms ✓ NNlib
1805.4 ms ✓ NNlib → NNlibEnzymeCoreExt
1982.7 ms ✓ NNlib → NNlibForwardDiffExt
6566.4 ms ✓ LuxLib
9818.2 ms ✓ Lux
10 dependencies successfully precompiled in 40 seconds. 113 already precompiled.
Precompiling LuxComponentArraysExt...
2291.4 ms ✓ Lux → LuxComponentArraysExt
2825.6 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
2 dependencies successfully precompiled in 3 seconds. 125 already precompiled.
Precompiling LuxCUDA...
1338.1 ms ✓ LLVM → BFloat16sExt
2322.2 ms ✓ GPUArrays
1937.8 ms ✓ KernelAbstractions → SparseArraysExt
27843.0 ms ✓ GPUCompiler
47660.1 ms ✓ DataFrames
51860.4 ms ✓ CUDA
5594.4 ms ✓ Atomix → AtomixCUDAExt
8519.3 ms ✓ cuDNN
5360.8 ms ✓ LuxCUDA
9 dependencies successfully precompiled in 120 seconds. 91 already precompiled.
Precompiling MLDataDevicesGPUArraysExt...
1358.5 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
1 dependency successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling WeightInitializersGPUArraysExt...
1404.7 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling ComponentArraysGPUArraysExt...
1601.0 ms ✓ ComponentArrays → ComponentArraysGPUArraysExt
1 dependency successfully precompiled in 2 seconds. 68 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
4946.9 ms ✓ ArrayInterface → ArrayInterfaceCUDAExt
1 dependency successfully precompiled in 5 seconds. 101 already precompiled.
Precompiling NNlibCUDAExt...
5057.4 ms ✓ CUDA → ChainRulesCoreExt
5546.4 ms ✓ NNlib → NNlibCUDAExt
2 dependencies successfully precompiled in 6 seconds. 102 already precompiled.
Precompiling MLDataDevicesCUDAExt...
5021.2 ms ✓ MLDataDevices → MLDataDevicesCUDAExt
1 dependency successfully precompiled in 5 seconds. 104 already precompiled.
Precompiling LuxLibCUDAExt...
5183.0 ms ✓ CUDA → SpecialFunctionsExt
5642.1 ms ✓ CUDA → EnzymeCoreExt
6232.8 ms ✓ LuxLib → LuxLibCUDAExt
3 dependencies successfully precompiled in 7 seconds. 167 already precompiled.
Precompiling WeightInitializersCUDAExt...
4973.0 ms ✓ WeightInitializers → WeightInitializersCUDAExt
1 dependency successfully precompiled in 5 seconds. 109 already precompiled.
Precompiling NNlibCUDACUDNNExt...
5594.3 ms ✓ NNlib → NNlibCUDACUDNNExt
1 dependency successfully precompiled in 6 seconds. 106 already precompiled.
Precompiling MLDataDevicescuDNNExt...
5270.2 ms ✓ MLDataDevices → MLDataDevicescuDNNExt
1 dependency successfully precompiled in 6 seconds. 107 already precompiled.
Precompiling LuxLibcuDNNExt...
5779.3 ms ✓ LuxLib → LuxLibcuDNNExt
1 dependency successfully precompiled in 6 seconds. 174 already precompiled.
Precompiling MLDatasets...
1439.5 ms ✓ Transducers → TransducersDataFramesExt
1696.5 ms ✓ BangBang → BangBangDataFramesExt
7689.0 ms ✓ MLUtils
20360.7 ms ✓ CSV
9518.8 ms ✓ MLDatasets
5 dependencies successfully precompiled in 30 seconds. 196 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
2568.0 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 3 seconds. 116 already precompiled.
Precompiling LuxMLUtilsExt...
3138.0 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 4 seconds. 178 already precompiled.
Precompiling OneHotArrays...
1828.5 ms ✓ OneHotArrays
1 dependency successfully precompiled in 2 seconds. 50 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
1716.8 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 2 seconds. 57 already precompiled.
Precompiling LuxZygoteExt...
3770.4 ms ✓ Lux → LuxZygoteExt
1 dependency successfully precompiled in 4 seconds. 162 already precompiled.Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
imgs, labels = dset(:train)[1:n_train]
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)
)
end
function load_datasets(n_train=1024, n_eval=32, batchsize=256)
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endload_datasets (generic function with 4 methods)Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endHyperNet (generic function with 1 method)Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
endCreate and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
)
model = HyperNet(weight_generator, core_network)
return model
endcreate_model (generic function with 1 method)Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endaccuracy (generic function with 1 method)Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 83.64228s Training Accuracy: 55.96% Test Accuracy: 46.88%
[ 1/ 50] FashionMNIST Time 0.04220s Training Accuracy: 39.36% Test Accuracy: 31.25%
[ 2/ 50] MNIST Time 0.02845s Training Accuracy: 68.75% Test Accuracy: 56.25%
[ 2/ 50] FashionMNIST Time 0.02756s Training Accuracy: 63.18% Test Accuracy: 59.38%
[ 3/ 50] MNIST Time 0.02854s Training Accuracy: 78.42% Test Accuracy: 62.50%
[ 3/ 50] FashionMNIST Time 0.02976s Training Accuracy: 68.55% Test Accuracy: 65.62%
[ 4/ 50] MNIST Time 0.02679s Training Accuracy: 80.18% Test Accuracy: 65.62%
[ 4/ 50] FashionMNIST Time 0.02406s Training Accuracy: 61.43% Test Accuracy: 65.62%
[ 5/ 50] MNIST Time 0.02184s Training Accuracy: 82.42% Test Accuracy: 65.62%
[ 5/ 50] FashionMNIST Time 0.02261s Training Accuracy: 72.56% Test Accuracy: 75.00%
[ 6/ 50] MNIST Time 0.02117s Training Accuracy: 86.91% Test Accuracy: 75.00%
[ 6/ 50] FashionMNIST Time 0.02097s Training Accuracy: 75.68% Test Accuracy: 65.62%
[ 7/ 50] MNIST Time 0.12138s Training Accuracy: 89.36% Test Accuracy: 78.12%
[ 7/ 50] FashionMNIST Time 0.02130s Training Accuracy: 73.73% Test Accuracy: 68.75%
[ 8/ 50] MNIST Time 0.02080s Training Accuracy: 91.21% Test Accuracy: 81.25%
[ 8/ 50] FashionMNIST Time 0.02233s Training Accuracy: 75.59% Test Accuracy: 71.88%
[ 9/ 50] MNIST Time 0.02139s Training Accuracy: 93.36% Test Accuracy: 78.12%
[ 9/ 50] FashionMNIST Time 0.02181s Training Accuracy: 78.03% Test Accuracy: 75.00%
[ 10/ 50] MNIST Time 0.02317s Training Accuracy: 94.53% Test Accuracy: 78.12%
[ 10/ 50] FashionMNIST Time 0.02376s Training Accuracy: 77.15% Test Accuracy: 62.50%
[ 11/ 50] MNIST Time 0.02234s Training Accuracy: 97.17% Test Accuracy: 84.38%
[ 11/ 50] FashionMNIST Time 0.02063s Training Accuracy: 80.18% Test Accuracy: 71.88%
[ 12/ 50] MNIST Time 0.02080s Training Accuracy: 98.14% Test Accuracy: 87.50%
[ 12/ 50] FashionMNIST Time 0.03034s Training Accuracy: 81.25% Test Accuracy: 81.25%
[ 13/ 50] MNIST Time 0.02292s Training Accuracy: 98.83% Test Accuracy: 84.38%
[ 13/ 50] FashionMNIST Time 0.02190s Training Accuracy: 81.84% Test Accuracy: 78.12%
[ 14/ 50] MNIST Time 0.02274s Training Accuracy: 99.02% Test Accuracy: 81.25%
[ 14/ 50] FashionMNIST Time 0.02165s Training Accuracy: 83.20% Test Accuracy: 81.25%
[ 15/ 50] MNIST Time 0.02197s Training Accuracy: 99.41% Test Accuracy: 81.25%
[ 15/ 50] FashionMNIST Time 0.02208s Training Accuracy: 84.18% Test Accuracy: 78.12%
[ 16/ 50] MNIST Time 0.02363s Training Accuracy: 99.71% Test Accuracy: 81.25%
[ 16/ 50] FashionMNIST Time 0.02594s Training Accuracy: 84.77% Test Accuracy: 75.00%
[ 17/ 50] MNIST Time 0.02194s Training Accuracy: 99.80% Test Accuracy: 81.25%
[ 17/ 50] FashionMNIST Time 0.02081s Training Accuracy: 85.45% Test Accuracy: 78.12%
[ 18/ 50] MNIST Time 0.02280s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 18/ 50] FashionMNIST Time 0.02157s Training Accuracy: 86.43% Test Accuracy: 78.12%
[ 19/ 50] MNIST Time 0.02196s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 19/ 50] FashionMNIST Time 0.02154s Training Accuracy: 87.40% Test Accuracy: 78.12%
[ 20/ 50] MNIST Time 0.02838s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 20/ 50] FashionMNIST Time 0.02275s Training Accuracy: 87.60% Test Accuracy: 78.12%
[ 21/ 50] MNIST Time 0.02219s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 21/ 50] FashionMNIST Time 0.02292s Training Accuracy: 88.48% Test Accuracy: 81.25%
[ 22/ 50] MNIST Time 0.02135s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 22/ 50] FashionMNIST Time 0.02119s Training Accuracy: 88.77% Test Accuracy: 81.25%
[ 23/ 50] MNIST Time 0.02117s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 23/ 50] FashionMNIST Time 0.02120s Training Accuracy: 89.36% Test Accuracy: 78.12%
[ 24/ 50] MNIST Time 0.02519s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 24/ 50] FashionMNIST Time 0.02117s Training Accuracy: 89.94% Test Accuracy: 81.25%
[ 25/ 50] MNIST Time 0.02107s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 25/ 50] FashionMNIST Time 0.02220s Training Accuracy: 90.14% Test Accuracy: 81.25%
[ 26/ 50] MNIST Time 0.02096s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 26/ 50] FashionMNIST Time 0.02091s Training Accuracy: 90.72% Test Accuracy: 81.25%
[ 27/ 50] MNIST Time 0.02084s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 27/ 50] FashionMNIST Time 0.02111s Training Accuracy: 91.02% Test Accuracy: 81.25%
[ 28/ 50] MNIST Time 0.02565s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 28/ 50] FashionMNIST Time 0.02350s Training Accuracy: 90.92% Test Accuracy: 81.25%
[ 29/ 50] MNIST Time 0.02434s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 29/ 50] FashionMNIST Time 0.02745s Training Accuracy: 91.11% Test Accuracy: 81.25%
[ 30/ 50] MNIST Time 0.02276s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 30/ 50] FashionMNIST Time 0.02369s Training Accuracy: 92.09% Test Accuracy: 78.12%
[ 31/ 50] MNIST Time 0.02933s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 31/ 50] FashionMNIST Time 0.01989s Training Accuracy: 91.02% Test Accuracy: 81.25%
[ 32/ 50] MNIST Time 0.02080s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 32/ 50] FashionMNIST Time 0.02125s Training Accuracy: 91.41% Test Accuracy: 71.88%
[ 33/ 50] MNIST Time 0.02128s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 33/ 50] FashionMNIST Time 0.02177s Training Accuracy: 90.33% Test Accuracy: 78.12%
[ 34/ 50] MNIST Time 0.02062s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 34/ 50] FashionMNIST Time 0.02130s Training Accuracy: 91.50% Test Accuracy: 75.00%
[ 35/ 50] MNIST Time 0.02073s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 35/ 50] FashionMNIST Time 0.02038s Training Accuracy: 91.89% Test Accuracy: 78.12%
[ 36/ 50] MNIST Time 0.02041s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 36/ 50] FashionMNIST Time 0.02201s Training Accuracy: 92.77% Test Accuracy: 78.12%
[ 37/ 50] MNIST Time 0.02278s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 37/ 50] FashionMNIST Time 0.02282s Training Accuracy: 92.68% Test Accuracy: 81.25%
[ 38/ 50] MNIST Time 0.02760s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 38/ 50] FashionMNIST Time 0.02224s Training Accuracy: 93.26% Test Accuracy: 78.12%
[ 39/ 50] MNIST Time 0.02176s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 39/ 50] FashionMNIST Time 0.02131s Training Accuracy: 93.46% Test Accuracy: 81.25%
[ 40/ 50] MNIST Time 0.02280s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 40/ 50] FashionMNIST Time 0.02084s Training Accuracy: 93.85% Test Accuracy: 78.12%
[ 41/ 50] MNIST Time 0.02146s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 41/ 50] FashionMNIST Time 0.02138s Training Accuracy: 94.34% Test Accuracy: 81.25%
[ 42/ 50] MNIST Time 0.02611s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 42/ 50] FashionMNIST Time 0.02165s Training Accuracy: 94.04% Test Accuracy: 78.12%
[ 43/ 50] MNIST Time 0.02243s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 43/ 50] FashionMNIST Time 0.02275s Training Accuracy: 94.82% Test Accuracy: 78.12%
[ 44/ 50] MNIST Time 0.02336s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 44/ 50] FashionMNIST Time 0.02289s Training Accuracy: 94.73% Test Accuracy: 78.12%
[ 45/ 50] MNIST Time 0.02144s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 45/ 50] FashionMNIST Time 0.02455s Training Accuracy: 94.92% Test Accuracy: 81.25%
[ 46/ 50] MNIST Time 0.02101s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 46/ 50] FashionMNIST Time 0.02202s Training Accuracy: 94.73% Test Accuracy: 81.25%
[ 47/ 50] MNIST Time 0.02197s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 47/ 50] FashionMNIST Time 0.02190s Training Accuracy: 94.63% Test Accuracy: 78.12%
[ 48/ 50] MNIST Time 0.02162s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 48/ 50] FashionMNIST Time 0.02333s Training Accuracy: 94.34% Test Accuracy: 81.25%
[ 49/ 50] MNIST Time 0.02363s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 49/ 50] FashionMNIST Time 0.02219s Training Accuracy: 94.53% Test Accuracy: 78.12%
[ 50/ 50] MNIST Time 0.02235s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 50/ 50] FashionMNIST Time 0.02864s Training Accuracy: 95.31% Test Accuracy: 78.12%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 81.25%
[FINAL] FashionMNIST Training Accuracy: 95.31% Test Accuracy: 78.12%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.357 GiB / 4.750 GiB available)This page was generated using Literate.jl.