Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
CUDA.allowscalar(false)
Precompiling MLDatasets...
429.4 ms ✓ Glob
382.7 ms ✓ WorkerUtilities
410.5 ms ✓ BufferedStreams
413.8 ms ✓ PaddedViews
327.5 ms ✓ SimpleBufferStream
583.5 ms ✓ URIs
295.6 ms ✓ PackageExtensionCompat
338.0 ms ✓ BitFlags
394.0 ms ✓ StackViews
660.4 ms ✓ GZip
695.0 ms ✓ ConcurrentUtilities
595.9 ms ✓ ZipFile
779.3 ms ✓ StructTypes
1008.3 ms ✓ MbedTLS
555.4 ms ✓ MPIPreferences
806.7 ms ✓ Accessors → AccessorsUnitfulExt
314.3 ms ✓ InternedStrings
2800.8 ms ✓ UnitfulAtomic
2365.1 ms ✓ PeriodicTable
469.4 ms ✓ ExceptionUnwrapping
554.4 ms ✓ Chemfiles_jll
440.5 ms ✓ MicrosoftMPI_jll
570.9 ms ✓ libaec_jll
507.5 ms ✓ StringEncodings
764.3 ms ✓ WeakRefStrings
398.5 ms ✓ StridedViews
436.3 ms ✓ MosaicViews
1839.1 ms ✓ OpenSSL
1548.2 ms ✓ NPZ
1375.4 ms ✓ MPICH_jll
1152.4 ms ✓ MPItrampoline_jll
1161.8 ms ✓ OpenMPI_jll
2230.4 ms ✓ AtomsBase
11048.5 ms ✓ JSON3
2354.9 ms ✓ Pickle
19583.5 ms ✓ CSV
34082.6 ms ✓ JLD2
18873.7 ms ✓ ImageCore
1833.7 ms ✓ HDF5_jll
2319.4 ms ✓ Chemfiles
2108.2 ms ✓ ImageBase
1933.8 ms ✓ ImageShow
7392.4 ms ✓ HDF5
2362.4 ms ✓ MAT
19226.6 ms ✓ HTTP
1869.5 ms ✓ FileIO → HTTPExt
3016.3 ms ✓ DataDeps
8847.3 ms ✓ MLDatasets
48 dependencies successfully precompiled in 65 seconds. 150 already precompiled.
Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
n_eval::Union{Nothing, Int}, batchsize::Int) where {dset}
if n_train === nothing
imgs, labels = dset(:train)
else
imgs, labels = dset(:train)[1:n_train]
end
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
if n_eval === nothing
imgs, labels = dset(:test)
else
imgs, labels = dset(:test)[1:n_eval]
end
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)
)
end
function load_datasets(batchsize=256)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end
Create and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
)
model = HyperNet(weight_generator, core_network)
return model
end
create_model (generic function with 1 method)
Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
[ 1/ 50] MNIST Time 86.41144s Training Accuracy: 54.00% Test Accuracy: 53.12%
[ 1/ 50] FashionMNIST Time 0.02907s Training Accuracy: 36.04% Test Accuracy: 31.25%
[ 2/ 50] MNIST Time 0.02835s Training Accuracy: 64.94% Test Accuracy: 62.50%
[ 2/ 50] FashionMNIST Time 0.02838s Training Accuracy: 51.76% Test Accuracy: 53.12%
[ 3/ 50] MNIST Time 0.05963s Training Accuracy: 73.24% Test Accuracy: 53.12%
[ 3/ 50] FashionMNIST Time 0.02768s Training Accuracy: 59.08% Test Accuracy: 50.00%
[ 4/ 50] MNIST Time 0.02138s Training Accuracy: 75.49% Test Accuracy: 59.38%
[ 4/ 50] FashionMNIST Time 0.02150s Training Accuracy: 68.65% Test Accuracy: 50.00%
[ 5/ 50] MNIST Time 0.02270s Training Accuracy: 81.64% Test Accuracy: 62.50%
[ 5/ 50] FashionMNIST Time 0.09385s Training Accuracy: 67.09% Test Accuracy: 65.62%
[ 6/ 50] MNIST Time 0.02166s Training Accuracy: 85.64% Test Accuracy: 68.75%
[ 6/ 50] FashionMNIST Time 0.02131s Training Accuracy: 71.78% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.02086s Training Accuracy: 87.60% Test Accuracy: 71.88%
[ 7/ 50] FashionMNIST Time 0.02124s Training Accuracy: 76.37% Test Accuracy: 75.00%
[ 8/ 50] MNIST Time 0.02026s Training Accuracy: 89.94% Test Accuracy: 75.00%
[ 8/ 50] FashionMNIST Time 0.02188s Training Accuracy: 80.96% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.02065s Training Accuracy: 91.70% Test Accuracy: 71.88%
[ 9/ 50] FashionMNIST Time 0.02066s Training Accuracy: 78.52% Test Accuracy: 71.88%
[ 10/ 50] MNIST Time 0.02248s Training Accuracy: 94.92% Test Accuracy: 71.88%
[ 10/ 50] FashionMNIST Time 0.02141s Training Accuracy: 80.76% Test Accuracy: 75.00%
[ 11/ 50] MNIST Time 0.02352s Training Accuracy: 95.51% Test Accuracy: 75.00%
[ 11/ 50] FashionMNIST Time 0.03260s Training Accuracy: 81.84% Test Accuracy: 68.75%
[ 12/ 50] MNIST Time 0.02039s Training Accuracy: 95.90% Test Accuracy: 75.00%
[ 12/ 50] FashionMNIST Time 0.02034s Training Accuracy: 82.23% Test Accuracy: 78.12%
[ 13/ 50] MNIST Time 0.02214s Training Accuracy: 97.66% Test Accuracy: 75.00%
[ 13/ 50] FashionMNIST Time 0.02058s Training Accuracy: 82.03% Test Accuracy: 78.12%
[ 14/ 50] MNIST Time 0.02051s Training Accuracy: 98.05% Test Accuracy: 78.12%
[ 14/ 50] FashionMNIST Time 0.02128s Training Accuracy: 83.50% Test Accuracy: 84.38%
[ 15/ 50] MNIST Time 0.02040s Training Accuracy: 98.93% Test Accuracy: 81.25%
[ 15/ 50] FashionMNIST Time 0.02037s Training Accuracy: 85.06% Test Accuracy: 81.25%
[ 16/ 50] MNIST Time 0.02180s Training Accuracy: 99.32% Test Accuracy: 81.25%
[ 16/ 50] FashionMNIST Time 0.02008s Training Accuracy: 85.55% Test Accuracy: 75.00%
[ 17/ 50] MNIST Time 0.02034s Training Accuracy: 99.61% Test Accuracy: 81.25%
[ 17/ 50] FashionMNIST Time 0.02047s Training Accuracy: 85.94% Test Accuracy: 78.12%
[ 18/ 50] MNIST Time 0.02575s Training Accuracy: 99.80% Test Accuracy: 81.25%
[ 18/ 50] FashionMNIST Time 0.02010s Training Accuracy: 86.62% Test Accuracy: 78.12%
[ 19/ 50] MNIST Time 0.02034s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 19/ 50] FashionMNIST Time 0.02190s Training Accuracy: 87.60% Test Accuracy: 78.12%
[ 20/ 50] MNIST Time 0.02036s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 20/ 50] FashionMNIST Time 0.02037s Training Accuracy: 88.57% Test Accuracy: 75.00%
[ 21/ 50] MNIST Time 0.02174s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 21/ 50] FashionMNIST Time 0.02068s Training Accuracy: 89.84% Test Accuracy: 75.00%
[ 22/ 50] MNIST Time 0.02085s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 22/ 50] FashionMNIST Time 0.02283s Training Accuracy: 89.75% Test Accuracy: 78.12%
[ 23/ 50] MNIST Time 0.02082s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 23/ 50] FashionMNIST Time 0.02086s Training Accuracy: 90.53% Test Accuracy: 75.00%
[ 24/ 50] MNIST Time 0.02338s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 24/ 50] FashionMNIST Time 0.02136s Training Accuracy: 90.92% Test Accuracy: 71.88%
[ 25/ 50] MNIST Time 0.02099s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 25/ 50] FashionMNIST Time 0.02270s Training Accuracy: 90.62% Test Accuracy: 78.12%
[ 26/ 50] MNIST Time 0.02105s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 26/ 50] FashionMNIST Time 0.02280s Training Accuracy: 90.14% Test Accuracy: 78.12%
[ 27/ 50] MNIST Time 0.02436s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 27/ 50] FashionMNIST Time 0.02136s Training Accuracy: 91.02% Test Accuracy: 78.12%
[ 28/ 50] MNIST Time 0.02088s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 28/ 50] FashionMNIST Time 0.02328s Training Accuracy: 90.62% Test Accuracy: 78.12%
[ 29/ 50] MNIST Time 0.02103s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 29/ 50] FashionMNIST Time 0.02093s Training Accuracy: 91.50% Test Accuracy: 78.12%
[ 30/ 50] MNIST Time 0.02463s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 30/ 50] FashionMNIST Time 0.02127s Training Accuracy: 91.60% Test Accuracy: 78.12%
[ 31/ 50] MNIST Time 0.02127s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 31/ 50] FashionMNIST Time 0.02264s Training Accuracy: 92.38% Test Accuracy: 75.00%
[ 32/ 50] MNIST Time 0.02173s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 32/ 50] FashionMNIST Time 0.02072s Training Accuracy: 92.58% Test Accuracy: 78.12%
[ 33/ 50] MNIST Time 0.02226s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 33/ 50] FashionMNIST Time 0.02081s Training Accuracy: 93.07% Test Accuracy: 78.12%
[ 34/ 50] MNIST Time 0.02189s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 34/ 50] FashionMNIST Time 0.02241s Training Accuracy: 93.36% Test Accuracy: 78.12%
[ 35/ 50] MNIST Time 0.02044s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 35/ 50] FashionMNIST Time 0.02066s Training Accuracy: 93.36% Test Accuracy: 78.12%
[ 36/ 50] MNIST Time 0.02227s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 36/ 50] FashionMNIST Time 0.02411s Training Accuracy: 93.95% Test Accuracy: 75.00%
[ 37/ 50] MNIST Time 0.02093s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 37/ 50] FashionMNIST Time 0.02723s Training Accuracy: 94.43% Test Accuracy: 71.88%
[ 38/ 50] MNIST Time 0.02085s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 38/ 50] FashionMNIST Time 0.02177s Training Accuracy: 94.34% Test Accuracy: 75.00%
[ 39/ 50] MNIST Time 0.02244s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 39/ 50] FashionMNIST Time 0.02153s Training Accuracy: 94.73% Test Accuracy: 75.00%
[ 40/ 50] MNIST Time 0.02167s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 40/ 50] FashionMNIST Time 0.02313s Training Accuracy: 95.02% Test Accuracy: 75.00%
[ 41/ 50] MNIST Time 0.02104s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 41/ 50] FashionMNIST Time 0.02070s Training Accuracy: 95.21% Test Accuracy: 78.12%
[ 42/ 50] MNIST Time 0.02538s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 42/ 50] FashionMNIST Time 0.02063s Training Accuracy: 95.41% Test Accuracy: 78.12%
[ 43/ 50] MNIST Time 0.02206s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 43/ 50] FashionMNIST Time 0.02219s Training Accuracy: 94.73% Test Accuracy: 78.12%
[ 44/ 50] MNIST Time 0.02107s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 44/ 50] FashionMNIST Time 0.02247s Training Accuracy: 95.31% Test Accuracy: 78.12%
[ 45/ 50] MNIST Time 0.02237s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 45/ 50] FashionMNIST Time 0.02071s Training Accuracy: 95.70% Test Accuracy: 78.12%
[ 46/ 50] MNIST Time 0.02070s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 46/ 50] FashionMNIST Time 0.02237s Training Accuracy: 95.12% Test Accuracy: 81.25%
[ 47/ 50] MNIST Time 0.02072s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 47/ 50] FashionMNIST Time 0.02157s Training Accuracy: 95.41% Test Accuracy: 75.00%
[ 48/ 50] MNIST Time 0.02487s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 48/ 50] FashionMNIST Time 0.02112s Training Accuracy: 95.61% Test Accuracy: 78.12%
[ 49/ 50] MNIST Time 0.02190s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 49/ 50] FashionMNIST Time 0.02341s Training Accuracy: 96.39% Test Accuracy: 75.00%
[ 50/ 50] MNIST Time 0.02079s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 50/ 50] FashionMNIST Time 0.02124s Training Accuracy: 96.68% Test Accuracy: 78.12%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 78.12%
[FINAL] FashionMNIST Training Accuracy: 96.68% Test Accuracy: 78.12%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.576 GiB / 4.750 GiB available)
This page was generated using Literate.jl.