Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
      Printf, Random, Setfield, Statistics, Zygote

CUDA.allowscalar(false)
Precompiling Lux...
    674.6 ms  ✓ ConcreteStructs
    581.7 ms  ✓ ExprTools
    726.5 ms  ✓ AbstractFFTs
    496.3 ms  ✓ IteratorInterfaceExtensions
    583.7 ms  ✓ StatsAPI
    565.1 ms  ✓ Future
   1480.7 ms  ✓ UnsafeAtomics
    842.5 ms  ✓ ADTypes
   1037.9 ms  ✓ InitialValues
    578.3 ms  ✓ CEnum
    796.6 ms  ✓ Serialization
    630.0 ms  ✓ InverseFunctions
    785.8 ms  ✓ Statistics
    573.2 ms  ✓ PrettyPrint
    595.2 ms  ✓ ArgCheck
    633.3 ms  ✓ ShowCases
    656.4 ms  ✓ CompilerSupportLibraries_jll
    632.6 ms  ✓ SuiteSparse_jll
    505.1 ms  ✓ DataValueInterfaces
    710.4 ms  ✓ OrderedCollections
    510.4 ms  ✓ Reexport
    584.9 ms  ✓ Zlib_jll
    610.4 ms  ✓ CompositionsBase
    539.5 ms  ✓ DefineSingletons
    645.6 ms  ✓ Adapt
    565.2 ms  ✓ DataAPI
    601.0 ms  ✓ StaticArraysCore
   1433.5 ms  ✓ Baselet
    640.5 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
    553.0 ms  ✓ TableTraits
    720.5 ms  ✓ Atomix
   3042.8 ms  ✓ TimerOutputs
    580.7 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    546.3 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    542.1 ms  ✓ ADTypes → ADTypesConstructionBaseExt
   2455.8 ms  ✓ Hwloc
   2291.0 ms  ✓ Distributed
    608.6 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    762.4 ms  ✓ Unitful → InverseFunctionsUnitfulExt
    645.6 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    622.7 ms  ✓ FillArrays → FillArraysStatisticsExt
    621.9 ms  ✓ NameResolution
   4194.2 ms  ✓ Test
   2120.3 ms  ✓ DataStructures
    834.8 ms  ✓ OpenSpecFun_jll
   4569.0 ms  ✓ SparseArrays
    571.5 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    631.2 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    572.6 ms  ✓ EnzymeCore → AdaptExt
    702.8 ms  ✓ PooledArrays
    651.2 ms  ✓ Missings
    598.6 ms  ✓ DiffResults
  19789.7 ms  ✓ MLStyle
   1148.3 ms  ✓ Tables
   7400.5 ms  ✓ LLVM
   1554.2 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
    844.1 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    722.3 ms  ✓ SortingAlgorithms
    866.6 ms  ✓ FillArrays → FillArraysSparseArraysExt
    845.8 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
   2876.3 ms  ✓ SpecialFunctions
    813.4 ms  ✓ SuiteSparse
    820.5 ms  ✓ Statistics → SparseArraysExt
   7907.2 ms  ✓ StaticArrays
   1469.4 ms  ✓ LLVM → BFloat16sExt
   4773.7 ms  ✓ JuliaVariables
   2030.3 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
    805.4 ms  ✓ SparseInverseSubset
   2597.4 ms  ✓ FixedPointNumbers
    796.8 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
   2768.1 ms  ✓ StatsBase
    768.0 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    808.9 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    782.5 ms  ✓ Adapt → AdaptStaticArraysExt
   2152.7 ms  ✓ UnsafeAtomicsLLVM
   4238.3 ms  ✓ ForwardDiff
   2634.4 ms  ✓ ColorTypes
   3172.1 ms  ✓ Accessors
  20809.5 ms  ✓ PrettyTables
   1049.1 ms  ✓ StructArrays
   1658.1 ms  ✓ Setfield
    710.6 ms  ✓ GPUArraysCore
   1060.1 ms  ✓ MLDataDevices
    818.2 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    556.5 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
   1116.5 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   5170.8 ms  ✓ KernelAbstractions
    801.9 ms  ✓ Accessors → AccessorsTestExt
   4609.6 ms  ✓ Colors
    987.5 ms  ✓ Accessors → AccessorsDatesExt
    833.3 ms  ✓ Accessors → AccessorsUnitfulExt
    899.4 ms  ✓ Accessors → AccessorsStaticArraysExt
    591.2 ms  ✓ StructArrays → StructArraysAdaptExt
    860.2 ms  ✓ StructArrays → StructArraysSparseArraysExt
    851.6 ms  ✓ StructArrays → StructArraysStaticArraysExt
    631.8 ms  ✓ Accessors → AccessorsStructArraysExt
   1416.0 ms  ✓ SplittablesBase
    654.2 ms  ✓ LuxCore → LuxCoreSetfieldExt
  28446.1 ms  ✓ GPUCompiler
   2683.0 ms  ✓ GPUArrays
    574.5 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    586.8 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
   3144.6 ms  ✓ WeightInitializers
    805.5 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
    646.9 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
    873.9 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
    709.2 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
   1793.0 ms  ✓ KernelAbstractions → LinearAlgebraExt
   2195.7 ms  ✓ KernelAbstractions → SparseArraysExt
   1626.1 ms  ✓ NVTX
   1949.4 ms  ✓ KernelAbstractions → EnzymeExt
   1010.3 ms  ✓ BangBang
   1598.2 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   1945.3 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
   1335.6 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   6445.9 ms  ✓ ChainRules
    981.5 ms  ✓ BangBang → BangBangStaticArraysExt
    684.2 ms  ✓ BangBang → BangBangChainRulesCoreExt
    687.3 ms  ✓ BangBang → BangBangStructArraysExt
    672.3 ms  ✓ BangBang → BangBangTablesExt
   7169.6 ms  ✓ NNlib
   1010.9 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   1094.7 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
   2260.1 ms  ✓ NNlib → NNlibEnzymeCoreExt
   2141.6 ms  ✓ NNlib → NNlibForwardDiffExt
  50523.7 ms  ✓ DataFrames
   2115.9 ms  ✓ BangBang → BangBangDataFramesExt
   1088.0 ms  ✓ MicroCollections
   3325.8 ms  ✓ Transducers
   1653.9 ms  ✓ Transducers → TransducersDataFramesExt
    827.9 ms  ✓ Transducers → TransducersAdaptExt
  36763.5 ms  ✓ Zygote
   2299.6 ms  ✓ Zygote → ZygoteColorsExt
   5637.7 ms  ✓ FLoops
   1932.9 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  60851.6 ms  ✓ CUDA
   5607.3 ms  ✓ CUDA → ChainRulesCoreExt
   5713.0 ms  ✓ CUDA → SpecialFunctionsExt
   5755.6 ms  ✓ CUDA → EnzymeCoreExt
   5565.3 ms  ✓ ArrayInterface → ArrayInterfaceCUDAExt
   5532.3 ms  ✓ MLDataDevices → MLDataDevicesCUDAExt
   5731.8 ms  ✓ WeightInitializers → WeightInitializersCUDAExt
    860.0 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
   5764.6 ms  ✓ NNlib → NNlibCUDAExt
   9266.4 ms  ✓ cuDNN
   5695.7 ms  ✓ MLDataDevices → MLDataDevicescuDNNExt
   6169.7 ms  ✓ NNlib → NNlibCUDACUDNNExt
   2150.6 ms  ✓ OneHotArrays
   2000.2 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
   7667.2 ms  ✓ MLUtils
   2773.1 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
   6991.5 ms  ✓ LuxLib
   6184.8 ms  ✓ LuxLib → LuxLibCUDAExt
   6310.6 ms  ✓ LuxLib → LuxLibcuDNNExt
  10804.5 ms  ✓ Lux
   3758.0 ms  ✓ Lux → LuxMLUtilsExt
   4346.6 ms  ✓ Lux → LuxZygoteExt
  157 dependencies successfully precompiled in 240 seconds. 106 already precompiled.
