Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
CUDA.allowscalar(false)
Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
imgs, labels = dset(:train)[1:n_train]
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end
function load_datasets(n_train=1024, n_eval=32, batchsize=256)
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 4 methods)
Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end
Create and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)))
model = HyperNet(weight_generator, core_network)
return model
end
create_model (generic function with 1 method)
Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))
### Lets train the model
nepochs = 25
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d] \t %12s \t Time %.5fs \t Training Accuracy: %.2f%% \t Test \
Accuracy: %.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL] \t %12s \t Training Accuracy: %.2f%% \t Test Accuracy: \
%.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
[ 1/ 25] MNIST Time 69.51804s Training Accuracy: 23.24% Test Accuracy: 25.00%
[ 1/ 25] FashionMNIST Time 0.01361s Training Accuracy: 28.22% Test Accuracy: 25.00%
[ 2/ 25] MNIST Time 0.01258s Training Accuracy: 47.75% Test Accuracy: 31.25%
[ 2/ 25] FashionMNIST Time 0.01260s Training Accuracy: 51.17% Test Accuracy: 37.50%
[ 3/ 25] MNIST Time 0.01241s Training Accuracy: 59.57% Test Accuracy: 59.38%
[ 3/ 25] FashionMNIST Time 0.04608s Training Accuracy: 62.40% Test Accuracy: 62.50%
[ 4/ 25] MNIST Time 0.01120s Training Accuracy: 64.84% Test Accuracy: 43.75%
[ 4/ 25] FashionMNIST Time 0.01117s Training Accuracy: 65.14% Test Accuracy: 50.00%
[ 5/ 25] MNIST Time 0.01115s Training Accuracy: 76.56% Test Accuracy: 56.25%
[ 5/ 25] FashionMNIST Time 0.01113s Training Accuracy: 73.63% Test Accuracy: 65.62%
[ 6/ 25] MNIST Time 0.01244s Training Accuracy: 78.32% Test Accuracy: 62.50%
[ 6/ 25] FashionMNIST Time 0.01249s Training Accuracy: 74.02% Test Accuracy: 68.75%
[ 7/ 25] MNIST Time 0.01266s Training Accuracy: 83.89% Test Accuracy: 65.62%
[ 7/ 25] FashionMNIST Time 0.01240s Training Accuracy: 78.03% Test Accuracy: 62.50%
[ 8/ 25] MNIST Time 0.01271s Training Accuracy: 86.33% Test Accuracy: 65.62%
[ 8/ 25] FashionMNIST Time 0.01274s Training Accuracy: 80.66% Test Accuracy: 71.88%
[ 9/ 25] MNIST Time 0.01486s Training Accuracy: 88.18% Test Accuracy: 65.62%
[ 9/ 25] FashionMNIST Time 0.01454s Training Accuracy: 81.54% Test Accuracy: 68.75%
[ 10/ 25] MNIST Time 0.01112s Training Accuracy: 91.99% Test Accuracy: 68.75%
[ 10/ 25] FashionMNIST Time 0.01099s Training Accuracy: 83.50% Test Accuracy: 68.75%
[ 11/ 25] MNIST Time 0.01093s Training Accuracy: 91.80% Test Accuracy: 68.75%
[ 11/ 25] FashionMNIST Time 0.01093s Training Accuracy: 84.47% Test Accuracy: 71.88%
[ 12/ 25] MNIST Time 0.01097s Training Accuracy: 95.41% Test Accuracy: 62.50%
[ 12/ 25] FashionMNIST Time 0.01096s Training Accuracy: 87.30% Test Accuracy: 65.62%
[ 13/ 25] MNIST Time 0.01091s Training Accuracy: 95.90% Test Accuracy: 62.50%
[ 13/ 25] FashionMNIST Time 0.01094s Training Accuracy: 87.60% Test Accuracy: 78.12%
[ 14/ 25] MNIST Time 0.01211s Training Accuracy: 96.29% Test Accuracy: 62.50%
[ 14/ 25] FashionMNIST Time 0.01113s Training Accuracy: 88.57% Test Accuracy: 75.00%
[ 15/ 25] MNIST Time 0.01075s Training Accuracy: 97.36% Test Accuracy: 65.62%
[ 15/ 25] FashionMNIST Time 0.01072s Training Accuracy: 88.09% Test Accuracy: 62.50%
[ 16/ 25] MNIST Time 0.01072s Training Accuracy: 97.85% Test Accuracy: 62.50%
[ 16/ 25] FashionMNIST Time 0.01187s Training Accuracy: 90.04% Test Accuracy: 68.75%
[ 17/ 25] MNIST Time 0.01168s Training Accuracy: 97.36% Test Accuracy: 65.62%
[ 17/ 25] FashionMNIST Time 0.01150s Training Accuracy: 88.09% Test Accuracy: 68.75%
[ 18/ 25] MNIST Time 0.01068s Training Accuracy: 97.56% Test Accuracy: 62.50%
[ 18/ 25] FashionMNIST Time 0.01184s Training Accuracy: 87.70% Test Accuracy: 59.38%
[ 19/ 25] MNIST Time 0.01088s Training Accuracy: 98.73% Test Accuracy: 65.62%
[ 19/ 25] FashionMNIST Time 0.02154s Training Accuracy: 88.67% Test Accuracy: 65.62%
[ 20/ 25] MNIST Time 0.01196s Training Accuracy: 98.93% Test Accuracy: 65.62%
[ 20/ 25] FashionMNIST Time 0.01073s Training Accuracy: 91.99% Test Accuracy: 71.88%
[ 21/ 25] MNIST Time 0.01076s Training Accuracy: 99.02% Test Accuracy: 62.50%
[ 21/ 25] FashionMNIST Time 0.01153s Training Accuracy: 90.43% Test Accuracy: 68.75%
[ 22/ 25] MNIST Time 0.01068s Training Accuracy: 98.93% Test Accuracy: 65.62%
[ 22/ 25] FashionMNIST Time 0.01180s Training Accuracy: 89.16% Test Accuracy: 59.38%
[ 23/ 25] MNIST Time 0.01151s Training Accuracy: 99.02% Test Accuracy: 65.62%
[ 23/ 25] FashionMNIST Time 0.01228s Training Accuracy: 91.70% Test Accuracy: 68.75%
[ 24/ 25] MNIST Time 0.01132s Training Accuracy: 98.44% Test Accuracy: 62.50%
[ 24/ 25] FashionMNIST Time 0.01064s Training Accuracy: 85.45% Test Accuracy: 68.75%
[ 25/ 25] MNIST Time 0.01101s Training Accuracy: 97.75% Test Accuracy: 62.50%
[ 25/ 25] FashionMNIST Time 0.01145s Training Accuracy: 81.45% Test Accuracy: 62.50%
[FINAL] MNIST Training Accuracy: 98.24% Test Accuracy: 62.50%
[FINAL] FashionMNIST Training Accuracy: 81.45% Test Accuracy: 62.50%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6
CUDA libraries:
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6
Julia packages:
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0
Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
Preferences:
- CUDA_Driver_jll.compat: false
2 devices:
0: Quadro RTX 5000 (sm_75, 12.451 GiB / 16.000 GiB available)
1: Quadro RTX 5000 (sm_75, 15.556 GiB / 16.000 GiB available)
This page was generated using Literate.jl.