Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
      Printf, Random, Setfield, Statistics, Zygote

CUDA.allowscalar(false)
Precompiling MLDatasets...
    429.4 ms  ✓ Glob
    382.7 ms  ✓ WorkerUtilities
    410.5 ms  ✓ BufferedStreams
    413.8 ms  ✓ PaddedViews
    327.5 ms  ✓ SimpleBufferStream
    583.5 ms  ✓ URIs
    295.6 ms  ✓ PackageExtensionCompat
    338.0 ms  ✓ BitFlags
    394.0 ms  ✓ StackViews
    660.4 ms  ✓ GZip
    695.0 ms  ✓ ConcurrentUtilities
    595.9 ms  ✓ ZipFile
    779.3 ms  ✓ StructTypes
   1008.3 ms  ✓ MbedTLS
    555.4 ms  ✓ MPIPreferences
    806.7 ms  ✓ Accessors → AccessorsUnitfulExt
    314.3 ms  ✓ InternedStrings
   2800.8 ms  ✓ UnitfulAtomic
   2365.1 ms  ✓ PeriodicTable
    469.4 ms  ✓ ExceptionUnwrapping
    554.4 ms  ✓ Chemfiles_jll
    440.5 ms  ✓ MicrosoftMPI_jll
    570.9 ms  ✓ libaec_jll
    507.5 ms  ✓ StringEncodings
    764.3 ms  ✓ WeakRefStrings
    398.5 ms  ✓ StridedViews
    436.3 ms  ✓ MosaicViews
   1839.1 ms  ✓ OpenSSL
   1548.2 ms  ✓ NPZ
   1375.4 ms  ✓ MPICH_jll
   1152.4 ms  ✓ MPItrampoline_jll
   1161.8 ms  ✓ OpenMPI_jll
   2230.4 ms  ✓ AtomsBase
  11048.5 ms  ✓ JSON3
   2354.9 ms  ✓ Pickle
  19583.5 ms  ✓ CSV
  34082.6 ms  ✓ JLD2
  18873.7 ms  ✓ ImageCore
   1833.7 ms  ✓ HDF5_jll
   2319.4 ms  ✓ Chemfiles
   2108.2 ms  ✓ ImageBase
   1933.8 ms  ✓ ImageShow
   7392.4 ms  ✓ HDF5
   2362.4 ms  ✓ MAT
  19226.6 ms  ✓ HTTP
   1869.5 ms  ✓ FileIO → HTTPExt
   3016.3 ms  ✓ DataDeps
   8847.3 ms  ✓ MLDatasets
  48 dependencies successfully precompiled in 65 seconds. 150 already precompiled.

Loading Datasets

julia
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
        n_eval::Union{Nothing, Int}, batchsize::Int) where {dset}
    if n_train === nothing
        imgs, labels = dset(:train)
    else
        imgs, labels = dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)

    if n_eval === nothing
        imgs, labels = dset(:test)
    else
        imgs, labels = dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

    return (
        DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
        DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)
    )
end

