Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ADTypes, ComponentArrays, LuxAMDGPU, LuxCUDA, MLDatasets, MLUtils, OneHotArrays,
      Optimisers, Printf, Random, Setfield, Statistics, Zygote

CUDA.allowscalar(false)

Loading Datasets

julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
    imgs, labels = dset(:train)[1:n_train]
    x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)

    imgs, labels = dset(:test)[1:n_eval]
    x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

    return (DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
        DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false))
end

function load_datasets(n_train=1024, n_eval=32, batchsize=256)
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 4 methods)

Implement a HyperNet Layer

julia
struct HyperNet{W <: Lux.AbstractExplicitLayer, C <: Lux.AbstractExplicitLayer, A} <:
       Lux.AbstractExplicitContainerLayer{(:weight_generator, :core_network)}
    weight_generator::W
    core_network::C
    ca_axes::A
end

function HyperNet(w::Lux.AbstractExplicitLayer, c::Lux.AbstractExplicitLayer)
    ca_axes = Lux.initialparameters(Random.default_rng(), c) |> ComponentArray |> getaxes
    return HyperNet(w, c, ca_axes)
end

function Lux.initialparameters(rng::AbstractRNG, h::HyperNet)
    return (weight_generator=Lux.initialparameters(rng, h.weight_generator),)
end

function (hn::HyperNet)(x, ps, st::NamedTuple)
    ps_new, st_ = hn.weight_generator(x, ps.weight_generator, st.weight_generator)
    @set! st.weight_generator = st_
    return ComponentArray(vec(ps_new), hn.ca_axes), st
end

function (hn::HyperNet)((x, y)::T, ps, st::NamedTuple) where {T <: Tuple}
    ps_ca, st = hn(x, ps, st)
    pred, st_ = hn.core_network(y, ps_ca, st.core_network)
    @set! st.core_network = st_
    return pred, st
end

Create and Initialize the HyperNet

julia
function create_model()
    # Doesn't need to be a MLP can have any Lux Layer
    core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
    weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
        Dense(64, Lux.parameterlength(core_network)))

    model = HyperNet(weight_generator, core_network)
    return model
end
create_model (generic function with 1 method)

Define Utility Functions

julia
logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))

function loss(model, ps, st, (data_idx, x, y))
    y_pred, st = model((data_idx, x), ps, st)
    return logitcrossentropy(y_pred, y), st, (;)
end

function accuracy(model, ps, st, dataloader, data_idx, gdev=gpu_device())
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    cpu_dev = cpu_device()
    for (x, y) in dataloader
        x = x |> gdev
        y = y |> gdev
        target_class = onecold(cpu_dev(y))
        predicted_class = onecold(cpu_dev(model((data_idx, x), ps, st)[1]))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 2 methods)

Training

julia
function train()
    model = create_model()
    dataloaders = load_datasets()

    dev = gpu_device()

    rng = Xoshiro(0)

    train_state = Lux.Experimental.TrainState(
        rng, model, Adam(3.0f-4); transform_variables=dev)

    ### Lets train the model
    nepochs = 10
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx]

        stime = time()
        for (x, y) in train_dataloader
            x = x |> dev
            y = y |> dev
            (gs, _, _, train_state) = Lux.Experimental.compute_gradients(
                AutoZygote(), loss, (data_idx, x, y), train_state)
            train_state = Lux.Experimental.apply_gradients(train_state, gs)
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(model, train_state.parameters, train_state.states,
                train_dataloader, data_idx, dev) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters, train_state.states,
                test_dataloader, data_idx, dev) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d] \t %12s \t Time %.5fs \t Training Accuracy: %.2f%% \t Test Accuracy: %.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx]
        train_acc = round(
            accuracy(model, train_state.parameters, train_state.states,
                train_dataloader, data_idx, dev) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters, train_state.states,
                test_dataloader, data_idx, dev) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL] \t %12s \t Training Accuracy: %.2f%% \t Test Accuracy: %.2f%%\n" data_name train_acc test_acc
    end
