Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
      Printf, Random, Zygote

CUDA.allowscalar(false)
Precompiling LuxComponentArraysExt...
   1540.4 ms  ✓ Lux → LuxComponentArraysExt
   2152.8 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  2 dependencies successfully precompiled in 2 seconds. 112 already precompiled.
Precompiling LuxCUDA...
  26103.8 ms  ✓ GPUCompiler
  51357.4 ms  ✓ CUDA
   5182.4 ms  ✓ Atomix → AtomixCUDAExt
   8159.0 ms  ✓ cuDNN
   5259.5 ms  ✓ LuxCUDA
  5 dependencies successfully precompiled in 97 seconds. 95 already precompiled.
Precompiling WeightInitializersGPUArraysExt...
   1407.8 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
  1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
   4922.4 ms  ✓ ArrayInterface → ArrayInterfaceCUDAExt
  1 dependency successfully precompiled in 5 seconds. 101 already precompiled.
Precompiling NNlibCUDAExt...
   4973.8 ms  ✓ CUDA → ChainRulesCoreExt
   5368.8 ms  ✓ NNlib → NNlibCUDAExt
  2 dependencies successfully precompiled in 6 seconds. 102 already precompiled.
Precompiling MLDataDevicesCUDAExt...
   5034.3 ms  ✓ MLDataDevices → MLDataDevicesCUDAExt
  1 dependency successfully precompiled in 5 seconds. 104 already precompiled.
Precompiling LuxLibCUDAExt...
   5212.8 ms  ✓ CUDA → EnzymeCoreExt
   5234.5 ms  ✓ CUDA → SpecialFunctionsExt
   5683.4 ms  ✓ LuxLib → LuxLibCUDAExt
  3 dependencies successfully precompiled in 6 seconds. 167 already precompiled.
Precompiling WeightInitializersCUDAExt...
   5051.9 ms  ✓ WeightInitializers → WeightInitializersCUDAExt
  1 dependency successfully precompiled in 5 seconds. 109 already precompiled.
Precompiling NNlibCUDACUDNNExt...
   5331.8 ms  ✓ NNlib → NNlibCUDACUDNNExt
  1 dependency successfully precompiled in 6 seconds. 106 already precompiled.
Precompiling MLDataDevicescuDNNExt...
   5064.3 ms  ✓ MLDataDevices → MLDataDevicescuDNNExt
  1 dependency successfully precompiled in 5 seconds. 107 already precompiled.
Precompiling LuxLibcuDNNExt...
   5799.5 ms  ✓ LuxLib → LuxLibcuDNNExt
  1 dependency successfully precompiled in 6 seconds. 174 already precompiled.
