Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
CUDA.allowscalar(false)
Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
imgs, labels = dset(:train)[1:n_train]
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end
function load_datasets(n_train=1024, n_eval=32, batchsize=256)
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 4 methods)
Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end
Create and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)))
model = HyperNet(weight_generator, core_network)
return model
end
create_model (generic function with 1 method)
Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
[ 1/ 50] MNIST Time 72.33321s Training Accuracy: 57.91% Test Accuracy: 53.12%
[ 1/ 50] FashionMNIST Time 0.02845s Training Accuracy: 48.63% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.02912s Training Accuracy: 69.43% Test Accuracy: 65.62%
[ 2/ 50] FashionMNIST Time 0.02984s Training Accuracy: 63.57% Test Accuracy: 56.25%
[ 3/ 50] MNIST Time 0.02899s Training Accuracy: 77.73% Test Accuracy: 68.75%
[ 3/ 50] FashionMNIST Time 0.02873s Training Accuracy: 64.94% Test Accuracy: 62.50%
[ 4/ 50] MNIST Time 0.02721s Training Accuracy: 82.71% Test Accuracy: 65.62%
[ 4/ 50] FashionMNIST Time 0.02093s Training Accuracy: 65.82% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.02106s Training Accuracy: 85.84% Test Accuracy: 68.75%
[ 5/ 50] FashionMNIST Time 0.02239s Training Accuracy: 71.68% Test Accuracy: 53.12%
[ 6/ 50] MNIST Time 0.02133s Training Accuracy: 86.52% Test Accuracy: 71.88%
[ 6/ 50] FashionMNIST Time 0.02146s Training Accuracy: 76.46% Test Accuracy: 53.12%
[ 7/ 50] MNIST Time 0.02169s Training Accuracy: 89.36% Test Accuracy: 78.12%
[ 7/ 50] FashionMNIST Time 0.02320s Training Accuracy: 77.34% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.02666s Training Accuracy: 92.58% Test Accuracy: 81.25%
[ 8/ 50] FashionMNIST Time 0.05175s Training Accuracy: 78.61% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.02292s Training Accuracy: 94.14% Test Accuracy: 78.12%
[ 9/ 50] FashionMNIST Time 0.02085s Training Accuracy: 77.44% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.02117s Training Accuracy: 95.70% Test Accuracy: 78.12%
[ 10/ 50] FashionMNIST Time 0.02135s Training Accuracy: 81.84% Test Accuracy: 68.75%
[ 11/ 50] MNIST Time 0.02182s Training Accuracy: 96.39% Test Accuracy: 81.25%
[ 11/ 50] FashionMNIST Time 0.02111s Training Accuracy: 81.84% Test Accuracy: 65.62%
[ 12/ 50] MNIST Time 0.02164s Training Accuracy: 97.75% Test Accuracy: 81.25%
[ 12/ 50] FashionMNIST Time 0.02159s Training Accuracy: 82.71% Test Accuracy: 68.75%
[ 13/ 50] MNIST Time 0.02879s Training Accuracy: 97.85% Test Accuracy: 78.12%
[ 13/ 50] FashionMNIST Time 0.02271s Training Accuracy: 84.28% Test Accuracy: 65.62%
[ 14/ 50] MNIST Time 0.02079s Training Accuracy: 98.73% Test Accuracy: 78.12%
[ 14/ 50] FashionMNIST Time 0.02091s Training Accuracy: 83.11% Test Accuracy: 62.50%
[ 15/ 50] MNIST Time 0.02113s Training Accuracy: 99.22% Test Accuracy: 78.12%
[ 15/ 50] FashionMNIST Time 0.02264s Training Accuracy: 83.79% Test Accuracy: 65.62%
[ 16/ 50] MNIST Time 0.02123s Training Accuracy: 99.32% Test Accuracy: 78.12%
[ 16/ 50] FashionMNIST Time 0.02098s Training Accuracy: 84.96% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.02143s Training Accuracy: 99.51% Test Accuracy: 78.12%
[ 17/ 50] FashionMNIST Time 0.02652s Training Accuracy: 85.64% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.02070s Training Accuracy: 99.71% Test Accuracy: 78.12%
[ 18/ 50] FashionMNIST Time 0.02071s Training Accuracy: 85.55% Test Accuracy: 65.62%
[ 19/ 50] MNIST Time 0.02077s Training Accuracy: 99.90% Test Accuracy: 78.12%
[ 19/ 50] FashionMNIST Time 0.02108s Training Accuracy: 86.82% Test Accuracy: 71.88%
[ 20/ 50] MNIST Time 0.02127s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 20/ 50] FashionMNIST Time 0.02089s Training Accuracy: 87.21% Test Accuracy: 62.50%
[ 21/ 50] MNIST Time 0.02149s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 21/ 50] FashionMNIST Time 0.02119s Training Accuracy: 87.79% Test Accuracy: 68.75%
[ 22/ 50] MNIST Time 0.02002s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 22/ 50] FashionMNIST Time 0.02091s Training Accuracy: 89.06% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.02035s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 23/ 50] FashionMNIST Time 0.02051s Training Accuracy: 90.43% Test Accuracy: 68.75%
[ 24/ 50] MNIST Time 0.02092s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 24/ 50] FashionMNIST Time 0.02109s Training Accuracy: 91.02% Test Accuracy: 68.75%
[ 25/ 50] MNIST Time 0.02049s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 25/ 50] FashionMNIST Time 0.02134s Training Accuracy: 92.48% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.02814s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 26/ 50] FashionMNIST Time 0.02124s Training Accuracy: 91.41% Test Accuracy: 71.88%
[ 27/ 50] MNIST Time 0.02085s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 27/ 50] FashionMNIST Time 0.02181s Training Accuracy: 92.68% Test Accuracy: 78.12%
[ 28/ 50] MNIST Time 0.02123s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 28/ 50] FashionMNIST Time 0.02125s Training Accuracy: 92.68% Test Accuracy: 78.12%
[ 29/ 50] MNIST Time 0.02102s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 29/ 50] FashionMNIST Time 0.02081s Training Accuracy: 93.65% Test Accuracy: 75.00%
[ 30/ 50] MNIST Time 0.02106s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 30/ 50] FashionMNIST Time 0.02752s Training Accuracy: 94.14% Test Accuracy: 78.12%
[ 31/ 50] MNIST Time 0.02086s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 31/ 50] FashionMNIST Time 0.02075s Training Accuracy: 94.53% Test Accuracy: 78.12%
[ 32/ 50] MNIST Time 0.02064s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 32/ 50] FashionMNIST Time 0.02118s Training Accuracy: 94.63% Test Accuracy: 78.12%
[ 33/ 50] MNIST Time 0.02125s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 33/ 50] FashionMNIST Time 0.02206s Training Accuracy: 94.34% Test Accuracy: 78.12%
[ 34/ 50] MNIST Time 0.02159s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 34/ 50] FashionMNIST Time 0.02064s Training Accuracy: 94.73% Test Accuracy: 78.12%
[ 35/ 50] MNIST Time 0.02600s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 35/ 50] FashionMNIST Time 0.02224s Training Accuracy: 95.12% Test Accuracy: 78.12%
[ 36/ 50] MNIST Time 0.02068s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 36/ 50] FashionMNIST Time 0.02117s Training Accuracy: 95.31% Test Accuracy: 78.12%
[ 37/ 50] MNIST Time 0.02063s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 37/ 50] FashionMNIST Time 0.02126s Training Accuracy: 95.31% Test Accuracy: 78.12%
[ 38/ 50] MNIST Time 0.02142s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 38/ 50] FashionMNIST Time 0.02136s Training Accuracy: 95.90% Test Accuracy: 78.12%
[ 39/ 50] MNIST Time 0.02099s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 39/ 50] FashionMNIST Time 0.02024s Training Accuracy: 95.80% Test Accuracy: 78.12%
[ 40/ 50] MNIST Time 0.02162s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 40/ 50] FashionMNIST Time 0.02015s Training Accuracy: 96.19% Test Accuracy: 78.12%
[ 41/ 50] MNIST Time 0.02085s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 41/ 50] FashionMNIST Time 0.02181s Training Accuracy: 96.09% Test Accuracy: 78.12%
[ 42/ 50] MNIST Time 0.02032s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 42/ 50] FashionMNIST Time 0.02672s Training Accuracy: 96.39% Test Accuracy: 78.12%
[ 43/ 50] MNIST Time 0.02700s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 43/ 50] FashionMNIST Time 0.02651s Training Accuracy: 96.39% Test Accuracy: 78.12%
[ 44/ 50] MNIST Time 0.02030s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 44/ 50] FashionMNIST Time 0.02073s Training Accuracy: 96.78% Test Accuracy: 78.12%
[ 45/ 50] MNIST Time 0.02073s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 45/ 50] FashionMNIST Time 0.02101s Training Accuracy: 96.78% Test Accuracy: 78.12%
[ 46/ 50] MNIST Time 0.02036s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 46/ 50] FashionMNIST Time 0.02085s Training Accuracy: 96.88% Test Accuracy: 78.12%
[ 47/ 50] MNIST Time 0.02098s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 47/ 50] FashionMNIST Time 0.02038s Training Accuracy: 96.88% Test Accuracy: 78.12%
[ 48/ 50] MNIST Time 0.02740s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 48/ 50] FashionMNIST Time 0.02044s Training Accuracy: 97.17% Test Accuracy: 78.12%
[ 49/ 50] MNIST Time 0.02146s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 49/ 50] FashionMNIST Time 0.02053s Training Accuracy: 96.97% Test Accuracy: 78.12%
[ 50/ 50] MNIST Time 0.02046s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 50/ 50] FashionMNIST Time 0.02094s Training Accuracy: 97.27% Test Accuracy: 78.12%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 78.12%
[FINAL] FashionMNIST Training Accuracy: 97.27% Test Accuracy: 78.12%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0
Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.170 GiB / 4.750 GiB available)
This page was generated using Literate.jl.