Precompiling ComponentArrays...
   1522.1 ms  ✓ ComponentArrays
   2167.8 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
   3364.0 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 6 seconds. 156 already precompiled.
Precompiling LuxComponentArraysExt...
    717.3 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1848.9 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
   2468.5 ms  ✓ Lux → LuxComponentArraysExt
  3 dependencies successfully precompiled in 4 seconds. 266 already precompiled.
Precompiling LuxCUDA...
   5812.7 ms  ✓ LuxCUDA
  1 dependency successfully precompiled in 6 seconds. 123 already precompiled.
Precompiling MLDatasets...
    748.0 ms  ✓ TranscodingStreams
    867.4 ms  ✓ GZip
    999.3 ms  ✓ ConcurrentUtilities
    831.7 ms  ✓ ZipFile
    707.2 ms  ✓ ExceptionUnwrapping
    995.3 ms  ✓ WeakRefStrings
   2410.2 ms  ✓ ColorVectorSpace
   1772.1 ms  ✓ MPICH_jll
   1488.5 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   1570.1 ms  ✓ CodecZlib
   2542.0 ms  ✓ AtomsBase
    944.4 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   2266.1 ms  ✓ HDF5_jll
   5365.0 ms  ✓ StridedViews → StridedViewsCUDAExt
   2652.3 ms  ✓ Chemfiles
  20864.9 ms  ✓ CSV
  20584.1 ms  ✓ HTTP
  19962.3 ms  ✓ ImageCore
   4499.2 ms  ✓ ColorSchemes
   2801.7 ms  ✓ Pickle
   3777.2 ms  ✓ DataDeps
   2365.0 ms  ✓ FileIO → HTTPExt
   8569.2 ms  ✓ HDF5
   2488.5 ms  ✓ ImageBase
   1835.7 ms  ✓ NPZ
   2704.4 ms  ✓ MAT
   2303.0 ms  ✓ ImageShow
  35275.9 ms  ✓ JLD2
  10826.6 ms  ✓ MLDatasets
  29 dependencies successfully precompiled in 83 seconds. 221 already precompiled.