Precompiling MLDatasets...
    359.5 ms  ✓ Glob
    381.9 ms  ✓ WorkerUtilities
    420.8 ms  ✓ BufferedStreams
    315.5 ms  ✓ SimpleBufferStream
    435.1 ms  ✓ CodecZlib
    560.7 ms  ✓ URIs
    425.1 ms  ✓ DelimitedFiles
    308.7 ms  ✓ PackageExtensionCompat
    340.6 ms  ✓ BitFlags
    570.6 ms  ✓ ZipFile
    663.8 ms  ✓ GZip
    659.9 ms  ✓ ConcurrentUtilities
    491.0 ms  ✓ LoggingExtras
    378.3 ms  ✓ ContextVariablesX
    318.7 ms  ✓ InternedStrings
    470.1 ms  ✓ ExceptionUnwrapping
   2069.5 ms  ✓ ColorVectorSpace
   1084.2 ms  ✓ SplittablesBase
   2874.3 ms  ✓ Accessors
   1373.7 ms  ✓ MPICH_jll
    758.5 ms  ✓ WeakRefStrings
   1120.0 ms  ✓ MPItrampoline_jll
   2206.8 ms  ✓ AtomsBase
   1096.9 ms  ✓ OpenMPI_jll
    400.1 ms  ✓ StridedViews
   1825.4 ms  ✓ OpenSSL
   1529.2 ms  ✓ NPZ
    582.7 ms  ✓ FLoopsBase
  11077.4 ms  ✓ JSON3
   3450.4 ms  ✓ ColorSchemes
    605.0 ms  ✓ Accessors → AccessorsTestExt
    759.1 ms  ✓ Accessors → AccessorsDatesExt
    634.0 ms  ✓ Accessors → AccessorsUnitfulExt
    741.6 ms  ✓ BangBang
    683.5 ms  ✓ Accessors → AccessorsStaticArraysExt
  18510.4 ms  ✓ ImageCore
   2226.2 ms  ✓ Chemfiles
   1871.0 ms  ✓ HDF5_jll
   2346.6 ms  ✓ Pickle
  33069.8 ms  ✓ JLD2
    694.3 ms  ✓ BangBang → BangBangStaticArraysExt
  18725.0 ms  ✓ CSV
   1629.3 ms  ✓ BangBang → BangBangDataFramesExt
    495.8 ms  ✓ BangBang → BangBangChainRulesCoreExt
    477.0 ms  ✓ BangBang → BangBangTablesExt
    832.1 ms  ✓ MicroCollections
   2071.7 ms  ✓ ImageBase
   2589.0 ms  ✓ Transducers
   1853.5 ms  ✓ ImageShow
   7107.9 ms  ✓ HDF5
   1397.9 ms  ✓ Transducers → TransducersDataFramesExt
    633.4 ms  ✓ Transducers → TransducersAdaptExt
   2301.0 ms  ✓ MAT
   4950.5 ms  ✓ FLoops
  19188.9 ms  ✓ HTTP
   1819.2 ms  ✓ FileIO → HTTPExt
   3050.2 ms  ✓ DataDeps
   6202.1 ms  ✓ MLUtils
   8856.3 ms  ✓ MLDatasets
  59 dependencies successfully precompiled in 71 seconds. 139 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1592.5 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 102 already precompiled.
Precompiling LuxMLUtilsExt...
   2167.6 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 3 seconds. 167 already precompiled.
Precompiling OneHotArrays...
    934.1 ms  ✓ OneHotArrays
  1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
    727.2 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
  1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling Zygote...
    324.8 ms  ✓ RealDot
   5301.8 ms  ✓ ChainRules
  32822.6 ms  ✓ Zygote
  3 dependencies successfully precompiled in 39 seconds. 83 already precompiled.
Precompiling AccessorsStructArraysExt...
    473.3 ms  ✓ Accessors → AccessorsStructArraysExt
  1 dependency successfully precompiled in 1 seconds. 16 already precompiled.
Precompiling BangBangStructArraysExt...
    481.8 ms  ✓ BangBang → BangBangStructArraysExt
  1 dependency successfully precompiled in 1 seconds. 22 already precompiled.
Precompiling ArrayInterfaceChainRulesExt...
    864.4 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 39 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    833.2 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1604.0 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  1 dependency successfully precompiled in 2 seconds. 93 already precompiled.
Precompiling LuxZygoteExt...
   2761.2 ms  ✓ Lux → LuxZygoteExt
  1 dependency successfully precompiled in 3 seconds. 163 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1620.3 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
  1 dependency successfully precompiled in 2 seconds. 99 already precompiled.
Precompiling ZygoteColorsExt...
   1776.5 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 89 already precompiled.

Loading Datasets

julia
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
        n_eval::Union{Nothing, Int}, batchsize::Int) where {dset}
    if n_train === nothing
        imgs, labels = dset(:train)
    else
        imgs, labels = dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)

    if n_eval === nothing
        imgs, labels = dset(:test)
    else
        imgs, labels = dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

    return (
        DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
        DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)
    )
end

