Training a Hyper Network
Package Imports¤
using Lux
using ComponentArrays,
LuxAMDGPU,
LuxCUDA,
MLDatasets,
MLUtils,
OneHotArrays,
Optimisers,
Random,
Setfield,
Statistics,
Zygote
CUDA.allowscalar(false)
Activating project at `/var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/examples`
Loading Datasets¤
function _load_dataset(dset, n_train::Int, n_eval::Int, batchsize::Int)
imgs, labels = dset(:train)[1:n_train]
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end
function load_datasets(n_train=1024, n_eval=32, batchsize=256)
return _load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 4 methods)
Implement a HyperNet Layer¤
struct HyperNet{W <: Lux.AbstractExplicitLayer, C <: Lux.AbstractExplicitLayer, A} <:
Lux.AbstractExplicitContainerLayer{(:weight_generator, :core_network)}
weight_generator::W
core_network::C
ca_axes::A
end
function HyperNet(w::Lux.AbstractExplicitLayer, c::Lux.AbstractExplicitLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), c) |> ComponentArray |> getaxes
return HyperNet(w, c, ca_axes)
end
function Lux.initialparameters(rng::AbstractRNG, h::HyperNet)
return (weight_generator=Lux.initialparameters(rng, h.weight_generator),)
end
function (hn::HyperNet)(x, ps, st::NamedTuple)
ps_new, st_ = hn.weight_generator(x, ps.weight_generator, st.weight_generator)
@set! st.weight_generator = st_
return ComponentArray(vec(ps_new), hn.ca_axes), st
end
function (hn::HyperNet)((x, y)::T, ps, st::NamedTuple) where {T <: Tuple}
ps_ca, st = hn(x, ps, st)
pred, st_ = hn.core_network(y, ps_ca, st.core_network)
@set! st.core_network = st_
return pred, st
end
Create and Initialize the HyperNet¤
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)))
model = HyperNet(weight_generator, core_network)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model) .|> gpu_device()
return model, ps, st
end
create_model (generic function with 1 method)
Define Utility Functions¤
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))
function loss(data_idx, x, y, model, ps, st)
y_pred, st = model((data_idx, x), ps, st)
return logitcrossentropy(y_pred, y), st
end
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
dev = gpu_device()
cpu_dev = cpu_device()
for (x, y) in dataloader
x = x |> dev
y = y |> dev
target_class = onecold(cpu_dev(y))
predicted_class = onecold(cpu_dev(model((data_idx, x), ps, st)[1]))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training¤
function train()
model, ps, st = create_model()
# Training
dataloaders = load_datasets()
opt = Adam(0.001f0)
st_opt = Optimisers.setup(opt, ps)
dev = gpu_device()
### Warmup the Model
img, lab = dev(dataloaders[1][1].data[1][:, :, :, 1:1]),
dev(dataloaders[1][1].data[2][:, 1:1])
loss(1, img, lab, model, ps, st)
(l, _), back = pullback(p -> loss(1, img, lab, model, p, st), ps)
back((one(l), nothing))
### Lets train the model
nepochs = 9
for epoch in 1:nepochs
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx]
stime = time()
for (x, y) in train_dataloader
x = x |> dev
y = y |> dev
(l, st), back = pullback(p -> loss(data_idx, x, y, model, p, st), ps)
gs = back((one(l), nothing))[1]
st_opt, ps = Optimisers.update(st_opt, ps, gs)
end
ttime = time() - stime
train_acc = round(accuracy(model, ps, st, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(accuracy(model, ps, st, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
println("[$epoch/$nepochs] \t $data_name Time $(round(ttime; digits=2))s \t " *
"Training Accuracy: $(train_acc)% \t Test Accuracy: $(test_acc)%")
end
end
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx]
train_acc = round(accuracy(model, ps, st, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(accuracy(model, ps, st, test_dataloader, data_idx) * 100; digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
println("[FINAL] \t $data_name Training Accuracy: $(train_acc)% \t " *
"Test Accuracy: $(test_acc)%")
end
end
train()
[1/9] MNIST Time 5.35s Training Accuracy: 54.98% Test Accuracy: 50.0%
[1/9] FashionMNIST Time 0.04s Training Accuracy: 57.03% Test Accuracy: 46.88%
[2/9] MNIST Time 0.04s Training Accuracy: 68.55% Test Accuracy: 75.0%
[2/9] FashionMNIST Time 0.04s Training Accuracy: 48.63% Test Accuracy: 46.88%
[3/9] MNIST Time 0.04s Training Accuracy: 78.52% Test Accuracy: 84.38%
[3/9] FashionMNIST Time 0.09s Training Accuracy: 60.94% Test Accuracy: 56.25%
[4/9] MNIST Time 0.03s Training Accuracy: 85.64% Test Accuracy: 84.38%
[4/9] FashionMNIST Time 0.03s Training Accuracy: 70.61% Test Accuracy: 68.75%
[5/9] MNIST Time 0.03s Training Accuracy: 86.43% Test Accuracy: 81.25%
[5/9] FashionMNIST Time 0.07s Training Accuracy: 69.24% Test Accuracy: 59.38%
[6/9] MNIST Time 0.03s Training Accuracy: 90.14% Test Accuracy: 78.12%
[6/9] FashionMNIST Time 0.03s Training Accuracy: 72.95% Test Accuracy: 65.62%
[7/9] MNIST Time 0.03s Training Accuracy: 92.29% Test Accuracy: 87.5%
[7/9] FashionMNIST Time 0.03s Training Accuracy: 75.49% Test Accuracy: 68.75%
[8/9] MNIST Time 0.06s Training Accuracy: 94.53% Test Accuracy: 90.62%
[8/9] FashionMNIST Time 0.03s Training Accuracy: 76.56% Test Accuracy: 68.75%
[9/9] MNIST Time 0.03s Training Accuracy: 96.39% Test Accuracy: 90.62%
[9/9] FashionMNIST Time 0.03s Training Accuracy: 74.51% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 96.19% Test Accuracy: 93.75%
[FINAL] FashionMNIST Training Accuracy: 74.51% Test Accuracy: 68.75%
This page was generated using Literate.jl.