function load_datasets(batchsize=256)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(
        weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
    ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
              ComponentArray |>
              getaxes
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end

Create and Initialize the HyperNet

julia
function create_model()
    # Doesn't need to be a MLP can have any Lux Layer
    core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
    weight_generator = Chain(
        Embedding(2 => 32),
        Dense(32, 64, relu),
        Dense(64, Lux.parameterlength(core_network))
    )

    model = HyperNet(weight_generator, core_network)
    return model
end
create_model (generic function with 1 method)

Define Utility Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model((data_idx, x), ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    model = create_model()
    dataloaders = load_datasets()

    dev = gpu_device()
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoZygote(), loss, ((data_idx, x), y), train_state)
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
[  1/ 50]	       MNIST	Time 86.41144s	Training Accuracy: 54.00%	Test Accuracy: 53.12%
[  1/ 50]	FashionMNIST	Time 0.02907s	Training Accuracy: 36.04%	Test Accuracy: 31.25%
[  2/ 50]	       MNIST	Time 0.02835s	Training Accuracy: 64.94%	Test Accuracy: 62.50%
[  2/ 50]	FashionMNIST	Time 0.02838s	Training Accuracy: 51.76%	Test Accuracy: 53.12%
[  3/ 50]	       MNIST	Time 0.05963s	Training Accuracy: 73.24%	Test Accuracy: 53.12%
[  3/ 50]	FashionMNIST	Time 0.02768s	Training Accuracy: 59.08%	Test Accuracy: 50.00%
[  4/ 50]	       MNIST	Time 0.02138s	Training Accuracy: 75.49%	Test Accuracy: 59.38%
[  4/ 50]	FashionMNIST	Time 0.02150s	Training Accuracy: 68.65%	Test Accuracy: 50.00%
[  5/ 50]	       MNIST	Time 0.02270s	Training Accuracy: 81.64%	Test Accuracy: 62.50%
[  5/ 50]	FashionMNIST	Time 0.09385s	Training Accuracy: 67.09%	Test Accuracy: 65.62%
[  6/ 50]	       MNIST	Time 0.02166s	Training Accuracy: 85.64%	Test Accuracy: 68.75%
[  6/ 50]	FashionMNIST	Time 0.02131s	Training Accuracy: 71.78%	Test Accuracy: 59.38%
[  7/ 50]	       MNIST	Time 0.02086s	Training Accuracy: 87.60%	Test Accuracy: 71.88%
[  7/ 50]	FashionMNIST	Time 0.02124s	Training Accuracy: 76.37%	Test Accuracy: 75.00%
[  8/ 50]	       MNIST	Time 0.02026s	Training Accuracy: 89.94%	Test Accuracy: 75.00%
[  8/ 50]	FashionMNIST	Time 0.02188s	Training Accuracy: 80.96%	Test Accuracy: 65.62%
[  9/ 50]	       MNIST	Time 0.02065s	Training Accuracy: 91.70%	Test Accuracy: 71.88%
[  9/ 50]	FashionMNIST	Time 0.02066s	Training Accuracy: 78.52%	Test Accuracy: 71.88%
[ 10/ 50]	       MNIST	Time 0.02248s	Training Accuracy: 94.92%	Test Accuracy: 71.88%
[ 10/ 50]	FashionMNIST	Time 0.02141s	Training Accuracy: 80.76%	Test Accuracy: 75.00%
[ 11/ 50]	       MNIST	Time 0.02352s	Training Accuracy: 95.51%	Test Accuracy: 75.00%
[ 11/ 50]	FashionMNIST	Time 0.03260s	Training Accuracy: 81.84%	Test Accuracy: 68.75%
[ 12/ 50]	       MNIST	Time 0.02039s	Training Accuracy: 95.90%	Test Accuracy: 75.00%
[ 12/ 50]	FashionMNIST	Time 0.02034s	Training Accuracy: 82.23%	Test Accuracy: 78.12%
[ 13/ 50]	       MNIST	Time 0.02214s	Training Accuracy: 97.66%	Test Accuracy: 75.00%
[ 13/ 50]	FashionMNIST	Time 0.02058s	Training Accuracy: 82.03%	Test Accuracy: 78.12%
[ 14/ 50]	       MNIST	Time 0.02051s	Training Accuracy: 98.05%	Test Accuracy: 78.12%
[ 14/ 50]	FashionMNIST	Time 0.02128s	Training Accuracy: 83.50%	Test Accuracy: 84.38%
[ 15/ 50]	       MNIST	Time 0.02040s	Training Accuracy: 98.93%	Test Accuracy: 81.25%
[ 15/ 50]	FashionMNIST	Time 0.02037s	Training Accuracy: 85.06%	Test Accuracy: 81.25%
[ 16/ 50]	       MNIST	Time 0.02180s	Training Accuracy: 99.32%	Test Accuracy: 81.25%
[ 16/ 50]	FashionMNIST	Time 0.02008s	Training Accuracy: 85.55%	Test Accuracy: 75.00%
[ 17/ 50]	       MNIST	Time 0.02034s	Training Accuracy: 99.61%	Test Accuracy: 81.25%
[ 17/ 50]	FashionMNIST	Time 0.02047s	Training Accuracy: 85.94%	Test Accuracy: 78.12%
[ 18/ 50]	       MNIST	Time 0.02575s	Training Accuracy: 99.80%	Test Accuracy: 81.25%
[ 18/ 50]	FashionMNIST	Time 0.02010s	Training Accuracy: 86.62%	Test Accuracy: 78.12%
[ 19/ 50]	       MNIST	Time 0.02034s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 19/ 50]	FashionMNIST	Time 0.02190s	Training Accuracy: 87.60%	Test Accuracy: 78.12%
[ 20/ 50]	       MNIST	Time 0.02036s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 20/ 50]	FashionMNIST	Time 0.02037s	Training Accuracy: 88.57%	Test Accuracy: 75.00%
[ 21/ 50]	       MNIST	Time 0.02174s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 21/ 50]	FashionMNIST	Time 0.02068s	Training Accuracy: 89.84%	Test Accuracy: 75.00%
[ 22/ 50]	       MNIST	Time 0.02085s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 22/ 50]	FashionMNIST	Time 0.02283s	Training Accuracy: 89.75%	Test Accuracy: 78.12%
[ 23/ 50]	       MNIST	Time 0.02082s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 23/ 50]	FashionMNIST	Time 0.02086s	Training Accuracy: 90.53%	Test Accuracy: 75.00%
[ 24/ 50]	       MNIST	Time 0.02338s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 24/ 50]	FashionMNIST	Time 0.02136s	Training Accuracy: 90.92%	Test Accuracy: 71.88%
[ 25/ 50]	       MNIST	Time 0.02099s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 25/ 50]	FashionMNIST	Time 0.02270s	Training Accuracy: 90.62%	Test Accuracy: 78.12%
[ 26/ 50]	       MNIST	Time 0.02105s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 26/ 50]	FashionMNIST	Time 0.02280s	Training Accuracy: 90.14%	Test Accuracy: 78.12%
[ 27/ 50]	       MNIST	Time 0.02436s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 27/ 50]	FashionMNIST	Time 0.02136s	Training Accuracy: 91.02%	Test Accuracy: 78.12%
[ 28/ 50]	       MNIST	Time 0.02088s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 28/ 50]	FashionMNIST	Time 0.02328s	Training Accuracy: 90.62%	Test Accuracy: 78.12%
[ 29/ 50]	       MNIST	Time 0.02103s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 29/ 50]	FashionMNIST	Time 0.02093s	Training Accuracy: 91.50%	Test Accuracy: 78.12%
[ 30/ 50]	       MNIST	Time 0.02463s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 30/ 50]	FashionMNIST	Time 0.02127s	Training Accuracy: 91.60%	Test Accuracy: 78.12%
[ 31/ 50]	       MNIST	Time 0.02127s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 31/ 50]	FashionMNIST	Time 0.02264s	Training Accuracy: 92.38%	Test Accuracy: 75.00%
[ 32/ 50]	       MNIST	Time 0.02173s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 32/ 50]	FashionMNIST	Time 0.02072s	Training Accuracy: 92.58%	Test Accuracy: 78.12%
[ 33/ 50]	       MNIST	Time 0.02226s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 33/ 50]	FashionMNIST	Time 0.02081s	Training Accuracy: 93.07%	Test Accuracy: 78.12%
[ 34/ 50]	       MNIST	Time 0.02189s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 34/ 50]	FashionMNIST	Time 0.02241s	Training Accuracy: 93.36%	Test Accuracy: 78.12%
[ 35/ 50]	       MNIST	Time 0.02044s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 35/ 50]	FashionMNIST	Time 0.02066s	Training Accuracy: 93.36%	Test Accuracy: 78.12%
[ 36/ 50]	       MNIST	Time 0.02227s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 36/ 50]	FashionMNIST	Time 0.02411s	Training Accuracy: 93.95%	Test Accuracy: 75.00%
[ 37/ 50]	       MNIST	Time 0.02093s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 37/ 50]	FashionMNIST	Time 0.02723s	Training Accuracy: 94.43%	Test Accuracy: 71.88%
[ 38/ 50]	       MNIST	Time 0.02085s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 38/ 50]	FashionMNIST	Time 0.02177s	Training Accuracy: 94.34%	Test Accuracy: 75.00%
[ 39/ 50]	       MNIST	Time 0.02244s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 39/ 50]	FashionMNIST	Time 0.02153s	Training Accuracy: 94.73%	Test Accuracy: 75.00%
[ 40/ 50]	       MNIST	Time 0.02167s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 40/ 50]	FashionMNIST	Time 0.02313s	Training Accuracy: 95.02%	Test Accuracy: 75.00%
[ 41/ 50]	       MNIST	Time 0.02104s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 41/ 50]	FashionMNIST	Time 0.02070s	Training Accuracy: 95.21%	Test Accuracy: 78.12%
[ 42/ 50]	       MNIST	Time 0.02538s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 42/ 50]	FashionMNIST	Time 0.02063s	Training Accuracy: 95.41%	Test Accuracy: 78.12%
[ 43/ 50]	       MNIST	Time 0.02206s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 43/ 50]	FashionMNIST	Time 0.02219s	Training Accuracy: 94.73%	Test Accuracy: 78.12%
[ 44/ 50]	       MNIST	Time 0.02107s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 44/ 50]	FashionMNIST	Time 0.02247s	Training Accuracy: 95.31%	Test Accuracy: 78.12%
[ 45/ 50]	       MNIST	Time 0.02237s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 45/ 50]	FashionMNIST	Time 0.02071s	Training Accuracy: 95.70%	Test Accuracy: 78.12%
[ 46/ 50]	       MNIST	Time 0.02070s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 46/ 50]	FashionMNIST	Time 0.02237s	Training Accuracy: 95.12%	Test Accuracy: 81.25%
[ 47/ 50]	       MNIST	Time 0.02072s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 47/ 50]	FashionMNIST	Time 0.02157s	Training Accuracy: 95.41%	Test Accuracy: 75.00%
[ 48/ 50]	       MNIST	Time 0.02487s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 48/ 50]	FashionMNIST	Time 0.02112s	Training Accuracy: 95.61%	Test Accuracy: 78.12%
[ 49/ 50]	       MNIST	Time 0.02190s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 49/ 50]	FashionMNIST	Time 0.02341s	Training Accuracy: 96.39%	Test Accuracy: 75.00%
[ 50/ 50]	       MNIST	Time 0.02079s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 50/ 50]	FashionMNIST	Time 0.02124s	Training Accuracy: 96.68%	Test Accuracy: 78.12%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[FINAL]	FashionMNIST	Training Accuracy: 96.68%	Test Accuracy: 78.12%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.576 GiB / 4.750 GiB available)

This page was generated using Literate.jl.