end

train()
[  1/ 10] 	        MNIST 	 Time 62.70427s 	 Training Accuracy: 76.27% 	 Test Accuracy: 78.12%
[  1/ 10] 	 FashionMNIST 	 Time 0.17746s 	 Training Accuracy: 54.98% 	 Test Accuracy: 50.00%
[  2/ 10] 	        MNIST 	 Time 0.03860s 	 Training Accuracy: 75.49% 	 Test Accuracy: 71.88%
[  2/ 10] 	 FashionMNIST 	 Time 0.05287s 	 Training Accuracy: 56.45% 	 Test Accuracy: 65.62%
[  3/ 10] 	        MNIST 	 Time 0.03839s 	 Training Accuracy: 81.84% 	 Test Accuracy: 78.12%
[  3/ 10] 	 FashionMNIST 	 Time 0.03281s 	 Training Accuracy: 62.21% 	 Test Accuracy: 56.25%
[  4/ 10] 	        MNIST 	 Time 0.03007s 	 Training Accuracy: 82.62% 	 Test Accuracy: 81.25%
[  4/ 10] 	 FashionMNIST 	 Time 0.03060s 	 Training Accuracy: 66.41% 	 Test Accuracy: 56.25%
[  5/ 10] 	        MNIST 	 Time 0.03625s 	 Training Accuracy: 81.54% 	 Test Accuracy: 81.25%
[  5/ 10] 	 FashionMNIST 	 Time 0.02937s 	 Training Accuracy: 65.72% 	 Test Accuracy: 71.88%
[  6/ 10] 	        MNIST 	 Time 0.03049s 	 Training Accuracy: 90.53% 	 Test Accuracy: 90.62%
[  6/ 10] 	 FashionMNIST 	 Time 0.02975s 	 Training Accuracy: 69.14% 	 Test Accuracy: 62.50%
[  7/ 10] 	        MNIST 	 Time 0.04204s 	 Training Accuracy: 92.68% 	 Test Accuracy: 90.62%
[  7/ 10] 	 FashionMNIST 	 Time 0.02972s 	 Training Accuracy: 75.10% 	 Test Accuracy: 68.75%
[  8/ 10] 	        MNIST 	 Time 0.03066s 	 Training Accuracy: 93.85% 	 Test Accuracy: 90.62%
[  8/ 10] 	 FashionMNIST 	 Time 0.03352s 	 Training Accuracy: 74.02% 	 Test Accuracy: 71.88%
[  9/ 10] 	        MNIST 	 Time 0.02887s 	 Training Accuracy: 94.53% 	 Test Accuracy: 93.75%
[  9/ 10] 	 FashionMNIST 	 Time 0.03016s 	 Training Accuracy: 76.76% 	 Test Accuracy: 71.88%
[ 10/ 10] 	        MNIST 	 Time 0.02880s 	 Training Accuracy: 94.73% 	 Test Accuracy: 87.50%
[ 10/ 10] 	 FashionMNIST 	 Time 0.02935s 	 Training Accuracy: 80.08% 	 Test Accuracy: 65.62%

[FINAL] 	        MNIST 	 Training Accuracy: 91.70% 	 Test Accuracy: 78.12%
[FINAL] 	 FashionMNIST 	 Training Accuracy: 80.08% 	 Test Accuracy: 65.62%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(LuxCUDA) && CUDA.functional(); println(); CUDA.versioninfo(); end
if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional(); println(); AMDGPU.versioninfo(); end
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PROJECT = /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/docs
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEBUG = Literate
  JULIA_CPU_THREADS = 2
  JULIA_NUM_THREADS = 48
  JULIA_LOAD_PATH = @:@v#.#:@stdlib
  JULIA_CUDA_HARD_MEMORY_LIMIT = 25%

CUDA runtime 12.3, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0

Toolchain:
- Julia: 1.10.2
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 25%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.443 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.