Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
      Printf, Random, Setfield, Statistics, Zygote

CUDA.allowscalar(false)

Loading Datasets

julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
    imgs, labels = dset(:train)[1:n_train]
    x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)

    imgs, labels = dset(:test)[1:n_eval]
    x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

    return (
        DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
        DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end

function load_datasets(n_train=1024, n_eval=32, batchsize=256)
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 4 methods)

Implement a HyperNet Layer

julia
function HyperNet(
        weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
    ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
              ComponentArray |>
              getaxes
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end

Create and Initialize the HyperNet

julia
function create_model()
    # Doesn't need to be a MLP can have any Lux Layer
    core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
    weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
        Dense(64, Lux.parameterlength(core_network)))

    model = HyperNet(weight_generator, core_network)
    return model
end
create_model (generic function with 1 method)

Define Utility Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model((data_idx, x), ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    model = create_model()
    dataloaders = load_datasets()

    dev = gpu_device()
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(3.0f-4))

    ### Lets train the model
    nepochs = 25
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoZygote(), loss, ((data_idx, x), y), train_state)
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d] \t %12s \t Time %.5fs \t Training Accuracy: %.2f%% \t Test \
                 Accuracy: %.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL] \t %12s \t Training Accuracy: %.2f%% \t Test Accuracy: \
                 %.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