function load_datasets(batchsize=256)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(
        weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
    ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
              ComponentArray |>
              getaxes
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end

Create and Initialize the HyperNet

julia
function create_model()
    # Doesn't need to be a MLP can have any Lux Layer
    core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
    weight_generator = Chain(
        Embedding(2 => 32),
        Dense(32, 64, relu),
        Dense(64, Lux.parameterlength(core_network))
    )

    model = HyperNet(weight_generator, core_network)
    return model
end
create_model (generic function with 1 method)

Define Utility Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model((data_idx, x), ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    model = create_model()
    dataloaders = load_datasets()

    dev = gpu_device()
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoZygote(), loss, ((data_idx, x), y), train_state)
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
[  1/ 50]	       MNIST	Time 88.27585s	Training Accuracy: 58.01%	Test Accuracy: 46.88%
[  1/ 50]	FashionMNIST	Time 0.02446s	Training Accuracy: 52.15%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.02460s	Training Accuracy: 66.99%	Test Accuracy: 68.75%
[  2/ 50]	FashionMNIST	Time 0.02514s	Training Accuracy: 64.94%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.02564s	Training Accuracy: 74.51%	Test Accuracy: 62.50%
[  3/ 50]	FashionMNIST	Time 0.02598s	Training Accuracy: 64.75%	Test Accuracy: 53.12%
[  4/ 50]	       MNIST	Time 0.05349s	Training Accuracy: 74.41%	Test Accuracy: 59.38%
[  4/ 50]	FashionMNIST	Time 0.02134s	Training Accuracy: 63.38%	Test Accuracy: 56.25%
[  5/ 50]	       MNIST	Time 0.02132s	Training Accuracy: 77.54%	Test Accuracy: 62.50%
[  5/ 50]	FashionMNIST	Time 0.02118s	Training Accuracy: 67.87%	Test Accuracy: 59.38%
[  6/ 50]	       MNIST	Time 0.02120s	Training Accuracy: 84.47%	Test Accuracy: 59.38%
[  6/ 50]	FashionMNIST	Time 0.03953s	Training Accuracy: 70.61%	Test Accuracy: 65.62%
[  7/ 50]	       MNIST	Time 0.02204s	Training Accuracy: 87.30%	Test Accuracy: 75.00%
[  7/ 50]	FashionMNIST	Time 0.02045s	Training Accuracy: 66.41%	Test Accuracy: 46.88%
[  8/ 50]	       MNIST	Time 0.02065s	Training Accuracy: 91.11%	Test Accuracy: 81.25%
[  8/ 50]	FashionMNIST	Time 0.02127s	Training Accuracy: 75.49%	Test Accuracy: 53.12%
[  9/ 50]	       MNIST	Time 0.02198s	Training Accuracy: 93.36%	Test Accuracy: 84.38%
[  9/ 50]	FashionMNIST	Time 0.02122s	Training Accuracy: 75.88%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.02040s	Training Accuracy: 94.82%	Test Accuracy: 84.38%
[ 10/ 50]	FashionMNIST	Time 0.02252s	Training Accuracy: 76.66%	Test Accuracy: 59.38%
[ 11/ 50]	       MNIST	Time 0.02274s	Training Accuracy: 96.78%	Test Accuracy: 84.38%
[ 11/ 50]	FashionMNIST	Time 0.02050s	Training Accuracy: 75.20%	Test Accuracy: 62.50%
[ 12/ 50]	       MNIST	Time 0.02255s	Training Accuracy: 97.85%	Test Accuracy: 87.50%
[ 12/ 50]	FashionMNIST	Time 0.02147s	Training Accuracy: 78.32%	Test Accuracy: 59.38%
[ 13/ 50]	       MNIST	Time 0.02286s	Training Accuracy: 98.44%	Test Accuracy: 87.50%
[ 13/ 50]	FashionMNIST	Time 0.02389s	Training Accuracy: 80.37%	Test Accuracy: 65.62%
[ 14/ 50]	       MNIST	Time 0.02165s	Training Accuracy: 98.34%	Test Accuracy: 87.50%
[ 14/ 50]	FashionMNIST	Time 0.02102s	Training Accuracy: 79.98%	Test Accuracy: 59.38%
[ 15/ 50]	       MNIST	Time 0.02389s	Training Accuracy: 98.93%	Test Accuracy: 84.38%
[ 15/ 50]	FashionMNIST	Time 0.02076s	Training Accuracy: 80.86%	Test Accuracy: 62.50%
[ 16/ 50]	       MNIST	Time 0.02198s	Training Accuracy: 99.22%	Test Accuracy: 84.38%
[ 16/ 50]	FashionMNIST	Time 0.02283s	Training Accuracy: 82.91%	Test Accuracy: 65.62%
[ 17/ 50]	       MNIST	Time 0.02087s	Training Accuracy: 99.41%	Test Accuracy: 84.38%
[ 17/ 50]	FashionMNIST	Time 0.02075s	Training Accuracy: 84.57%	Test Accuracy: 59.38%
[ 18/ 50]	       MNIST	Time 0.02213s	Training Accuracy: 99.71%	Test Accuracy: 84.38%
[ 18/ 50]	FashionMNIST	Time 0.02183s	Training Accuracy: 85.25%	Test Accuracy: 65.62%
[ 19/ 50]	       MNIST	Time 0.02074s	Training Accuracy: 99.80%	Test Accuracy: 84.38%
[ 19/ 50]	FashionMNIST	Time 0.02277s	Training Accuracy: 85.84%	Test Accuracy: 65.62%
[ 20/ 50]	       MNIST	Time 0.02216s	Training Accuracy: 99.80%	Test Accuracy: 84.38%
[ 20/ 50]	FashionMNIST	Time 0.02070s	Training Accuracy: 86.91%	Test Accuracy: 62.50%
[ 21/ 50]	       MNIST	Time 0.02558s	Training Accuracy: 99.90%	Test Accuracy: 84.38%
[ 21/ 50]	FashionMNIST	Time 0.02074s	Training Accuracy: 87.30%	Test Accuracy: 62.50%
[ 22/ 50]	       MNIST	Time 0.02022s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 22/ 50]	FashionMNIST	Time 0.02249s	Training Accuracy: 87.99%	Test Accuracy: 62.50%
[ 23/ 50]	       MNIST	Time 0.02069s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 23/ 50]	FashionMNIST	Time 0.02081s	Training Accuracy: 87.60%	Test Accuracy: 68.75%
[ 24/ 50]	       MNIST	Time 0.02344s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 24/ 50]	FashionMNIST	Time 0.02160s	Training Accuracy: 88.67%	Test Accuracy: 62.50%
[ 25/ 50]	       MNIST	Time 0.02121s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 25/ 50]	FashionMNIST	Time 0.02252s	Training Accuracy: 88.96%	Test Accuracy: 65.62%
[ 26/ 50]	       MNIST	Time 0.02087s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 26/ 50]	FashionMNIST	Time 0.02066s	Training Accuracy: 88.96%	Test Accuracy: 62.50%
[ 27/ 50]	       MNIST	Time 0.02329s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 27/ 50]	FashionMNIST	Time 0.02078s	Training Accuracy: 89.84%	Test Accuracy: 65.62%
[ 28/ 50]	       MNIST	Time 0.02276s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 28/ 50]	FashionMNIST	Time 0.02285s	Training Accuracy: 89.75%	Test Accuracy: 65.62%
[ 29/ 50]	       MNIST	Time 0.02087s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 29/ 50]	FashionMNIST	Time 0.02064s	Training Accuracy: 89.65%	Test Accuracy: 68.75%
[ 30/ 50]	       MNIST	Time 0.02719s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 30/ 50]	FashionMNIST	Time 0.02065s	Training Accuracy: 88.87%	Test Accuracy: 65.62%
[ 31/ 50]	       MNIST	Time 0.02154s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 31/ 50]	FashionMNIST	Time 0.02290s	Training Accuracy: 89.36%	Test Accuracy: 68.75%
[ 32/ 50]	       MNIST	Time 0.02069s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 32/ 50]	FashionMNIST	Time 0.02048s	Training Accuracy: 89.65%	Test Accuracy: 68.75%
[ 33/ 50]	       MNIST	Time 0.02328s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 33/ 50]	FashionMNIST	Time 0.02044s	Training Accuracy: 90.23%	Test Accuracy: 71.88%
[ 34/ 50]	       MNIST	Time 0.02069s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 34/ 50]	FashionMNIST	Time 0.02398s	Training Accuracy: 90.72%	Test Accuracy: 65.62%
[ 35/ 50]	       MNIST	Time 0.02258s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 35/ 50]	FashionMNIST	Time 0.02103s	Training Accuracy: 91.02%	Test Accuracy: 65.62%
[ 36/ 50]	       MNIST	Time 0.02304s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 36/ 50]	FashionMNIST	Time 0.02057s	Training Accuracy: 90.43%	Test Accuracy: 71.88%
[ 37/ 50]	       MNIST	Time 0.02040s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 37/ 50]	FashionMNIST	Time 0.02203s	Training Accuracy: 91.50%	Test Accuracy: 71.88%
[ 38/ 50]	       MNIST	Time 0.02213s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 38/ 50]	FashionMNIST	Time 0.02045s	Training Accuracy: 91.41%	Test Accuracy: 68.75%
[ 39/ 50]	       MNIST	Time 0.02225s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 39/ 50]	FashionMNIST	Time 0.02042s	Training Accuracy: 91.89%	Test Accuracy: 65.62%
[ 40/ 50]	       MNIST	Time 0.02254s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 40/ 50]	FashionMNIST	Time 0.02756s	Training Accuracy: 92.19%	Test Accuracy: 71.88%
[ 41/ 50]	       MNIST	Time 0.02421s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 41/ 50]	FashionMNIST	Time 0.02099s	Training Accuracy: 92.38%	Test Accuracy: 65.62%
[ 42/ 50]	       MNIST	Time 0.02322s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 42/ 50]	FashionMNIST	Time 0.02150s	Training Accuracy: 93.26%	Test Accuracy: 65.62%
[ 43/ 50]	       MNIST	Time 0.02165s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 43/ 50]	FashionMNIST	Time 0.02323s	Training Accuracy: 92.48%	Test Accuracy: 71.88%
[ 44/ 50]	       MNIST	Time 0.02272s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 44/ 50]	FashionMNIST	Time 0.02271s	Training Accuracy: 93.07%	Test Accuracy: 71.88%
[ 45/ 50]	       MNIST	Time 0.02384s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 45/ 50]	FashionMNIST	Time 0.02160s	Training Accuracy: 92.19%	Test Accuracy: 71.88%
[ 46/ 50]	       MNIST	Time 0.02293s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 46/ 50]	FashionMNIST	Time 0.02301s	Training Accuracy: 88.77%	Test Accuracy: 71.88%
[ 47/ 50]	       MNIST	Time 0.02104s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 47/ 50]	FashionMNIST	Time 0.02125s	Training Accuracy: 90.53%	Test Accuracy: 68.75%
[ 48/ 50]	       MNIST	Time 0.02407s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 48/ 50]	FashionMNIST	Time 0.02186s	Training Accuracy: 90.14%	Test Accuracy: 65.62%
[ 49/ 50]	       MNIST	Time 0.02145s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 49/ 50]	FashionMNIST	Time 0.02530s	Training Accuracy: 90.82%	Test Accuracy: 68.75%
[ 50/ 50]	       MNIST	Time 0.02194s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 50/ 50]	FashionMNIST	Time 0.02140s	Training Accuracy: 92.09%	Test Accuracy: 75.00%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[FINAL]	FashionMNIST	Training Accuracy: 92.09%	Test Accuracy: 75.00%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.170 GiB / 4.750 GiB available)

This page was generated using Literate.jl.