Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
CUDA.allowscalar(false)
Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
imgs, labels = dset(:train)[1:n_train]
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end
function load_datasets(n_train=1024, n_eval=32, batchsize=256)
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 4 methods)
Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end
Create and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)))
model = HyperNet(weight_generator, core_network)
return model
end
create_model (generic function with 1 method)
Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
[ 1/ 50] MNIST Time 78.46570s Training Accuracy: 60.06% Test Accuracy: 46.88%
[ 1/ 50] FashionMNIST Time 0.03264s Training Accuracy: 53.52% Test Accuracy: 53.12%
[ 2/ 50] MNIST Time 0.02956s Training Accuracy: 69.82% Test Accuracy: 56.25%
[ 2/ 50] FashionMNIST Time 0.03135s Training Accuracy: 55.57% Test Accuracy: 59.38%
[ 3/ 50] MNIST Time 0.03733s Training Accuracy: 76.07% Test Accuracy: 65.62%
[ 3/ 50] FashionMNIST Time 0.03408s Training Accuracy: 65.33% Test Accuracy: 59.38%
[ 4/ 50] MNIST Time 0.02706s Training Accuracy: 76.95% Test Accuracy: 65.62%
[ 4/ 50] FashionMNIST Time 0.02389s Training Accuracy: 68.07% Test Accuracy: 65.62%
[ 5/ 50] MNIST Time 0.02540s Training Accuracy: 82.52% Test Accuracy: 68.75%
[ 5/ 50] FashionMNIST Time 0.02518s Training Accuracy: 67.48% Test Accuracy: 62.50%
[ 6/ 50] MNIST Time 0.02163s Training Accuracy: 87.60% Test Accuracy: 75.00%
[ 6/ 50] FashionMNIST Time 0.02164s Training Accuracy: 74.32% Test Accuracy: 65.62%
[ 7/ 50] MNIST Time 0.02190s Training Accuracy: 90.14% Test Accuracy: 75.00%
[ 7/ 50] FashionMNIST Time 0.02069s Training Accuracy: 75.78% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.02029s Training Accuracy: 93.75% Test Accuracy: 81.25%
[ 8/ 50] FashionMNIST Time 0.02106s Training Accuracy: 75.88% Test Accuracy: 53.12%
[ 9/ 50] MNIST Time 0.02097s Training Accuracy: 93.26% Test Accuracy: 84.38%
[ 9/ 50] FashionMNIST Time 0.04402s Training Accuracy: 76.95% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.02088s Training Accuracy: 95.61% Test Accuracy: 81.25%
[ 10/ 50] FashionMNIST Time 0.02069s Training Accuracy: 80.86% Test Accuracy: 65.62%
[ 11/ 50] MNIST Time 0.02230s Training Accuracy: 97.36% Test Accuracy: 81.25%
[ 11/ 50] FashionMNIST Time 0.02114s Training Accuracy: 82.71% Test Accuracy: 75.00%
[ 12/ 50] MNIST Time 0.02115s Training Accuracy: 98.54% Test Accuracy: 84.38%
[ 12/ 50] FashionMNIST Time 0.02571s Training Accuracy: 84.28% Test Accuracy: 75.00%
[ 13/ 50] MNIST Time 0.02090s Training Accuracy: 98.73% Test Accuracy: 84.38%
[ 13/ 50] FashionMNIST Time 0.02114s Training Accuracy: 84.96% Test Accuracy: 78.12%
[ 14/ 50] MNIST Time 0.03184s Training Accuracy: 99.22% Test Accuracy: 84.38%
[ 14/ 50] FashionMNIST Time 0.02204s Training Accuracy: 85.35% Test Accuracy: 81.25%
[ 15/ 50] MNIST Time 0.02328s Training Accuracy: 99.51% Test Accuracy: 84.38%
[ 15/ 50] FashionMNIST Time 0.02878s Training Accuracy: 85.94% Test Accuracy: 75.00%
[ 16/ 50] MNIST Time 0.02129s Training Accuracy: 99.71% Test Accuracy: 84.38%
[ 16/ 50] FashionMNIST Time 0.02389s Training Accuracy: 86.62% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.02113s Training Accuracy: 99.80% Test Accuracy: 84.38%
[ 17/ 50] FashionMNIST Time 0.02242s Training Accuracy: 88.09% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.02085s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 18/ 50] FashionMNIST Time 0.02884s Training Accuracy: 88.57% Test Accuracy: 78.12%
[ 19/ 50] MNIST Time 0.02230s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 19/ 50] FashionMNIST Time 0.02097s Training Accuracy: 90.14% Test Accuracy: 81.25%
[ 20/ 50] MNIST Time 0.02478s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 20/ 50] FashionMNIST Time 0.02424s Training Accuracy: 90.92% Test Accuracy: 81.25%
[ 21/ 50] MNIST Time 0.04484s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 21/ 50] FashionMNIST Time 0.01982s Training Accuracy: 91.70% Test Accuracy: 81.25%
[ 22/ 50] MNIST Time 0.01972s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 22/ 50] FashionMNIST Time 0.01968s Training Accuracy: 92.29% Test Accuracy: 78.12%
[ 23/ 50] MNIST Time 0.01994s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 23/ 50] FashionMNIST Time 0.02597s Training Accuracy: 92.77% Test Accuracy: 81.25%
[ 24/ 50] MNIST Time 0.02466s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 24/ 50] FashionMNIST Time 0.03392s Training Accuracy: 93.16% Test Accuracy: 81.25%
[ 25/ 50] MNIST Time 0.02114s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 25/ 50] FashionMNIST Time 0.02454s Training Accuracy: 93.85% Test Accuracy: 81.25%
[ 26/ 50] MNIST Time 0.02240s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 26/ 50] FashionMNIST Time 0.03070s Training Accuracy: 94.34% Test Accuracy: 81.25%
[ 27/ 50] MNIST Time 0.04950s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 27/ 50] FashionMNIST Time 0.03720s Training Accuracy: 94.43% Test Accuracy: 81.25%
[ 28/ 50] MNIST Time 0.04596s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 28/ 50] FashionMNIST Time 0.02596s Training Accuracy: 94.82% Test Accuracy: 81.25%
[ 29/ 50] MNIST Time 0.02988s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 29/ 50] FashionMNIST Time 0.02997s Training Accuracy: 95.02% Test Accuracy: 78.12%
[ 30/ 50] MNIST Time 0.02919s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 30/ 50] FashionMNIST Time 0.03933s Training Accuracy: 95.31% Test Accuracy: 78.12%
[ 31/ 50] MNIST Time 0.02382s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 31/ 50] FashionMNIST Time 0.02843s Training Accuracy: 95.61% Test Accuracy: 78.12%
[ 32/ 50] MNIST Time 0.02070s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 32/ 50] FashionMNIST Time 0.02059s Training Accuracy: 95.70% Test Accuracy: 81.25%
[ 33/ 50] MNIST Time 0.02155s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 33/ 50] FashionMNIST Time 0.02189s Training Accuracy: 96.19% Test Accuracy: 81.25%
[ 34/ 50] MNIST Time 0.02188s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 34/ 50] FashionMNIST Time 0.02427s Training Accuracy: 96.68% Test Accuracy: 81.25%
[ 35/ 50] MNIST Time 0.02018s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 35/ 50] FashionMNIST Time 0.02131s Training Accuracy: 96.88% Test Accuracy: 81.25%
[ 36/ 50] MNIST Time 0.03078s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 36/ 50] FashionMNIST Time 0.02035s Training Accuracy: 96.88% Test Accuracy: 81.25%
[ 37/ 50] MNIST Time 0.02239s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 37/ 50] FashionMNIST Time 0.02143s Training Accuracy: 96.88% Test Accuracy: 78.12%
[ 38/ 50] MNIST Time 0.02121s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 38/ 50] FashionMNIST Time 0.02295s Training Accuracy: 97.17% Test Accuracy: 81.25%
[ 39/ 50] MNIST Time 0.02132s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 39/ 50] FashionMNIST Time 0.02120s Training Accuracy: 97.27% Test Accuracy: 81.25%
[ 40/ 50] MNIST Time 0.02339s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 40/ 50] FashionMNIST Time 0.02092s Training Accuracy: 97.46% Test Accuracy: 81.25%
[ 41/ 50] MNIST Time 0.02237s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 41/ 50] FashionMNIST Time 0.02104s Training Accuracy: 97.56% Test Accuracy: 81.25%
[ 42/ 50] MNIST Time 0.02215s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 42/ 50] FashionMNIST Time 0.02132s Training Accuracy: 97.56% Test Accuracy: 81.25%
[ 43/ 50] MNIST Time 0.02194s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 43/ 50] FashionMNIST Time 0.02158s Training Accuracy: 97.75% Test Accuracy: 78.12%
[ 44/ 50] MNIST Time 0.02189s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 44/ 50] FashionMNIST Time 0.02944s Training Accuracy: 97.85% Test Accuracy: 78.12%
[ 45/ 50] MNIST Time 0.02178s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 45/ 50] FashionMNIST Time 0.02185s Training Accuracy: 98.24% Test Accuracy: 81.25%
[ 46/ 50] MNIST Time 0.02149s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 46/ 50] FashionMNIST Time 0.02176s Training Accuracy: 98.24% Test Accuracy: 81.25%
[ 47/ 50] MNIST Time 0.02047s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 47/ 50] FashionMNIST Time 0.02223s Training Accuracy: 98.14% Test Accuracy: 81.25%
[ 48/ 50] MNIST Time 0.01997s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 48/ 50] FashionMNIST Time 0.02178s Training Accuracy: 98.44% Test Accuracy: 81.25%
[ 49/ 50] MNIST Time 0.02526s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 49/ 50] FashionMNIST Time 0.02135s Training Accuracy: 98.44% Test Accuracy: 81.25%
[ 50/ 50] MNIST Time 0.02100s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 50/ 50] FashionMNIST Time 0.02085s Training Accuracy: 98.73% Test Accuracy: 81.25%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 84.38%
[FINAL] FashionMNIST Training Accuracy: 98.73% Test Accuracy: 81.25%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0
Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.170 GiB / 4.750 GiB available)
This page was generated using Literate.jl.