Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
CUDA.allowscalar(false)Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
imgs, labels = dset(:train)[1:n_train]
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end
function load_datasets(n_train=1024, n_eval=32, batchsize=256)
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endload_datasets (generic function with 4 methods)Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endHyperNet (generic function with 1 method)Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
endCreate and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)))
model = HyperNet(weight_generator, core_network)
return model
endcreate_model (generic function with 1 method)Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endaccuracy (generic function with 1 method)Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 69.55687s Training Accuracy: 60.06% Test Accuracy: 59.38%
[ 1/ 50] FashionMNIST Time 0.02189s Training Accuracy: 53.03% Test Accuracy: 50.00%
[ 2/ 50] MNIST Time 0.02383s Training Accuracy: 65.23% Test Accuracy: 59.38%
[ 2/ 50] FashionMNIST Time 0.02378s Training Accuracy: 57.42% Test Accuracy: 56.25%
[ 3/ 50] MNIST Time 0.02428s Training Accuracy: 75.29% Test Accuracy: 59.38%
[ 3/ 50] FashionMNIST Time 0.02304s Training Accuracy: 64.75% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.02243s Training Accuracy: 79.98% Test Accuracy: 65.62%
[ 4/ 50] FashionMNIST Time 0.02277s Training Accuracy: 59.86% Test Accuracy: 62.50%
[ 5/ 50] MNIST Time 0.02267s Training Accuracy: 82.42% Test Accuracy: 68.75%
[ 5/ 50] FashionMNIST Time 0.05914s Training Accuracy: 71.68% Test Accuracy: 71.88%
[ 6/ 50] MNIST Time 0.02185s Training Accuracy: 86.72% Test Accuracy: 75.00%
[ 6/ 50] FashionMNIST Time 0.02095s Training Accuracy: 75.39% Test Accuracy: 62.50%
[ 7/ 50] MNIST Time 0.02061s Training Accuracy: 87.99% Test Accuracy: 81.25%
[ 7/ 50] FashionMNIST Time 0.02085s Training Accuracy: 73.63% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.02074s Training Accuracy: 89.36% Test Accuracy: 78.12%
[ 8/ 50] FashionMNIST Time 0.02045s Training Accuracy: 75.20% Test Accuracy: 68.75%
[ 9/ 50] MNIST Time 0.02115s Training Accuracy: 93.16% Test Accuracy: 84.38%
[ 9/ 50] FashionMNIST Time 0.02190s Training Accuracy: 77.73% Test Accuracy: 71.88%
[ 10/ 50] MNIST Time 0.02052s Training Accuracy: 95.02% Test Accuracy: 84.38%
[ 10/ 50] FashionMNIST Time 0.02083s Training Accuracy: 80.37% Test Accuracy: 71.88%
[ 11/ 50] MNIST Time 0.02098s Training Accuracy: 96.19% Test Accuracy: 84.38%
[ 11/ 50] FashionMNIST Time 0.02052s Training Accuracy: 80.57% Test Accuracy: 75.00%
[ 12/ 50] MNIST Time 0.02090s Training Accuracy: 97.75% Test Accuracy: 84.38%
[ 12/ 50] FashionMNIST Time 0.02112s Training Accuracy: 83.20% Test Accuracy: 78.12%
[ 13/ 50] MNIST Time 0.02071s Training Accuracy: 98.05% Test Accuracy: 81.25%
[ 13/ 50] FashionMNIST Time 0.02096s Training Accuracy: 82.71% Test Accuracy: 75.00%
[ 14/ 50] MNIST Time 0.02879s Training Accuracy: 99.22% Test Accuracy: 78.12%
[ 14/ 50] FashionMNIST Time 0.02107s Training Accuracy: 83.79% Test Accuracy: 68.75%
[ 15/ 50] MNIST Time 0.02113s Training Accuracy: 99.41% Test Accuracy: 81.25%
[ 15/ 50] FashionMNIST Time 0.02070s Training Accuracy: 84.96% Test Accuracy: 68.75%
[ 16/ 50] MNIST Time 0.02017s Training Accuracy: 99.51% Test Accuracy: 81.25%
[ 16/ 50] FashionMNIST Time 0.02235s Training Accuracy: 85.94% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.02146s Training Accuracy: 99.71% Test Accuracy: 81.25%
[ 17/ 50] FashionMNIST Time 0.02093s Training Accuracy: 86.04% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.02058s Training Accuracy: 99.80% Test Accuracy: 81.25%
[ 18/ 50] FashionMNIST Time 0.02758s Training Accuracy: 87.79% Test Accuracy: 71.88%
[ 19/ 50] MNIST Time 0.02107s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 19/ 50] FashionMNIST Time 0.02044s Training Accuracy: 88.77% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.02072s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 20/ 50] FashionMNIST Time 0.02261s Training Accuracy: 89.26% Test Accuracy: 75.00%
[ 21/ 50] MNIST Time 0.02064s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 21/ 50] FashionMNIST Time 0.02116s Training Accuracy: 88.96% Test Accuracy: 75.00%
[ 22/ 50] MNIST Time 0.02121s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 22/ 50] FashionMNIST Time 0.02070s Training Accuracy: 89.55% Test Accuracy: 75.00%
[ 23/ 50] MNIST Time 0.02843s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 23/ 50] FashionMNIST Time 0.02109s Training Accuracy: 88.57% Test Accuracy: 71.88%
[ 24/ 50] MNIST Time 0.02119s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 24/ 50] FashionMNIST Time 0.02109s Training Accuracy: 90.33% Test Accuracy: 71.88%
[ 25/ 50] MNIST Time 0.02037s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 25/ 50] FashionMNIST Time 0.02032s Training Accuracy: 90.23% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.02075s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 26/ 50] FashionMNIST Time 0.02086s Training Accuracy: 91.21% Test Accuracy: 71.88%
[ 27/ 50] MNIST Time 0.02047s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 27/ 50] FashionMNIST Time 0.02619s Training Accuracy: 90.92% Test Accuracy: 71.88%
[ 28/ 50] MNIST Time 0.02145s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 28/ 50] FashionMNIST Time 0.02185s Training Accuracy: 91.31% Test Accuracy: 75.00%
[ 29/ 50] MNIST Time 0.02313s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 29/ 50] FashionMNIST Time 0.02258s Training Accuracy: 92.09% Test Accuracy: 75.00%
[ 30/ 50] MNIST Time 0.02250s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 30/ 50] FashionMNIST Time 0.02208s Training Accuracy: 92.38% Test Accuracy: 71.88%
[ 31/ 50] MNIST Time 0.02188s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 31/ 50] FashionMNIST Time 0.03016s Training Accuracy: 93.07% Test Accuracy: 71.88%
[ 32/ 50] MNIST Time 0.02225s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 32/ 50] FashionMNIST Time 0.02214s Training Accuracy: 93.75% Test Accuracy: 71.88%
[ 33/ 50] MNIST Time 0.02163s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 33/ 50] FashionMNIST Time 0.02214s Training Accuracy: 93.36% Test Accuracy: 75.00%
[ 34/ 50] MNIST Time 0.02208s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 34/ 50] FashionMNIST Time 0.02359s Training Accuracy: 94.04% Test Accuracy: 75.00%
[ 35/ 50] MNIST Time 0.02124s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 35/ 50] FashionMNIST Time 0.02145s Training Accuracy: 94.43% Test Accuracy: 75.00%
[ 36/ 50] MNIST Time 0.02950s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 36/ 50] FashionMNIST Time 0.02093s Training Accuracy: 94.34% Test Accuracy: 78.12%
[ 37/ 50] MNIST Time 0.02137s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 37/ 50] FashionMNIST Time 0.02169s Training Accuracy: 94.43% Test Accuracy: 75.00%
[ 38/ 50] MNIST Time 0.02153s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 38/ 50] FashionMNIST Time 0.02083s Training Accuracy: 94.14% Test Accuracy: 75.00%
[ 39/ 50] MNIST Time 0.02079s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 39/ 50] FashionMNIST Time 0.02058s Training Accuracy: 94.82% Test Accuracy: 75.00%
[ 40/ 50] MNIST Time 0.02087s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 40/ 50] FashionMNIST Time 0.02789s Training Accuracy: 95.51% Test Accuracy: 75.00%
[ 41/ 50] MNIST Time 0.02113s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 41/ 50] FashionMNIST Time 0.02380s Training Accuracy: 95.31% Test Accuracy: 75.00%
[ 42/ 50] MNIST Time 0.02108s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 42/ 50] FashionMNIST Time 0.02216s Training Accuracy: 95.41% Test Accuracy: 78.12%
[ 43/ 50] MNIST Time 0.02263s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 43/ 50] FashionMNIST Time 0.02486s Training Accuracy: 95.80% Test Accuracy: 75.00%
[ 44/ 50] MNIST Time 0.02157s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 44/ 50] FashionMNIST Time 0.02187s Training Accuracy: 95.90% Test Accuracy: 78.12%
[ 45/ 50] MNIST Time 0.02092s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 45/ 50] FashionMNIST Time 0.02178s Training Accuracy: 95.80% Test Accuracy: 78.12%
[ 46/ 50] MNIST Time 0.02175s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 46/ 50] FashionMNIST Time 0.02131s Training Accuracy: 96.29% Test Accuracy: 75.00%
[ 47/ 50] MNIST Time 0.02115s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 47/ 50] FashionMNIST Time 0.02080s Training Accuracy: 95.61% Test Accuracy: 75.00%
[ 48/ 50] MNIST Time 0.02166s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 48/ 50] FashionMNIST Time 0.02175s Training Accuracy: 96.19% Test Accuracy: 75.00%
[ 49/ 50] MNIST Time 0.02912s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 49/ 50] FashionMNIST Time 0.02256s Training Accuracy: 96.19% Test Accuracy: 81.25%
[ 50/ 50] MNIST Time 0.02252s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 50/ 50] FashionMNIST Time 0.02148s Training Accuracy: 96.88% Test Accuracy: 75.00%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 81.25%
[FINAL] FashionMNIST Training Accuracy: 96.88% Test Accuracy: 75.00%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0
Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.170 GiB / 4.750 GiB available)This page was generated using Literate.jl.