[  1/ 25] 	        MNIST 	 Time 69.51804s 	 Training Accuracy: 23.24% 	 Test Accuracy: 25.00%
[  1/ 25] 	 FashionMNIST 	 Time 0.01361s 	 Training Accuracy: 28.22% 	 Test Accuracy: 25.00%
[  2/ 25] 	        MNIST 	 Time 0.01258s 	 Training Accuracy: 47.75% 	 Test Accuracy: 31.25%
[  2/ 25] 	 FashionMNIST 	 Time 0.01260s 	 Training Accuracy: 51.17% 	 Test Accuracy: 37.50%
[  3/ 25] 	        MNIST 	 Time 0.01241s 	 Training Accuracy: 59.57% 	 Test Accuracy: 59.38%
[  3/ 25] 	 FashionMNIST 	 Time 0.04608s 	 Training Accuracy: 62.40% 	 Test Accuracy: 62.50%
[  4/ 25] 	        MNIST 	 Time 0.01120s 	 Training Accuracy: 64.84% 	 Test Accuracy: 43.75%
[  4/ 25] 	 FashionMNIST 	 Time 0.01117s 	 Training Accuracy: 65.14% 	 Test Accuracy: 50.00%
[  5/ 25] 	        MNIST 	 Time 0.01115s 	 Training Accuracy: 76.56% 	 Test Accuracy: 56.25%
[  5/ 25] 	 FashionMNIST 	 Time 0.01113s 	 Training Accuracy: 73.63% 	 Test Accuracy: 65.62%
[  6/ 25] 	        MNIST 	 Time 0.01244s 	 Training Accuracy: 78.32% 	 Test Accuracy: 62.50%
[  6/ 25] 	 FashionMNIST 	 Time 0.01249s 	 Training Accuracy: 74.02% 	 Test Accuracy: 68.75%
[  7/ 25] 	        MNIST 	 Time 0.01266s 	 Training Accuracy: 83.89% 	 Test Accuracy: 65.62%
[  7/ 25] 	 FashionMNIST 	 Time 0.01240s 	 Training Accuracy: 78.03% 	 Test Accuracy: 62.50%
[  8/ 25] 	        MNIST 	 Time 0.01271s 	 Training Accuracy: 86.33% 	 Test Accuracy: 65.62%
[  8/ 25] 	 FashionMNIST 	 Time 0.01274s 	 Training Accuracy: 80.66% 	 Test Accuracy: 71.88%
[  9/ 25] 	        MNIST 	 Time 0.01486s 	 Training Accuracy: 88.18% 	 Test Accuracy: 65.62%
[  9/ 25] 	 FashionMNIST 	 Time 0.01454s 	 Training Accuracy: 81.54% 	 Test Accuracy: 68.75%
[ 10/ 25] 	        MNIST 	 Time 0.01112s 	 Training Accuracy: 91.99% 	 Test Accuracy: 68.75%
[ 10/ 25] 	 FashionMNIST 	 Time 0.01099s 	 Training Accuracy: 83.50% 	 Test Accuracy: 68.75%
[ 11/ 25] 	        MNIST 	 Time 0.01093s 	 Training Accuracy: 91.80% 	 Test Accuracy: 68.75%
[ 11/ 25] 	 FashionMNIST 	 Time 0.01093s 	 Training Accuracy: 84.47% 	 Test Accuracy: 71.88%
[ 12/ 25] 	        MNIST 	 Time 0.01097s 	 Training Accuracy: 95.41% 	 Test Accuracy: 62.50%
[ 12/ 25] 	 FashionMNIST 	 Time 0.01096s 	 Training Accuracy: 87.30% 	 Test Accuracy: 65.62%
[ 13/ 25] 	        MNIST 	 Time 0.01091s 	 Training Accuracy: 95.90% 	 Test Accuracy: 62.50%
[ 13/ 25] 	 FashionMNIST 	 Time 0.01094s 	 Training Accuracy: 87.60% 	 Test Accuracy: 78.12%
[ 14/ 25] 	        MNIST 	 Time 0.01211s 	 Training Accuracy: 96.29% 	 Test Accuracy: 62.50%
[ 14/ 25] 	 FashionMNIST 	 Time 0.01113s 	 Training Accuracy: 88.57% 	 Test Accuracy: 75.00%
[ 15/ 25] 	        MNIST 	 Time 0.01075s 	 Training Accuracy: 97.36% 	 Test Accuracy: 65.62%
[ 15/ 25] 	 FashionMNIST 	 Time 0.01072s 	 Training Accuracy: 88.09% 	 Test Accuracy: 62.50%
[ 16/ 25] 	        MNIST 	 Time 0.01072s 	 Training Accuracy: 97.85% 	 Test Accuracy: 62.50%
[ 16/ 25] 	 FashionMNIST 	 Time 0.01187s 	 Training Accuracy: 90.04% 	 Test Accuracy: 68.75%
[ 17/ 25] 	        MNIST 	 Time 0.01168s 	 Training Accuracy: 97.36% 	 Test Accuracy: 65.62%
[ 17/ 25] 	 FashionMNIST 	 Time 0.01150s 	 Training Accuracy: 88.09% 	 Test Accuracy: 68.75%
[ 18/ 25] 	        MNIST 	 Time 0.01068s 	 Training Accuracy: 97.56% 	 Test Accuracy: 62.50%
[ 18/ 25] 	 FashionMNIST 	 Time 0.01184s 	 Training Accuracy: 87.70% 	 Test Accuracy: 59.38%
[ 19/ 25] 	        MNIST 	 Time 0.01088s 	 Training Accuracy: 98.73% 	 Test Accuracy: 65.62%
[ 19/ 25] 	 FashionMNIST 	 Time 0.02154s 	 Training Accuracy: 88.67% 	 Test Accuracy: 65.62%
[ 20/ 25] 	        MNIST 	 Time 0.01196s 	 Training Accuracy: 98.93% 	 Test Accuracy: 65.62%
[ 20/ 25] 	 FashionMNIST 	 Time 0.01073s 	 Training Accuracy: 91.99% 	 Test Accuracy: 71.88%
[ 21/ 25] 	        MNIST 	 Time 0.01076s 	 Training Accuracy: 99.02% 	 Test Accuracy: 62.50%
[ 21/ 25] 	 FashionMNIST 	 Time 0.01153s 	 Training Accuracy: 90.43% 	 Test Accuracy: 68.75%
[ 22/ 25] 	        MNIST 	 Time 0.01068s 	 Training Accuracy: 98.93% 	 Test Accuracy: 65.62%
[ 22/ 25] 	 FashionMNIST 	 Time 0.01180s 	 Training Accuracy: 89.16% 	 Test Accuracy: 59.38%
[ 23/ 25] 	        MNIST 	 Time 0.01151s 	 Training Accuracy: 99.02% 	 Test Accuracy: 65.62%
[ 23/ 25] 	 FashionMNIST 	 Time 0.01228s 	 Training Accuracy: 91.70% 	 Test Accuracy: 68.75%
[ 24/ 25] 	        MNIST 	 Time 0.01132s 	 Training Accuracy: 98.44% 	 Test Accuracy: 62.50%
[ 24/ 25] 	 FashionMNIST 	 Time 0.01064s 	 Training Accuracy: 85.45% 	 Test Accuracy: 68.75%
[ 25/ 25] 	        MNIST 	 Time 0.01101s 	 Training Accuracy: 97.75% 	 Test Accuracy: 62.50%
[ 25/ 25] 	 FashionMNIST 	 Time 0.01145s 	 Training Accuracy: 81.45% 	 Test Accuracy: 62.50%

[FINAL] 	        MNIST 	 Training Accuracy: 98.24% 	 Test Accuracy: 62.50%
[FINAL] 	 FashionMNIST 	 Training Accuracy: 81.45% 	 Test Accuracy: 62.50%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 555.42.6

CUDA libraries: 
- CUBLAS: 12.5.3
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.3
- CUSPARSE: 12.5.1
- CUPTI: 2024.2.1 (API 23.0.0)
- NVML: 12.0.0+555.42.6

Julia packages: 
- CUDA: 5.4.3
- CUDA_Driver_jll: 0.9.2+0
- CUDA_Runtime_jll: 0.14.1+0

Toolchain:
- Julia: 1.10.5
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

Preferences:
- CUDA_Driver_jll.compat: false

2 devices:
  0: Quadro RTX 5000 (sm_75, 12.451 GiB / 16.000 GiB available)
  1: Quadro RTX 5000 (sm_75, 15.556 GiB / 16.000 GiB available)

This page was generated using Literate.jl.