Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ADTypes, ComponentArrays, LuxAMDGPU, LuxCUDA, MLDatasets, MLUtils, OneHotArrays,
      Optimisers, Printf, Random, Setfield, Statistics, Zygote

CUDA.allowscalar(false)

Loading Datasets

julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
    imgs, labels = dset(:train)[1:n_train]
    x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)

    imgs, labels = dset(:test)[1:n_eval]
    x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

    return (DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
        DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end

function load_datasets(n_train=1024, n_eval=32, batchsize=256)
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 4 methods)

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::Lux.AbstractExplicitLayer,
        core_network::Lux.AbstractExplicitLayer)
    ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
              ComponentArray |>
              getaxes
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end

Create and Initialize the HyperNet

julia
function create_model()
    # Doesn't need to be a MLP can have any Lux Layer
    core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
    weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
        Dense(64, Lux.parameterlength(core_network)))

    model = HyperNet(weight_generator, core_network)
    return model
end
create_model (generic function with 1 method)

Define Utility Functions

julia
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (data_idx, x, y))
    y_pred, st = model((data_idx, x), ps, st)
    return logitcrossentropy(y_pred, y), st, (;)
end

function accuracy(model, ps, st, dataloader, data_idx, gdev=gpu_device())
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    cpu_dev = cpu_device()
    for (x, y) in dataloader
        x = x |> gdev
        y = y |> gdev
        target_class = onecold(cpu_dev(y))
        predicted_class = onecold(cpu_dev(model((data_idx, x), ps, st)[1]))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 2 methods)

Training

julia
function train()
    model = create_model()
    dataloaders = load_datasets()

    dev = gpu_device()

    rng = Xoshiro(0)

    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(3.0f-4); transform_variables=dev)

    ### Lets train the model
    nepochs = 10
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx]

        stime = time()
        for (x, y) in train_dataloader
            x = x |> dev
            y = y |> dev
            (gs, _, _, train_state) = Lux.Experimental.compute_gradients(
                AutoZygote(), loss, (data_idx, x, y), train_state)
            train_state = Lux.Experimental.apply_gradients!(train_state, gs)
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(model, train_state.parameters, train_state.states,
                train_dataloader, data_idx, dev) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters, train_state.states,
                test_dataloader, data_idx, dev) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d] \t %12s \t Time %.5fs \t Training Accuracy: %.2f%% \t Test Accuracy: %.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx]
        train_acc = round(
            accuracy(model, train_state.parameters, train_state.states,
                train_dataloader, data_idx, dev) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters, train_state.states,
                test_dataloader, data_idx, dev) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL] \t %12s \t Training Accuracy: %.2f%% \t Test Accuracy: %.2f%%\n" data_name train_acc test_acc
    end
end

train()
[  1/ 10] 	        MNIST 	 Time 70.14708s 	 Training Accuracy: 76.27% 	 Test Accuracy: 78.12%
[  1/ 10] 	 FashionMNIST 	 Time 0.02141s 	 Training Accuracy: 46.78% 	 Test Accuracy: 50.00%
[  2/ 10] 	        MNIST 	 Time 0.02203s 	 Training Accuracy: 74.90% 	 Test Accuracy: 65.62%
[  2/ 10] 	 FashionMNIST 	 Time 0.02252s 	 Training Accuracy: 55.18% 	 Test Accuracy: 59.38%
[  3/ 10] 	        MNIST 	 Time 0.02254s 	 Training Accuracy: 84.86% 	 Test Accuracy: 84.38%
[  3/ 10] 	 FashionMNIST 	 Time 0.02240s 	 Training Accuracy: 62.50% 	 Test Accuracy: 53.12%
[  4/ 10] 	        MNIST 	 Time 0.02226s 	 Training Accuracy: 82.52% 	 Test Accuracy: 81.25%
[  4/ 10] 	 FashionMNIST 	 Time 0.02258s 	 Training Accuracy: 64.65% 	 Test Accuracy: 59.38%
[  5/ 10] 	        MNIST 	 Time 0.02218s 	 Training Accuracy: 84.57% 	 Test Accuracy: 90.62%
[  5/ 10] 	 FashionMNIST 	 Time 0.03661s 	 Training Accuracy: 68.55% 	 Test Accuracy: 53.12%
[  6/ 10] 	        MNIST 	 Time 0.02059s 	 Training Accuracy: 91.89% 	 Test Accuracy: 90.62%
[  6/ 10] 	 FashionMNIST 	 Time 0.02070s 	 Training Accuracy: 68.07% 	 Test Accuracy: 68.75%
[  7/ 10] 	        MNIST 	 Time 0.02107s 	 Training Accuracy: 92.87% 	 Test Accuracy: 96.88%
[  7/ 10] 	 FashionMNIST 	 Time 0.02074s 	 Training Accuracy: 72.07% 	 Test Accuracy: 71.88%
[  8/ 10] 	        MNIST 	 Time 0.02144s 	 Training Accuracy: 94.82% 	 Test Accuracy: 96.88%
[  8/ 10] 	 FashionMNIST 	 Time 0.02087s 	 Training Accuracy: 72.07% 	 Test Accuracy: 62.50%
[  9/ 10] 	        MNIST 	 Time 0.02105s 	 Training Accuracy: 94.43% 	 Test Accuracy: 93.75%
[  9/ 10] 	 FashionMNIST 	 Time 0.02253s 	 Training Accuracy: 70.02% 	 Test Accuracy: 59.38%
[ 10/ 10] 	        MNIST 	 Time 0.02726s 	 Training Accuracy: 95.61% 	 Test Accuracy: 96.88%
[ 10/ 10] 	 FashionMNIST 	 Time 0.02059s 	 Training Accuracy: 77.64% 	 Test Accuracy: 75.00%

[FINAL] 	        MNIST 	 Training Accuracy: 91.11% 	 Test Accuracy: 87.50%
[FINAL] 	 FashionMNIST 	 Training Accuracy: 77.64% 	 Test Accuracy: 75.00%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 8 default, 0 interactive, 4 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 8
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.5, artifact installation
CUDA driver 12.5
NVIDIA driver 550.54.15, originally for CUDA 12.4

CUDA libraries: 
- CUBLAS: 12.5.2
- CURAND: 10.3.6
- CUFFT: 11.2.3
- CUSOLVER: 11.6.2
- CUSPARSE: 12.4.1
- CUPTI: 23.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.4.2
- CUDA_Driver_jll: 0.9.0+0
- CUDA_Runtime_jll: 0.14.0+1

Toolchain:
- Julia: 1.10.3
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.170 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/DaGeB/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.