Loading Datasets

julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
    imgs, labels = dset(:train)[1:n_train]
    x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)

    imgs, labels = dset(:test)[1:n_eval]
    x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

    return (
        DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
        DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)
    )
end

function load_datasets(n_train=1024, n_eval=32, batchsize=256)
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 4 methods)

Implement a HyperNet Layer

julia
function HyperNet(
        weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
    ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
              ComponentArray |>
              getaxes
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end

Create and Initialize the HyperNet

julia
function create_model()
    # Doesn't need to be a MLP can have any Lux Layer
    core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
    weight_generator = Chain(
        Embedding(2 => 32),
        Dense(32, 64, relu),
        Dense(64, Lux.parameterlength(core_network))
    )

    model = HyperNet(weight_generator, core_network)
    return model
end
create_model (generic function with 1 method)

Define Utility Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model((data_idx, x), ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    model = create_model()
    dataloaders = load_datasets()

    dev = gpu_device()
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoZygote(), loss, ((data_idx, x), y), train_state)
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
[  1/ 50]	       MNIST	Time 88.45370s	Training Accuracy: 58.01%	Test Accuracy: 46.88%
[  1/ 50]	FashionMNIST	Time 0.02975s	Training Accuracy: 50.98%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.06977s	Training Accuracy: 67.19%	Test Accuracy: 68.75%
[  2/ 50]	FashionMNIST	Time 0.02853s	Training Accuracy: 59.47%	Test Accuracy: 53.12%
[  3/ 50]	       MNIST	Time 0.02840s	Training Accuracy: 78.12%	Test Accuracy: 71.88%
[  3/ 50]	FashionMNIST	Time 0.02953s	Training Accuracy: 69.63%	Test Accuracy: 62.50%
[  4/ 50]	       MNIST	Time 0.03371s	Training Accuracy: 78.32%	Test Accuracy: 62.50%
[  4/ 50]	FashionMNIST	Time 0.02054s	Training Accuracy: 65.62%	Test Accuracy: 65.62%
[  5/ 50]	       MNIST	Time 0.02133s	Training Accuracy: 77.25%	Test Accuracy: 68.75%
[  5/ 50]	FashionMNIST	Time 0.02112s	Training Accuracy: 72.07%	Test Accuracy: 56.25%
[  6/ 50]	       MNIST	Time 0.02102s	Training Accuracy: 86.62%	Test Accuracy: 78.12%
[  6/ 50]	FashionMNIST	Time 0.02481s	Training Accuracy: 74.41%	Test Accuracy: 71.88%
[  7/ 50]	       MNIST	Time 0.02066s	Training Accuracy: 88.57%	Test Accuracy: 84.38%
[  7/ 50]	FashionMNIST	Time 0.02103s	Training Accuracy: 73.24%	Test Accuracy: 65.62%
[  8/ 50]	       MNIST	Time 0.02169s	Training Accuracy: 90.92%	Test Accuracy: 81.25%
[  8/ 50]	FashionMNIST	Time 0.02057s	Training Accuracy: 76.95%	Test Accuracy: 65.62%
[  9/ 50]	       MNIST	Time 0.02215s	Training Accuracy: 93.07%	Test Accuracy: 84.38%
[  9/ 50]	FashionMNIST	Time 0.02106s	Training Accuracy: 80.66%	Test Accuracy: 78.12%
[ 10/ 50]	       MNIST	Time 0.02085s	Training Accuracy: 95.70%	Test Accuracy: 78.12%
[ 10/ 50]	FashionMNIST	Time 0.02199s	Training Accuracy: 81.25%	Test Accuracy: 75.00%
[ 11/ 50]	       MNIST	Time 0.02279s	Training Accuracy: 97.36%	Test Accuracy: 78.12%
[ 11/ 50]	FashionMNIST	Time 0.02128s	Training Accuracy: 81.45%	Test Accuracy: 81.25%
[ 12/ 50]	       MNIST	Time 0.02074s	Training Accuracy: 98.05%	Test Accuracy: 78.12%
[ 12/ 50]	FashionMNIST	Time 0.02565s	Training Accuracy: 81.84%	Test Accuracy: 78.12%
[ 13/ 50]	       MNIST	Time 0.02064s	Training Accuracy: 98.54%	Test Accuracy: 78.12%
[ 13/ 50]	FashionMNIST	Time 0.02058s	Training Accuracy: 81.93%	Test Accuracy: 81.25%
[ 14/ 50]	       MNIST	Time 0.02040s	Training Accuracy: 99.02%	Test Accuracy: 81.25%
[ 14/ 50]	FashionMNIST	Time 0.02248s	Training Accuracy: 85.16%	Test Accuracy: 78.12%
[ 15/ 50]	       MNIST	Time 0.02229s	Training Accuracy: 99.32%	Test Accuracy: 78.12%
[ 15/ 50]	FashionMNIST	Time 0.02102s	Training Accuracy: 85.84%	Test Accuracy: 75.00%
[ 16/ 50]	       MNIST	Time 0.02095s	Training Accuracy: 99.32%	Test Accuracy: 78.12%
[ 16/ 50]	FashionMNIST	Time 0.02068s	Training Accuracy: 86.43%	Test Accuracy: 78.12%
[ 17/ 50]	       MNIST	Time 0.02081s	Training Accuracy: 99.51%	Test Accuracy: 78.12%
[ 17/ 50]	FashionMNIST	Time 0.02611s	Training Accuracy: 86.72%	Test Accuracy: 78.12%
[ 18/ 50]	       MNIST	Time 0.02004s	Training Accuracy: 99.80%	Test Accuracy: 81.25%
[ 18/ 50]	FashionMNIST	Time 0.02113s	Training Accuracy: 88.48%	Test Accuracy: 78.12%
[ 19/ 50]	       MNIST	Time 0.02119s	Training Accuracy: 99.80%	Test Accuracy: 78.12%
[ 19/ 50]	FashionMNIST	Time 0.02131s	Training Accuracy: 88.57%	Test Accuracy: 84.38%
[ 20/ 50]	       MNIST	Time 0.02227s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 20/ 50]	FashionMNIST	Time 0.02106s	Training Accuracy: 89.06%	Test Accuracy: 84.38%
[ 21/ 50]	       MNIST	Time 0.02123s	Training Accuracy: 99.90%	Test Accuracy: 78.12%
[ 21/ 50]	FashionMNIST	Time 0.02278s	Training Accuracy: 89.45%	Test Accuracy: 81.25%
[ 22/ 50]	       MNIST	Time 0.02143s	Training Accuracy: 99.90%	Test Accuracy: 81.25%
[ 22/ 50]	FashionMNIST	Time 0.02074s	Training Accuracy: 89.45%	Test Accuracy: 84.38%
[ 23/ 50]	       MNIST	Time 0.02552s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 23/ 50]	FashionMNIST	Time 0.02061s	Training Accuracy: 90.14%	Test Accuracy: 84.38%
[ 24/ 50]	       MNIST	Time 0.02316s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 24/ 50]	FashionMNIST	Time 0.02123s	Training Accuracy: 91.70%	Test Accuracy: 84.38%
[ 25/ 50]	       MNIST	Time 0.02055s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 25/ 50]	FashionMNIST	Time 0.02182s	Training Accuracy: 91.21%	Test Accuracy: 87.50%
[ 26/ 50]	       MNIST	Time 0.02166s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 26/ 50]	FashionMNIST	Time 0.02073s	Training Accuracy: 91.70%	Test Accuracy: 90.62%
[ 27/ 50]	       MNIST	Time 0.02021s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 27/ 50]	FashionMNIST	Time 0.02205s	Training Accuracy: 92.19%	Test Accuracy: 84.38%
[ 28/ 50]	       MNIST	Time 0.02247s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 28/ 50]	FashionMNIST	Time 0.02149s	Training Accuracy: 92.68%	Test Accuracy: 90.62%
[ 29/ 50]	       MNIST	Time 0.02136s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 29/ 50]	FashionMNIST	Time 0.02383s	Training Accuracy: 92.77%	Test Accuracy: 87.50%
[ 30/ 50]	       MNIST	Time 0.02186s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 30/ 50]	FashionMNIST	Time 0.02068s	Training Accuracy: 93.85%	Test Accuracy: 90.62%
[ 31/ 50]	       MNIST	Time 0.02048s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 31/ 50]	FashionMNIST	Time 0.02130s	Training Accuracy: 93.85%	Test Accuracy: 90.62%
[ 32/ 50]	       MNIST	Time 0.02268s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 32/ 50]	FashionMNIST	Time 0.02070s	Training Accuracy: 94.63%	Test Accuracy: 90.62%
[ 33/ 50]	       MNIST	Time 0.02138s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 33/ 50]	FashionMNIST	Time 0.02640s	Training Accuracy: 94.24%	Test Accuracy: 87.50%
[ 34/ 50]	       MNIST	Time 0.02085s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 34/ 50]	FashionMNIST	Time 0.02125s	Training Accuracy: 95.02%	Test Accuracy: 87.50%
[ 35/ 50]	       MNIST	Time 0.02136s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 35/ 50]	FashionMNIST	Time 0.02131s	Training Accuracy: 95.41%	Test Accuracy: 90.62%
[ 36/ 50]	       MNIST	Time 0.02093s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 36/ 50]	FashionMNIST	Time 0.02084s	Training Accuracy: 95.51%	Test Accuracy: 90.62%
[ 37/ 50]	       MNIST	Time 0.02103s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 37/ 50]	FashionMNIST	Time 0.02101s	Training Accuracy: 95.61%	Test Accuracy: 90.62%
[ 38/ 50]	       MNIST	Time 0.02128s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 38/ 50]	FashionMNIST	Time 0.02100s	Training Accuracy: 95.80%	Test Accuracy: 87.50%
[ 39/ 50]	       MNIST	Time 0.02025s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 39/ 50]	FashionMNIST	Time 0.02092s	Training Accuracy: 96.00%	Test Accuracy: 90.62%
[ 40/ 50]	       MNIST	Time 0.02272s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 40/ 50]	FashionMNIST	Time 0.02348s	Training Accuracy: 96.29%	Test Accuracy: 87.50%
[ 41/ 50]	       MNIST	Time 0.02093s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 41/ 50]	FashionMNIST	Time 0.02076s	Training Accuracy: 96.19%	Test Accuracy: 90.62%
[ 42/ 50]	       MNIST	Time 0.02086s	Training Accuracy: 100.00%	Test Accuracy: 81.25%
[ 42/ 50]	FashionMNIST	Time 0.02919s	Training Accuracy: 96.68%	Test Accuracy: 87.50%
[ 43/ 50]	       MNIST	Time 0.02177s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 43/ 50]	FashionMNIST	Time 0.02187s	Training Accuracy: 96.58%	Test Accuracy: 87.50%
[ 44/ 50]	       MNIST	Time 0.02563s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 44/ 50]	FashionMNIST	Time 0.02175s	Training Accuracy: 96.58%	Test Accuracy: 87.50%
[ 45/ 50]	       MNIST	Time 0.02170s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 45/ 50]	FashionMNIST	Time 0.02085s	Training Accuracy: 96.68%	Test Accuracy: 87.50%
[ 46/ 50]	       MNIST	Time 0.02218s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 46/ 50]	FashionMNIST	Time 0.02262s	Training Accuracy: 96.78%	Test Accuracy: 87.50%
[ 47/ 50]	       MNIST	Time 0.02142s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 47/ 50]	FashionMNIST	Time 0.02121s	Training Accuracy: 96.88%	Test Accuracy: 84.38%
[ 48/ 50]	       MNIST	Time 0.02073s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 48/ 50]	FashionMNIST	Time 0.02216s	Training Accuracy: 97.07%	Test Accuracy: 87.50%
[ 49/ 50]	       MNIST	Time 0.02142s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 49/ 50]	FashionMNIST	Time 0.02106s	Training Accuracy: 97.36%	Test Accuracy: 84.38%
[ 50/ 50]	       MNIST	Time 0.02152s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 50/ 50]	FashionMNIST	Time 0.02239s	Training Accuracy: 97.46%	Test Accuracy: 84.38%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[FINAL]	FashionMNIST	Training Accuracy: 97.46%	Test Accuracy: 84.38%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.1
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.170 GiB / 4.750 GiB available)

This page was generated using Literate.jl.