Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
CUDA.allowscalar(false)Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
imgs, labels = dset(:train)[1:n_train]
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end
function load_datasets(n_train=1024, n_eval=32, batchsize=256)
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endload_datasets (generic function with 4 methods)Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endHyperNet (generic function with 1 method)Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
endCreate and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)))
model = HyperNet(weight_generator, core_network)
return model
endcreate_model (generic function with 1 method)Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endaccuracy (generic function with 1 method)Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))
### Lets train the model
nepochs = 25
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d] \t %12s \t Time %.5fs \t Training Accuracy: %.2f%% \t Test \
Accuracy: %.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL] \t %12s \t Training Accuracy: %.2f%% \t Test Accuracy: \
%.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 25] MNIST Time 68.75160s Training Accuracy: 23.24% Test Accuracy: 25.00%
[ 1/ 25] FashionMNIST Time 0.02731s Training Accuracy: 30.76% Test Accuracy: 18.75%
[ 2/ 25] MNIST Time 0.02331s Training Accuracy: 48.34% Test Accuracy: 34.38%
[ 2/ 25] FashionMNIST Time 0.02395s Training Accuracy: 52.54% Test Accuracy: 43.75%
[ 3/ 25] MNIST Time 0.02325s Training Accuracy: 59.08% Test Accuracy: 56.25%
[ 3/ 25] FashionMNIST Time 0.02429s Training Accuracy: 54.69% Test Accuracy: 46.88%
[ 4/ 25] MNIST Time 0.02403s Training Accuracy: 67.48% Test Accuracy: 53.12%
[ 4/ 25] FashionMNIST Time 0.02461s Training Accuracy: 64.36% Test Accuracy: 43.75%
[ 5/ 25] MNIST Time 0.02330s Training Accuracy: 74.12% Test Accuracy: 65.62%
[ 5/ 25] FashionMNIST Time 0.02051s Training Accuracy: 70.12% Test Accuracy: 62.50%
[ 6/ 25] MNIST Time 0.02121s Training Accuracy: 78.22% Test Accuracy: 56.25%
[ 6/ 25] FashionMNIST Time 0.02064s Training Accuracy: 74.22% Test Accuracy: 71.88%
[ 7/ 25] MNIST Time 0.02084s Training Accuracy: 82.13% Test Accuracy: 65.62%
[ 7/ 25] FashionMNIST Time 0.02181s Training Accuracy: 78.42% Test Accuracy: 62.50%
[ 8/ 25] MNIST Time 0.02076s Training Accuracy: 86.62% Test Accuracy: 65.62%
[ 8/ 25] FashionMNIST Time 0.02141s Training Accuracy: 81.05% Test Accuracy: 62.50%
[ 9/ 25] MNIST Time 0.02179s Training Accuracy: 86.04% Test Accuracy: 65.62%
[ 9/ 25] FashionMNIST Time 0.03888s Training Accuracy: 81.54% Test Accuracy: 62.50%
[ 10/ 25] MNIST Time 0.02146s Training Accuracy: 90.04% Test Accuracy: 68.75%
[ 10/ 25] FashionMNIST Time 0.02209s Training Accuracy: 81.45% Test Accuracy: 56.25%
[ 11/ 25] MNIST Time 0.02179s Training Accuracy: 93.07% Test Accuracy: 68.75%
[ 11/ 25] FashionMNIST Time 0.02217s Training Accuracy: 82.62% Test Accuracy: 59.38%
[ 12/ 25] MNIST Time 0.02153s Training Accuracy: 92.97% Test Accuracy: 65.62%
[ 12/ 25] FashionMNIST Time 0.02145s Training Accuracy: 82.62% Test Accuracy: 71.88%
[ 13/ 25] MNIST Time 0.02114s Training Accuracy: 94.73% Test Accuracy: 62.50%
[ 13/ 25] FashionMNIST Time 0.02240s Training Accuracy: 83.01% Test Accuracy: 62.50%
[ 14/ 25] MNIST Time 0.02893s Training Accuracy: 96.19% Test Accuracy: 65.62%
[ 14/ 25] FashionMNIST Time 0.02083s Training Accuracy: 87.11% Test Accuracy: 68.75%
[ 15/ 25] MNIST Time 0.02143s Training Accuracy: 94.92% Test Accuracy: 68.75%
[ 15/ 25] FashionMNIST Time 0.02212s Training Accuracy: 83.20% Test Accuracy: 62.50%
[ 16/ 25] MNIST Time 0.02130s Training Accuracy: 97.27% Test Accuracy: 62.50%
[ 16/ 25] FashionMNIST Time 0.02078s Training Accuracy: 84.77% Test Accuracy: 65.62%
[ 17/ 25] MNIST Time 0.02120s Training Accuracy: 98.14% Test Accuracy: 62.50%
[ 17/ 25] FashionMNIST Time 0.02212s Training Accuracy: 87.89% Test Accuracy: 65.62%
[ 18/ 25] MNIST Time 0.02139s Training Accuracy: 98.24% Test Accuracy: 62.50%
[ 18/ 25] FashionMNIST Time 0.02761s Training Accuracy: 87.70% Test Accuracy: 65.62%
[ 19/ 25] MNIST Time 0.02073s Training Accuracy: 98.83% Test Accuracy: 62.50%
[ 19/ 25] FashionMNIST Time 0.02146s Training Accuracy: 87.99% Test Accuracy: 62.50%
[ 20/ 25] MNIST Time 0.02160s Training Accuracy: 99.22% Test Accuracy: 62.50%
[ 20/ 25] FashionMNIST Time 0.02161s Training Accuracy: 89.26% Test Accuracy: 71.88%
[ 21/ 25] MNIST Time 0.02276s Training Accuracy: 99.12% Test Accuracy: 62.50%
[ 21/ 25] FashionMNIST Time 0.02150s Training Accuracy: 88.09% Test Accuracy: 59.38%
[ 22/ 25] MNIST Time 0.02120s Training Accuracy: 97.75% Test Accuracy: 62.50%
[ 22/ 25] FashionMNIST Time 0.02150s Training Accuracy: 90.43% Test Accuracy: 68.75%
[ 23/ 25] MNIST Time 0.02125s Training Accuracy: 99.32% Test Accuracy: 62.50%
[ 23/ 25] FashionMNIST Time 0.02158s Training Accuracy: 88.87% Test Accuracy: 65.62%
[ 24/ 25] MNIST Time 0.02130s Training Accuracy: 99.51% Test Accuracy: 62.50%
[ 24/ 25] FashionMNIST Time 0.02071s Training Accuracy: 90.04% Test Accuracy: 62.50%
[ 25/ 25] MNIST Time 0.02275s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 25/ 25] FashionMNIST Time 0.02149s Training Accuracy: 91.11% Test Accuracy: 65.62%
[FINAL] MNIST Training Accuracy: 99.80% Test Accuracy: 62.50%
[FINAL] FashionMNIST Training Accuracy: 91.11% Test Accuracy: 65.62%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6
CUDA libraries:
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6
Julia packages:
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0
Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.170 GiB / 4.750 GiB available)This page was generated using Literate.jl.