Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ADTypes, ComponentArrays, LuxAMDGPU, LuxCUDA, MLDatasets, MLUtils, OneHotArrays,
Optimisers, Printf, Random, Setfield, Statistics, Zygote
CUDA.allowscalar(false)
Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
imgs, labels = dset(:train)[1:n_train]
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end
function load_datasets(n_train=1024, n_eval=32, batchsize=256)
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 4 methods)
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::Lux.AbstractExplicitLayer,
core_network::Lux.AbstractExplicitLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end
Create and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)))
model = HyperNet(weight_generator, core_network)
return model
end
create_model (generic function with 1 method)
Define Utility Functions
julia
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))
function loss(model, ps, st, (data_idx, x, y))
y_pred, st = model((data_idx, x), ps, st)
return logitcrossentropy(y_pred, y), st, (;)
end
function accuracy(model, ps, st, dataloader, data_idx, gdev=gpu_device())
total_correct, total = 0, 0
st = Lux.testmode(st)
cpu_dev = cpu_device()
for (x, y) in dataloader
x = x |> gdev
y = y |> gdev
target_class = onecold(cpu_dev(y))
predicted_class = onecold(cpu_dev(model((data_idx, x), ps, st)[1]))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 2 methods)
Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
train_state = Lux.Experimental.TrainState(
rng, model, Adam(3.0f-4); transform_variables=dev)
### Lets train the model
nepochs = 10
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx]
stime = time()
for (x, y) in train_dataloader
x = x |> dev
y = y |> dev
(gs, _, _, train_state) = Lux.Experimental.compute_gradients(
AutoZygote(), loss, (data_idx, x, y), train_state)
train_state = Lux.Experimental.apply_gradients!(train_state, gs)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters, train_state.states,
train_dataloader, data_idx, dev) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters, train_state.states,
test_dataloader, data_idx, dev) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d] \t %12s \t Time %.5fs \t Training Accuracy: %.2f%% \t Test Accuracy: %.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx]
train_acc = round(
accuracy(model, train_state.parameters, train_state.states,
train_dataloader, data_idx, dev) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters, train_state.states,
test_dataloader, data_idx, dev) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL] \t %12s \t Training Accuracy: %.2f%% \t Test Accuracy: %.2f%%\n" data_name train_acc test_acc
end
end
train()
[ 1/ 10] MNIST Time 68.31514s Training Accuracy: 76.27% Test Accuracy: 78.12%
[ 1/ 10] FashionMNIST Time 0.02705s Training Accuracy: 46.78% Test Accuracy: 50.00%
[ 2/ 10] MNIST Time 0.02770s Training Accuracy: 74.90% Test Accuracy: 65.62%
[ 2/ 10] FashionMNIST Time 0.02785s Training Accuracy: 55.18% Test Accuracy: 59.38%
[ 3/ 10] MNIST Time 0.02701s Training Accuracy: 84.86% Test Accuracy: 84.38%
[ 3/ 10] FashionMNIST Time 0.02642s Training Accuracy: 62.50% Test Accuracy: 53.12%
[ 4/ 10] MNIST Time 0.02796s Training Accuracy: 82.52% Test Accuracy: 81.25%
[ 4/ 10] FashionMNIST Time 0.02666s Training Accuracy: 64.65% Test Accuracy: 59.38%
[ 5/ 10] MNIST Time 0.02572s Training Accuracy: 84.57% Test Accuracy: 90.62%
[ 5/ 10] FashionMNIST Time 0.02200s Training Accuracy: 68.55% Test Accuracy: 53.12%
[ 6/ 10] MNIST Time 0.02240s Training Accuracy: 91.89% Test Accuracy: 90.62%
[ 6/ 10] FashionMNIST Time 0.04166s Training Accuracy: 68.07% Test Accuracy: 68.75%
[ 7/ 10] MNIST Time 0.02017s Training Accuracy: 92.87% Test Accuracy: 96.88%
[ 7/ 10] FashionMNIST Time 0.02147s Training Accuracy: 72.07% Test Accuracy: 71.88%
[ 8/ 10] MNIST Time 0.02108s Training Accuracy: 94.82% Test Accuracy: 96.88%
[ 8/ 10] FashionMNIST Time 0.02077s Training Accuracy: 72.07% Test Accuracy: 62.50%
[ 9/ 10] MNIST Time 0.02023s Training Accuracy: 94.43% Test Accuracy: 93.75%
[ 9/ 10] FashionMNIST Time 0.02067s Training Accuracy: 70.02% Test Accuracy: 59.38%
[ 10/ 10] MNIST Time 0.02009s Training Accuracy: 95.61% Test Accuracy: 96.88%
[ 10/ 10] FashionMNIST Time 0.02038s Training Accuracy: 77.64% Test Accuracy: 75.00%
[FINAL] MNIST Training Accuracy: 91.11% Test Accuracy: 87.50%
[FINAL] FashionMNIST Training Accuracy: 77.64% Test Accuracy: 75.00%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional()
println()
CUDA.versioninfo()
end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
println()
AMDGPU.versioninfo()
end
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 8 default, 0 interactive, 4 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_AMDGPU_LOGGING_ENABLED = true
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 8
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.4, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15
CUDA libraries:
- CUBLAS: 12.4.5
- CURAND: 10.3.5
- CUFFT: 11.2.1
- CUSOLVER: 11.6.1
- CUSPARSE: 12.3.1
- CUPTI: 22.0.0
- NVML: 12.0.0+550.54.15
Julia packages:
- CUDA: 5.3.4
- CUDA_Driver_jll: 0.8.1+0
- CUDA_Runtime_jll: 0.12.1+0
Toolchain:
- Julia: 1.10.3
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 1.732 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19
This page was generated using Literate.jl.