Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
CUDA.allowscalar(false)Precompiling Lux...
674.6 ms ✓ ConcreteStructs
581.7 ms ✓ ExprTools
726.5 ms ✓ AbstractFFTs
496.3 ms ✓ IteratorInterfaceExtensions
583.7 ms ✓ StatsAPI
565.1 ms ✓ Future
1480.7 ms ✓ UnsafeAtomics
842.5 ms ✓ ADTypes
1037.9 ms ✓ InitialValues
578.3 ms ✓ CEnum
796.6 ms ✓ Serialization
630.0 ms ✓ InverseFunctions
785.8 ms ✓ Statistics
573.2 ms ✓ PrettyPrint
595.2 ms ✓ ArgCheck
633.3 ms ✓ ShowCases
656.4 ms ✓ CompilerSupportLibraries_jll
632.6 ms ✓ SuiteSparse_jll
505.1 ms ✓ DataValueInterfaces
710.4 ms ✓ OrderedCollections
510.4 ms ✓ Reexport
584.9 ms ✓ Zlib_jll
610.4 ms ✓ CompositionsBase
539.5 ms ✓ DefineSingletons
645.6 ms ✓ Adapt
565.2 ms ✓ DataAPI
601.0 ms ✓ StaticArraysCore
1433.5 ms ✓ Baselet
640.5 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
553.0 ms ✓ TableTraits
720.5 ms ✓ Atomix
3042.8 ms ✓ TimerOutputs
580.7 ms ✓ ADTypes → ADTypesChainRulesCoreExt
546.3 ms ✓ ADTypes → ADTypesEnzymeCoreExt
542.1 ms ✓ ADTypes → ADTypesConstructionBaseExt
2455.8 ms ✓ Hwloc
2291.0 ms ✓ Distributed
608.6 ms ✓ InverseFunctions → InverseFunctionsDatesExt
762.4 ms ✓ Unitful → InverseFunctionsUnitfulExt
645.6 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
622.7 ms ✓ FillArrays → FillArraysStatisticsExt
621.9 ms ✓ NameResolution
4194.2 ms ✓ Test
2120.3 ms ✓ DataStructures
834.8 ms ✓ OpenSpecFun_jll
4569.0 ms ✓ SparseArrays
571.5 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
631.2 ms ✓ OffsetArrays → OffsetArraysAdaptExt
572.6 ms ✓ EnzymeCore → AdaptExt
702.8 ms ✓ PooledArrays
651.2 ms ✓ Missings
598.6 ms ✓ DiffResults
19789.7 ms ✓ MLStyle
1148.3 ms ✓ Tables
7400.5 ms ✓ LLVM
1554.2 ms ✓ AbstractFFTs → AbstractFFTsTestExt
844.1 ms ✓ InverseFunctions → InverseFunctionsTestExt
722.3 ms ✓ SortingAlgorithms
866.6 ms ✓ FillArrays → FillArraysSparseArraysExt
845.8 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
2876.3 ms ✓ SpecialFunctions
813.4 ms ✓ SuiteSparse
820.5 ms ✓ Statistics → SparseArraysExt
7907.2 ms ✓ StaticArrays
1469.4 ms ✓ LLVM → BFloat16sExt
4773.7 ms ✓ JuliaVariables
2030.3 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
805.4 ms ✓ SparseInverseSubset
2597.4 ms ✓ FixedPointNumbers
796.8 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
2768.1 ms ✓ StatsBase
768.0 ms ✓ StaticArrays → StaticArraysStatisticsExt
808.9 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
782.5 ms ✓ Adapt → AdaptStaticArraysExt
2152.7 ms ✓ UnsafeAtomicsLLVM
4238.3 ms ✓ ForwardDiff
2634.4 ms ✓ ColorTypes
3172.1 ms ✓ Accessors
20809.5 ms ✓ PrettyTables
1049.1 ms ✓ StructArrays
1658.1 ms ✓ Setfield
710.6 ms ✓ GPUArraysCore
1060.1 ms ✓ MLDataDevices
818.2 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
556.5 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
1116.5 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
5170.8 ms ✓ KernelAbstractions
801.9 ms ✓ Accessors → AccessorsTestExt
4609.6 ms ✓ Colors
987.5 ms ✓ Accessors → AccessorsDatesExt
833.3 ms ✓ Accessors → AccessorsUnitfulExt
899.4 ms ✓ Accessors → AccessorsStaticArraysExt
591.2 ms ✓ StructArrays → StructArraysAdaptExt
860.2 ms ✓ StructArrays → StructArraysSparseArraysExt
851.6 ms ✓ StructArrays → StructArraysStaticArraysExt
631.8 ms ✓ Accessors → AccessorsStructArraysExt
1416.0 ms ✓ SplittablesBase
654.2 ms ✓ LuxCore → LuxCoreSetfieldExt
28446.1 ms ✓ GPUCompiler
2683.0 ms ✓ GPUArrays
574.5 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
586.8 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
3144.6 ms ✓ WeightInitializers
805.5 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
646.9 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
873.9 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
709.2 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
1793.0 ms ✓ KernelAbstractions → LinearAlgebraExt
2195.7 ms ✓ KernelAbstractions → SparseArraysExt
1626.1 ms ✓ NVTX
1949.4 ms ✓ KernelAbstractions → EnzymeExt
1010.3 ms ✓ BangBang
1598.2 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
1945.3 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
1335.6 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
6445.9 ms ✓ ChainRules
981.5 ms ✓ BangBang → BangBangStaticArraysExt
684.2 ms ✓ BangBang → BangBangChainRulesCoreExt
687.3 ms ✓ BangBang → BangBangStructArraysExt
672.3 ms ✓ BangBang → BangBangTablesExt
7169.6 ms ✓ NNlib
1010.9 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1094.7 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
2260.1 ms ✓ NNlib → NNlibEnzymeCoreExt
2141.6 ms ✓ NNlib → NNlibForwardDiffExt
50523.7 ms ✓ DataFrames
2115.9 ms ✓ BangBang → BangBangDataFramesExt
1088.0 ms ✓ MicroCollections
3325.8 ms ✓ Transducers
1653.9 ms ✓ Transducers → TransducersDataFramesExt
827.9 ms ✓ Transducers → TransducersAdaptExt
36763.5 ms ✓ Zygote
2299.6 ms ✓ Zygote → ZygoteColorsExt
5637.7 ms ✓ FLoops
1932.9 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
60851.6 ms ✓ CUDA
5607.3 ms ✓ CUDA → ChainRulesCoreExt
5713.0 ms ✓ CUDA → SpecialFunctionsExt
5755.6 ms ✓ CUDA → EnzymeCoreExt
5565.3 ms ✓ ArrayInterface → ArrayInterfaceCUDAExt
5532.3 ms ✓ MLDataDevices → MLDataDevicesCUDAExt
5731.8 ms ✓ WeightInitializers → WeightInitializersCUDAExt
860.0 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
5764.6 ms ✓ NNlib → NNlibCUDAExt
9266.4 ms ✓ cuDNN
5695.7 ms ✓ MLDataDevices → MLDataDevicescuDNNExt
6169.7 ms ✓ NNlib → NNlibCUDACUDNNExt
2150.6 ms ✓ OneHotArrays
2000.2 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
7667.2 ms ✓ MLUtils
2773.1 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
6991.5 ms ✓ LuxLib
6184.8 ms ✓ LuxLib → LuxLibCUDAExt
6310.6 ms ✓ LuxLib → LuxLibcuDNNExt
10804.5 ms ✓ Lux
3758.0 ms ✓ Lux → LuxMLUtilsExt
4346.6 ms ✓ Lux → LuxZygoteExt
157 dependencies successfully precompiled in 240 seconds. 106 already precompiled.
Precompiling ComponentArrays...
1522.1 ms ✓ ComponentArrays
2167.8 ms ✓ ComponentArrays → ComponentArraysGPUArraysExt
3364.0 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
3 dependencies successfully precompiled in 6 seconds. 156 already precompiled.
Precompiling LuxComponentArraysExt...
717.3 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1848.9 ms ✓ ComponentArrays → ComponentArraysZygoteExt
2468.5 ms ✓ Lux → LuxComponentArraysExt
3 dependencies successfully precompiled in 4 seconds. 266 already precompiled.
Precompiling LuxCUDA...
5812.7 ms ✓ LuxCUDA
1 dependency successfully precompiled in 6 seconds. 123 already precompiled.
Precompiling MLDatasets...
748.0 ms ✓ TranscodingStreams
867.4 ms ✓ GZip
999.3 ms ✓ ConcurrentUtilities
831.7 ms ✓ ZipFile
707.2 ms ✓ ExceptionUnwrapping
995.3 ms ✓ WeakRefStrings
2410.2 ms ✓ ColorVectorSpace
1772.1 ms ✓ MPICH_jll
1488.5 ms ✓ FilePathsBase → FilePathsBaseTestExt
1570.1 ms ✓ CodecZlib
2542.0 ms ✓ AtomsBase
944.4 ms ✓ ColorVectorSpace → SpecialFunctionsExt
2266.1 ms ✓ HDF5_jll
5365.0 ms ✓ StridedViews → StridedViewsCUDAExt
2652.3 ms ✓ Chemfiles
20864.9 ms ✓ CSV
20584.1 ms ✓ HTTP
19962.3 ms ✓ ImageCore
4499.2 ms ✓ ColorSchemes
2801.7 ms ✓ Pickle
3777.2 ms ✓ DataDeps
2365.0 ms ✓ FileIO → HTTPExt
8569.2 ms ✓ HDF5
2488.5 ms ✓ ImageBase
1835.7 ms ✓ NPZ
2704.4 ms ✓ MAT
2303.0 ms ✓ ImageShow
35275.9 ms ✓ JLD2
10826.6 ms ✓ MLDatasets
29 dependencies successfully precompiled in 83 seconds. 221 already precompiled.Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Int, n_eval::Int, batchsize::Int) where {dset}
imgs, labels = dset(:train)[1:n_train]
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
imgs, labels = dset(:test)[1:n_eval]
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)
)
end
function load_datasets(n_train=1024, n_eval=32, batchsize=256)
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
endload_datasets (generic function with 4 methods)Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
endHyperNet (generic function with 1 method)Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
endCreate and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
)
model = HyperNet(weight_generator, core_network)
return model
endcreate_model (generic function with 1 method)Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
endaccuracy (generic function with 1 method)Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()[ 1/ 50] MNIST Time 88.45370s Training Accuracy: 58.01% Test Accuracy: 46.88%
[ 1/ 50] FashionMNIST Time 0.02975s Training Accuracy: 50.98% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.06977s Training Accuracy: 67.19% Test Accuracy: 68.75%
[ 2/ 50] FashionMNIST Time 0.02853s Training Accuracy: 59.47% Test Accuracy: 53.12%
[ 3/ 50] MNIST Time 0.02840s Training Accuracy: 78.12% Test Accuracy: 71.88%
[ 3/ 50] FashionMNIST Time 0.02953s Training Accuracy: 69.63% Test Accuracy: 62.50%
[ 4/ 50] MNIST Time 0.03371s Training Accuracy: 78.32% Test Accuracy: 62.50%
[ 4/ 50] FashionMNIST Time 0.02054s Training Accuracy: 65.62% Test Accuracy: 65.62%
[ 5/ 50] MNIST Time 0.02133s Training Accuracy: 77.25% Test Accuracy: 68.75%
[ 5/ 50] FashionMNIST Time 0.02112s Training Accuracy: 72.07% Test Accuracy: 56.25%
[ 6/ 50] MNIST Time 0.02102s Training Accuracy: 86.62% Test Accuracy: 78.12%
[ 6/ 50] FashionMNIST Time 0.02481s Training Accuracy: 74.41% Test Accuracy: 71.88%
[ 7/ 50] MNIST Time 0.02066s Training Accuracy: 88.57% Test Accuracy: 84.38%
[ 7/ 50] FashionMNIST Time 0.02103s Training Accuracy: 73.24% Test Accuracy: 65.62%
[ 8/ 50] MNIST Time 0.02169s Training Accuracy: 90.92% Test Accuracy: 81.25%
[ 8/ 50] FashionMNIST Time 0.02057s Training Accuracy: 76.95% Test Accuracy: 65.62%
[ 9/ 50] MNIST Time 0.02215s Training Accuracy: 93.07% Test Accuracy: 84.38%
[ 9/ 50] FashionMNIST Time 0.02106s Training Accuracy: 80.66% Test Accuracy: 78.12%
[ 10/ 50] MNIST Time 0.02085s Training Accuracy: 95.70% Test Accuracy: 78.12%
[ 10/ 50] FashionMNIST Time 0.02199s Training Accuracy: 81.25% Test Accuracy: 75.00%
[ 11/ 50] MNIST Time 0.02279s Training Accuracy: 97.36% Test Accuracy: 78.12%
[ 11/ 50] FashionMNIST Time 0.02128s Training Accuracy: 81.45% Test Accuracy: 81.25%
[ 12/ 50] MNIST Time 0.02074s Training Accuracy: 98.05% Test Accuracy: 78.12%
[ 12/ 50] FashionMNIST Time 0.02565s Training Accuracy: 81.84% Test Accuracy: 78.12%
[ 13/ 50] MNIST Time 0.02064s Training Accuracy: 98.54% Test Accuracy: 78.12%
[ 13/ 50] FashionMNIST Time 0.02058s Training Accuracy: 81.93% Test Accuracy: 81.25%
[ 14/ 50] MNIST Time 0.02040s Training Accuracy: 99.02% Test Accuracy: 81.25%
[ 14/ 50] FashionMNIST Time 0.02248s Training Accuracy: 85.16% Test Accuracy: 78.12%
[ 15/ 50] MNIST Time 0.02229s Training Accuracy: 99.32% Test Accuracy: 78.12%
[ 15/ 50] FashionMNIST Time 0.02102s Training Accuracy: 85.84% Test Accuracy: 75.00%
[ 16/ 50] MNIST Time 0.02095s Training Accuracy: 99.32% Test Accuracy: 78.12%
[ 16/ 50] FashionMNIST Time 0.02068s Training Accuracy: 86.43% Test Accuracy: 78.12%
[ 17/ 50] MNIST Time 0.02081s Training Accuracy: 99.51% Test Accuracy: 78.12%
[ 17/ 50] FashionMNIST Time 0.02611s Training Accuracy: 86.72% Test Accuracy: 78.12%
[ 18/ 50] MNIST Time 0.02004s Training Accuracy: 99.80% Test Accuracy: 81.25%
[ 18/ 50] FashionMNIST Time 0.02113s Training Accuracy: 88.48% Test Accuracy: 78.12%
[ 19/ 50] MNIST Time 0.02119s Training Accuracy: 99.80% Test Accuracy: 78.12%
[ 19/ 50] FashionMNIST Time 0.02131s Training Accuracy: 88.57% Test Accuracy: 84.38%
[ 20/ 50] MNIST Time 0.02227s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 20/ 50] FashionMNIST Time 0.02106s Training Accuracy: 89.06% Test Accuracy: 84.38%
[ 21/ 50] MNIST Time 0.02123s Training Accuracy: 99.90% Test Accuracy: 78.12%
[ 21/ 50] FashionMNIST Time 0.02278s Training Accuracy: 89.45% Test Accuracy: 81.25%
[ 22/ 50] MNIST Time 0.02143s Training Accuracy: 99.90% Test Accuracy: 81.25%
[ 22/ 50] FashionMNIST Time 0.02074s Training Accuracy: 89.45% Test Accuracy: 84.38%
[ 23/ 50] MNIST Time 0.02552s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 23/ 50] FashionMNIST Time 0.02061s Training Accuracy: 90.14% Test Accuracy: 84.38%
[ 24/ 50] MNIST Time 0.02316s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 24/ 50] FashionMNIST Time 0.02123s Training Accuracy: 91.70% Test Accuracy: 84.38%
[ 25/ 50] MNIST Time 0.02055s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 25/ 50] FashionMNIST Time 0.02182s Training Accuracy: 91.21% Test Accuracy: 87.50%
[ 26/ 50] MNIST Time 0.02166s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 26/ 50] FashionMNIST Time 0.02073s Training Accuracy: 91.70% Test Accuracy: 90.62%
[ 27/ 50] MNIST Time 0.02021s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 27/ 50] FashionMNIST Time 0.02205s Training Accuracy: 92.19% Test Accuracy: 84.38%
[ 28/ 50] MNIST Time 0.02247s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 28/ 50] FashionMNIST Time 0.02149s Training Accuracy: 92.68% Test Accuracy: 90.62%
[ 29/ 50] MNIST Time 0.02136s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 29/ 50] FashionMNIST Time 0.02383s Training Accuracy: 92.77% Test Accuracy: 87.50%
[ 30/ 50] MNIST Time 0.02186s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 30/ 50] FashionMNIST Time 0.02068s Training Accuracy: 93.85% Test Accuracy: 90.62%
[ 31/ 50] MNIST Time 0.02048s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 31/ 50] FashionMNIST Time 0.02130s Training Accuracy: 93.85% Test Accuracy: 90.62%
[ 32/ 50] MNIST Time 0.02268s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 32/ 50] FashionMNIST Time 0.02070s Training Accuracy: 94.63% Test Accuracy: 90.62%
[ 33/ 50] MNIST Time 0.02138s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 33/ 50] FashionMNIST Time 0.02640s Training Accuracy: 94.24% Test Accuracy: 87.50%
[ 34/ 50] MNIST Time 0.02085s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 34/ 50] FashionMNIST Time 0.02125s Training Accuracy: 95.02% Test Accuracy: 87.50%
[ 35/ 50] MNIST Time 0.02136s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 35/ 50] FashionMNIST Time 0.02131s Training Accuracy: 95.41% Test Accuracy: 90.62%
[ 36/ 50] MNIST Time 0.02093s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 36/ 50] FashionMNIST Time 0.02084s Training Accuracy: 95.51% Test Accuracy: 90.62%
[ 37/ 50] MNIST Time 0.02103s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 37/ 50] FashionMNIST Time 0.02101s Training Accuracy: 95.61% Test Accuracy: 90.62%
[ 38/ 50] MNIST Time 0.02128s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 38/ 50] FashionMNIST Time 0.02100s Training Accuracy: 95.80% Test Accuracy: 87.50%
[ 39/ 50] MNIST Time 0.02025s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 39/ 50] FashionMNIST Time 0.02092s Training Accuracy: 96.00% Test Accuracy: 90.62%
[ 40/ 50] MNIST Time 0.02272s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 40/ 50] FashionMNIST Time 0.02348s Training Accuracy: 96.29% Test Accuracy: 87.50%
[ 41/ 50] MNIST Time 0.02093s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 41/ 50] FashionMNIST Time 0.02076s Training Accuracy: 96.19% Test Accuracy: 90.62%
[ 42/ 50] MNIST Time 0.02086s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 42/ 50] FashionMNIST Time 0.02919s Training Accuracy: 96.68% Test Accuracy: 87.50%
[ 43/ 50] MNIST Time 0.02177s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 43/ 50] FashionMNIST Time 0.02187s Training Accuracy: 96.58% Test Accuracy: 87.50%
[ 44/ 50] MNIST Time 0.02563s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 44/ 50] FashionMNIST Time 0.02175s Training Accuracy: 96.58% Test Accuracy: 87.50%
[ 45/ 50] MNIST Time 0.02170s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 45/ 50] FashionMNIST Time 0.02085s Training Accuracy: 96.68% Test Accuracy: 87.50%
[ 46/ 50] MNIST Time 0.02218s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 46/ 50] FashionMNIST Time 0.02262s Training Accuracy: 96.78% Test Accuracy: 87.50%
[ 47/ 50] MNIST Time 0.02142s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 47/ 50] FashionMNIST Time 0.02121s Training Accuracy: 96.88% Test Accuracy: 84.38%
[ 48/ 50] MNIST Time 0.02073s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 48/ 50] FashionMNIST Time 0.02216s Training Accuracy: 97.07% Test Accuracy: 87.50%
[ 49/ 50] MNIST Time 0.02142s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 49/ 50] FashionMNIST Time 0.02106s Training Accuracy: 97.36% Test Accuracy: 84.38%
[ 50/ 50] MNIST Time 0.02152s Training Accuracy: 100.00% Test Accuracy: 78.12%
[ 50/ 50] FashionMNIST Time 0.02239s Training Accuracy: 97.46% Test Accuracy: 84.38%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 78.12%
[FINAL] FashionMNIST Training Accuracy: 97.46% Test Accuracy: 84.38%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.1
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.170 GiB / 4.750 GiB available)This page was generated using Literate.jl.