Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
CUDA.allowscalar(false)Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
imgs, labels = dset(:train)[1:n_train]
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end
function load_datasets(n_train=1024, n_eval=32, batchsize=256)
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endload_datasets (generic function with 4 methods)Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endHyperNet (generic function with 1 method)Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
endCreate and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)))
model = HyperNet(weight_generator, core_network)
return model
endcreate_model (generic function with 1 method)Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endaccuracy (generic function with 1 method)Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 71.03980s Training Accuracy: 59.77% Test Accuracy: 53.12%
[ 1/ 50] FashionMNIST Time 0.02848s Training Accuracy: 49.22% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.02817s Training Accuracy: 70.12% Test Accuracy: 68.75%
[ 2/ 50] FashionMNIST Time 0.02862s Training Accuracy: 60.45% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.02804s Training Accuracy: 76.95% Test Accuracy: 65.62%
[ 3/ 50] FashionMNIST Time 0.02236s Training Accuracy: 63.38% Test Accuracy: 59.38%
[ 4/ 50] MNIST Time 0.02292s Training Accuracy: 79.20% Test Accuracy: 65.62%
[ 4/ 50] FashionMNIST Time 0.02286s Training Accuracy: 59.18% Test Accuracy: 56.25%
[ 5/ 50] MNIST Time 0.02302s Training Accuracy: 82.42% Test Accuracy: 68.75%
[ 5/ 50] FashionMNIST Time 0.04113s Training Accuracy: 73.14% Test Accuracy: 62.50%
[ 6/ 50] MNIST Time 0.02088s Training Accuracy: 88.09% Test Accuracy: 68.75%
[ 6/ 50] FashionMNIST Time 0.02097s Training Accuracy: 72.17% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.02118s Training Accuracy: 91.41% Test Accuracy: 75.00%
[ 7/ 50] FashionMNIST Time 0.02044s Training Accuracy: 75.59% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.02076s Training Accuracy: 92.19% Test Accuracy: 81.25%
[ 8/ 50] FashionMNIST Time 0.02037s Training Accuracy: 78.32% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.02041s Training Accuracy: 93.75% Test Accuracy: 81.25%
[ 9/ 50] FashionMNIST Time 0.02058s Training Accuracy: 77.64% Test Accuracy: 68.75%
[ 10/ 50] MNIST Time 0.03276s Training Accuracy: 96.29% Test Accuracy: 84.38%
[ 10/ 50] FashionMNIST Time 0.02028s Training Accuracy: 77.34% Test Accuracy: 68.75%
[ 11/ 50] MNIST Time 0.02036s Training Accuracy: 97.27% Test Accuracy: 84.38%
[ 11/ 50] FashionMNIST Time 0.02073s Training Accuracy: 79.39% Test Accuracy: 68.75%
[ 12/ 50] MNIST Time 0.02104s Training Accuracy: 97.56% Test Accuracy: 84.38%
[ 12/ 50] FashionMNIST Time 0.02038s Training Accuracy: 79.69% Test Accuracy: 71.88%
[ 13/ 50] MNIST Time 0.02077s Training Accuracy: 98.73% Test Accuracy: 84.38%
[ 13/ 50] FashionMNIST Time 0.02030s Training Accuracy: 82.71% Test Accuracy: 75.00%
[ 14/ 50] MNIST Time 0.02066s Training Accuracy: 99.41% Test Accuracy: 84.38%
[ 14/ 50] FashionMNIST Time 0.02658s Training Accuracy: 82.91% Test Accuracy: 71.88%
[ 15/ 50] MNIST Time 0.02095s Training Accuracy: 99.51% Test Accuracy: 84.38%
[ 15/ 50] FashionMNIST Time 0.02054s Training Accuracy: 82.91% Test Accuracy: 75.00%
[ 16/ 50] MNIST Time 0.02063s Training Accuracy: 99.80% Test Accuracy: 84.38%
[ 16/ 50] FashionMNIST Time 0.02108s Training Accuracy: 83.98% Test Accuracy: 62.50%
[ 17/ 50] MNIST Time 0.02091s Training Accuracy: 99.71% Test Accuracy: 84.38%
[ 17/ 50] FashionMNIST Time 0.02103s Training Accuracy: 86.43% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.02066s Training Accuracy: 99.90% Test Accuracy: 84.38%
[ 18/ 50] FashionMNIST Time 0.02035s Training Accuracy: 86.62% Test Accuracy: 65.62%
[ 19/ 50] MNIST Time 0.02030s Training Accuracy: 99.90% Test Accuracy: 84.38%
[ 19/ 50] FashionMNIST Time 0.02081s Training Accuracy: 88.67% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.02087s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 20/ 50] FashionMNIST Time 0.02054s Training Accuracy: 89.36% Test Accuracy: 71.88%
[ 21/ 50] MNIST Time 0.02052s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 21/ 50] FashionMNIST Time 0.02101s Training Accuracy: 89.65% Test Accuracy: 84.38%
[ 22/ 50] MNIST Time 0.02074s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 22/ 50] FashionMNIST Time 0.02047s Training Accuracy: 90.43% Test Accuracy: 78.12%
[ 23/ 50] MNIST Time 0.02776s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 23/ 50] FashionMNIST Time 0.02046s Training Accuracy: 90.72% Test Accuracy: 78.12%
[ 24/ 50] MNIST Time 0.02053s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 24/ 50] FashionMNIST Time 0.02026s Training Accuracy: 91.31% Test Accuracy: 78.12%
[ 25/ 50] MNIST Time 0.02207s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 25/ 50] FashionMNIST Time 0.02159s Training Accuracy: 91.89% Test Accuracy: 84.38%
[ 26/ 50] MNIST Time 0.02081s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 26/ 50] FashionMNIST Time 0.02090s Training Accuracy: 91.80% Test Accuracy: 81.25%
[ 27/ 50] MNIST Time 0.02054s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 27/ 50] FashionMNIST Time 0.02756s Training Accuracy: 92.58% Test Accuracy: 81.25%
[ 28/ 50] MNIST Time 0.02042s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 28/ 50] FashionMNIST Time 0.02061s Training Accuracy: 92.77% Test Accuracy: 75.00%
[ 29/ 50] MNIST Time 0.02075s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 29/ 50] FashionMNIST Time 0.02078s Training Accuracy: 92.38% Test Accuracy: 75.00%
[ 30/ 50] MNIST Time 0.02057s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 30/ 50] FashionMNIST Time 0.02072s Training Accuracy: 92.77% Test Accuracy: 81.25%
[ 31/ 50] MNIST Time 0.02089s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 31/ 50] FashionMNIST Time 0.02029s Training Accuracy: 93.26% Test Accuracy: 84.38%
[ 32/ 50] MNIST Time 0.02640s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 32/ 50] FashionMNIST Time 0.02032s Training Accuracy: 94.04% Test Accuracy: 81.25%
[ 33/ 50] MNIST Time 0.02031s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 33/ 50] FashionMNIST Time 0.02044s Training Accuracy: 94.53% Test Accuracy: 81.25%
[ 34/ 50] MNIST Time 0.02108s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 34/ 50] FashionMNIST Time 0.02044s Training Accuracy: 94.82% Test Accuracy: 81.25%
[ 35/ 50] MNIST Time 0.02080s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 35/ 50] FashionMNIST Time 0.02075s Training Accuracy: 95.31% Test Accuracy: 78.12%
[ 36/ 50] MNIST Time 0.02122s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 36/ 50] FashionMNIST Time 0.02006s Training Accuracy: 95.12% Test Accuracy: 81.25%
[ 37/ 50] MNIST Time 0.02076s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 37/ 50] FashionMNIST Time 0.02073s Training Accuracy: 95.80% Test Accuracy: 84.38%
[ 38/ 50] MNIST Time 0.02059s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 38/ 50] FashionMNIST Time 0.02081s Training Accuracy: 95.12% Test Accuracy: 81.25%
[ 39/ 50] MNIST Time 0.02079s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 39/ 50] FashionMNIST Time 0.02320s Training Accuracy: 95.70% Test Accuracy: 84.38%
[ 40/ 50] MNIST Time 0.02078s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 40/ 50] FashionMNIST Time 0.02696s Training Accuracy: 96.09% Test Accuracy: 84.38%
[ 41/ 50] MNIST Time 0.02039s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 41/ 50] FashionMNIST Time 0.02076s Training Accuracy: 96.00% Test Accuracy: 84.38%
[ 42/ 50] MNIST Time 0.02078s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 42/ 50] FashionMNIST Time 0.02069s Training Accuracy: 96.19% Test Accuracy: 81.25%
[ 43/ 50] MNIST Time 0.02081s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 43/ 50] FashionMNIST Time 0.02025s Training Accuracy: 96.09% Test Accuracy: 84.38%
[ 44/ 50] MNIST Time 0.02177s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 44/ 50] FashionMNIST Time 0.03603s Training Accuracy: 96.68% Test Accuracy: 84.38%
[ 45/ 50] MNIST Time 0.02670s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 45/ 50] FashionMNIST Time 0.01978s Training Accuracy: 97.27% Test Accuracy: 84.38%
[ 46/ 50] MNIST Time 0.01968s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 46/ 50] FashionMNIST Time 0.02102s Training Accuracy: 97.46% Test Accuracy: 84.38%
[ 47/ 50] MNIST Time 0.01969s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 47/ 50] FashionMNIST Time 0.01991s Training Accuracy: 97.27% Test Accuracy: 84.38%
[ 48/ 50] MNIST Time 0.01988s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 48/ 50] FashionMNIST Time 0.01967s Training Accuracy: 97.27% Test Accuracy: 84.38%
[ 49/ 50] MNIST Time 0.01967s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 49/ 50] FashionMNIST Time 0.02651s Training Accuracy: 97.56% Test Accuracy: 84.38%
[ 50/ 50] MNIST Time 0.01965s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 50/ 50] FashionMNIST Time 0.01965s Training Accuracy: 97.46% Test Accuracy: 84.38%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 84.38%
[FINAL] FashionMNIST Training Accuracy: 97.46% Test Accuracy: 84.38%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6
CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+555.42.6
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0
Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.170 GiB / 4.750 GiB available)This page was generated using Literate.jl.