Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
CUDA.allowscalar(false)
Precompiling Lux...
431.2 ms ✓ Future
384.5 ms ✓ CEnum
556.5 ms ✓ ADTypes
362.8 ms ✓ OpenLibm_jll
509.9 ms ✓ Statistics
468.1 ms ✓ CompilerSupportLibraries_jll
450.1 ms ✓ Requires
375.7 ms ✓ Reexport
302.3 ms ✓ IfElse
527.2 ms ✓ EnzymeCore
881.3 ms ✓ IrrationalConstants
426.0 ms ✓ ConstructionBase
317.5 ms ✓ CommonWorldInvalidations
374.2 ms ✓ StaticArraysCore
606.3 ms ✓ CpuId
539.3 ms ✓ Compat
640.9 ms ✓ DocStringExtensions
485.2 ms ✓ JLLWrappers
400.8 ms ✓ NaNMath
418.8 ms ✓ Adapt
395.3 ms ✓ ADTypes → ADTypesEnzymeCoreExt
364.9 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
352.6 ms ✓ ADTypes → ADTypesConstructionBaseExt
2482.7 ms ✓ MacroTools
375.1 ms ✓ DiffResults
780.2 ms ✓ Static
373.5 ms ✓ Compat → CompatLinearAlgebraExt
597.7 ms ✓ LogExpFunctions
622.5 ms ✓ Hwloc_jll
620.2 ms ✓ OpenSpecFun_jll
476.2 ms ✓ GPUArraysCore
369.5 ms ✓ EnzymeCore → AdaptExt
366.2 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
661.2 ms ✓ CommonSubexpressions
1566.7 ms ✓ DispatchDoctor
401.6 ms ✓ BitTwiddlingConvenienceFunctions
1532.0 ms ✓ Setfield
1011.7 ms ✓ CPUSummary
1237.4 ms ✓ ChainRulesCore
613.4 ms ✓ Functors
1540.3 ms ✓ StaticArrayInterface
7460.0 ms ✓ StaticArrays
375.0 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
2190.3 ms ✓ Hwloc
436.5 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
669.3 ms ✓ PolyesterWeave
1207.8 ms ✓ LuxCore
430.5 ms ✓ ADTypes → ADTypesChainRulesCoreExt
2569.8 ms ✓ SpecialFunctions
404.9 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
631.3 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
841.6 ms ✓ MLDataDevices
1335.6 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
474.3 ms ✓ CloseOpenIntervals
1099.6 ms ✓ Optimisers
611.5 ms ✓ LayoutPointers
634.2 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
621.1 ms ✓ StaticArrays → StaticArraysStatisticsExt
622.6 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
605.1 ms ✓ Adapt → AdaptStaticArraysExt
672.5 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
467.0 ms ✓ LuxCore → LuxCoreFunctorsExt
685.5 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
455.6 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
475.1 ms ✓ LuxCore → LuxCoreSetfieldExt
609.0 ms ✓ DiffRules
670.5 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
1730.3 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
474.8 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
449.7 ms ✓ Optimisers → OptimisersEnzymeCoreExt
423.1 ms ✓ Optimisers → OptimisersAdaptExt
935.8 ms ✓ StrideArraysCore
2870.6 ms ✓ WeightInitializers
752.4 ms ✓ Polyester
960.4 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
4085.2 ms ✓ KernelAbstractions
3731.7 ms ✓ ForwardDiff
707.8 ms ✓ KernelAbstractions → LinearAlgebraExt
869.9 ms ✓ KernelAbstractions → EnzymeExt
911.3 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
5147.9 ms ✓ NNlib
863.7 ms ✓ NNlib → NNlibEnzymeCoreExt
963.0 ms ✓ NNlib → NNlibForwardDiffExt
5825.8 ms ✓ LuxLib
9703.7 ms ✓ Lux
85 dependencies successfully precompiled in 46 seconds. 24 already precompiled.
Precompiling ComponentArrays...
925.9 ms ✓ ComponentArrays
1 dependency successfully precompiled in 1 seconds. 46 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
608.8 ms ✓ MLDataDevices → MLDataDevicesComponentArraysExt
1 dependency successfully precompiled in 1 seconds. 49 already precompiled.
Precompiling LuxComponentArraysExt...
539.2 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1548.7 ms ✓ Lux → LuxComponentArraysExt
2022.7 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
3 dependencies successfully precompiled in 2 seconds. 111 already precompiled.
Precompiling LuxCUDA...
298.2 ms ✓ IteratorInterfaceExtensions
362.1 ms ✓ ExprTools
531.9 ms ✓ AbstractFFTs
433.4 ms ✓ SuiteSparse_jll
557.3 ms ✓ Serialization
483.2 ms ✓ OrderedCollections
296.3 ms ✓ DataValueInterfaces
372.2 ms ✓ Zlib_jll
344.8 ms ✓ DataAPI
408.9 ms ✓ Scratch
648.3 ms ✓ demumble_jll
1336.6 ms ✓ SentinelArrays
344.3 ms ✓ TableTraits
2424.9 ms ✓ FixedPointNumbers
2127.5 ms ✓ StringManipulation
2681.4 ms ✓ TimerOutputs
3797.9 ms ✓ SparseArrays
1732.1 ms ✓ DataStructures
936.5 ms ✓ CUDA_Driver_jll
3769.9 ms ✓ Test
1040.4 ms ✓ LazyArtifacts
553.3 ms ✓ NVTX_jll
473.3 ms ✓ PooledArrays
554.2 ms ✓ JuliaNVTXCallbacks_jll
510.0 ms ✓ Missings
846.5 ms ✓ Tables
647.7 ms ✓ Statistics → SparseArraysExt
512.4 ms ✓ SortingAlgorithms
934.2 ms ✓ KernelAbstractions → SparseArraysExt
523.6 ms ✓ BFloat16s
2315.0 ms ✓ ColorTypes
1357.7 ms ✓ AbstractFFTs → AbstractFFTsTestExt
1421.4 ms ✓ LLVMExtra_jll
2741.9 ms ✓ CUDA_Runtime_jll
4287.9 ms ✓ Colors
1978.2 ms ✓ CUDNN_jll
1289.3 ms ✓ NVTX
6555.5 ms ✓ LLVM
1297.7 ms ✓ LLVM → BFloat16sExt
1754.9 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
2181.9 ms ✓ GPUArrays
20482.5 ms ✓ PrettyTables
27368.6 ms ✓ GPUCompiler
46725.2 ms ✓ DataFrames
52293.6 ms ✓ CUDA
5045.6 ms ✓ Atomix → AtomixCUDAExt
9033.3 ms ✓ cuDNN
5594.8 ms ✓ LuxCUDA
48 dependencies successfully precompiled in 152 seconds. 52 already precompiled.
Precompiling MLDataDevicesGPUArraysExt...
1369.2 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
1 dependency successfully precompiled in 2 seconds. 42 already precompiled.
Precompiling WeightInitializersGPUArraysExt...
1448.9 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling ComponentArraysGPUArraysExt...
1584.3 ms ✓ ComponentArrays → ComponentArraysGPUArraysExt
1 dependency successfully precompiled in 2 seconds. 68 already precompiled.
Precompiling ParsersExt...
489.1 ms ✓ InlineStrings → ParsersExt
1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
691.6 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 7 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
727.2 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
678.2 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling AbstractFFTsChainRulesCoreExt...
413.7 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
1 dependency successfully precompiled in 0 seconds. 9 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
4948.7 ms ✓ ArrayInterface → ArrayInterfaceCUDAExt
1 dependency successfully precompiled in 5 seconds. 101 already precompiled.
Precompiling NNlibCUDAExt...
5117.1 ms ✓ CUDA → ChainRulesCoreExt
5548.4 ms ✓ NNlib → NNlibCUDAExt
2 dependencies successfully precompiled in 6 seconds. 102 already precompiled.
Precompiling MLDataDevicesCUDAExt...
5046.9 ms ✓ MLDataDevices → MLDataDevicesCUDAExt
1 dependency successfully precompiled in 5 seconds. 104 already precompiled.
Precompiling LuxLibCUDAExt...
5217.2 ms ✓ CUDA → SpecialFunctionsExt
5372.5 ms ✓ CUDA → EnzymeCoreExt
6165.1 ms ✓ LuxLib → LuxLibCUDAExt
3 dependencies successfully precompiled in 7 seconds. 167 already precompiled.
Precompiling WeightInitializersCUDAExt...
5111.6 ms ✓ WeightInitializers → WeightInitializersCUDAExt
1 dependency successfully precompiled in 5 seconds. 109 already precompiled.
Precompiling NNlibCUDACUDNNExt...
5350.8 ms ✓ NNlib → NNlibCUDACUDNNExt
1 dependency successfully precompiled in 6 seconds. 106 already precompiled.
Precompiling MLDataDevicescuDNNExt...
5128.7 ms ✓ MLDataDevices → MLDataDevicescuDNNExt
1 dependency successfully precompiled in 5 seconds. 107 already precompiled.
Precompiling LuxLibcuDNNExt...
5774.5 ms ✓ LuxLib → LuxLibcuDNNExt
1 dependency successfully precompiled in 6 seconds. 174 already precompiled.
Precompiling MLDatasets...
400.0 ms ✓ TensorCore
385.2 ms ✓ LazyModules
364.2 ms ✓ MappedArrays
464.4 ms ✓ CodecZlib
384.8 ms ✓ OffsetArrays → OffsetArraysAdaptExt
656.6 ms ✓ GZip
378.9 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
673.4 ms ✓ ConcurrentUtilities
1950.3 ms ✓ Distributed
579.8 ms ✓ ZipFile
808.6 ms ✓ StructTypes
406.6 ms ✓ InverseFunctions → InverseFunctionsDatesExt
1127.0 ms ✓ MbedTLS
568.2 ms ✓ LoggingExtras
744.8 ms ✓ MPIPreferences
1029.2 ms ✓ SimpleTraits
474.4 ms ✓ ContextVariablesX
491.6 ms ✓ ExceptionUnwrapping
1214.5 ms ✓ SplittablesBase
600.2 ms ✓ InverseFunctions → InverseFunctionsTestExt
599.3 ms ✓ OpenSSL_jll
564.6 ms ✓ Chemfiles_jll
685.7 ms ✓ libaec_jll
563.1 ms ✓ MicrosoftMPI_jll
612.8 ms ✓ Libiconv_jll
434.1 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
1061.0 ms ✓ FilePathsBase
758.8 ms ✓ WeakRefStrings
2286.2 ms ✓ StatsBase
4337.0 ms ✓ FileIO
460.2 ms ✓ MosaicViews
2075.7 ms ✓ ColorVectorSpace
3039.3 ms ✓ Accessors
1546.9 ms ✓ MPICH_jll
1322.7 ms ✓ MPItrampoline_jll
1236.3 ms ✓ OpenMPI_jll
594.9 ms ✓ FLoopsBase
2166.5 ms ✓ OpenSSL
545.6 ms ✓ StringEncodings
526.1 ms ✓ FilePathsBase → FilePathsBaseMmapExt
20972.5 ms ✓ Unitful
11093.9 ms ✓ JSON3
1193.5 ms ✓ FilePathsBase → FilePathsBaseTestExt
1559.0 ms ✓ NPZ
3468.8 ms ✓ ColorSchemes
622.4 ms ✓ Accessors → AccessorsTestExt
815.0 ms ✓ Accessors → AccessorsDatesExt
783.7 ms ✓ BangBang
699.6 ms ✓ Accessors → AccessorsStaticArraysExt
1821.9 ms ✓ HDF5_jll
19228.6 ms ✓ ImageCore
2422.7 ms ✓ Pickle
19487.1 ms ✓ HTTP
562.3 ms ✓ Unitful → ConstructionBaseUnitfulExt
588.1 ms ✓ Unitful → InverseFunctionsUnitfulExt
2868.0 ms ✓ UnitfulAtomic
34010.6 ms ✓ JLD2
663.8 ms ✓ Accessors → AccessorsUnitfulExt
2475.3 ms ✓ PeriodicTable
751.7 ms ✓ BangBang → BangBangStaticArraysExt
516.0 ms ✓ BangBang → BangBangChainRulesCoreExt
500.2 ms ✓ BangBang → BangBangTablesExt
1912.3 ms ✓ BangBang → BangBangDataFramesExt
863.4 ms ✓ MicroCollections
2084.0 ms ✓ ImageBase
19556.2 ms ✓ CSV
3240.1 ms ✓ DataDeps
1917.8 ms ✓ FileIO → HTTPExt
7412.0 ms ✓ HDF5
2242.9 ms ✓ AtomsBase
2744.4 ms ✓ Transducers
1964.7 ms ✓ ImageShow
2461.6 ms ✓ MAT
654.6 ms ✓ Transducers → TransducersAdaptExt
1420.7 ms ✓ Transducers → TransducersDataFramesExt
2382.6 ms ✓ Chemfiles
5101.2 ms ✓ FLoops
6352.2 ms ✓ MLUtils
9087.8 ms ✓ MLDatasets
79 dependencies successfully precompiled in 93 seconds. 119 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1777.4 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 102 already precompiled.
Precompiling LuxMLUtilsExt...
2287.5 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 3 seconds. 167 already precompiled.
Precompiling OneHotArrays...
949.1 ms ✓ OneHotArrays
1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
775.1 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling Zygote...
397.5 ms ✓ FillArrays → FillArraysStatisticsExt
591.1 ms ✓ SuiteSparse
681.6 ms ✓ FillArrays → FillArraysSparseArraysExt
832.1 ms ✓ StructArrays
998.9 ms ✓ ZygoteRules
745.2 ms ✓ SparseInverseSubset
403.5 ms ✓ StructArrays → StructArraysAdaptExt
384.0 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
1963.7 ms ✓ IRTools
669.4 ms ✓ StructArrays → StructArraysSparseArraysExt
5387.0 ms ✓ ChainRules
34727.4 ms ✓ Zygote
12 dependencies successfully precompiled in 43 seconds. 74 already precompiled.
Precompiling AccessorsStructArraysExt...
460.6 ms ✓ Accessors → AccessorsStructArraysExt
1 dependency successfully precompiled in 1 seconds. 16 already precompiled.
Precompiling BangBangStructArraysExt...
585.0 ms ✓ BangBang → BangBangStructArraysExt
1 dependency successfully precompiled in 1 seconds. 22 already precompiled.
Precompiling StructArraysStaticArraysExt...
683.5 ms ✓ StructArrays → StructArraysStaticArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling ArrayInterfaceChainRulesExt...
792.5 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1 dependency successfully precompiled in 1 seconds. 39 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
876.9 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
439.0 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 0 seconds. 15 already precompiled.
Precompiling MLDataDevicesZygoteExt...
1654.0 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
1 dependency successfully precompiled in 2 seconds. 93 already precompiled.
Precompiling LuxZygoteExt...
2910.5 ms ✓ Lux → LuxZygoteExt
1 dependency successfully precompiled in 3 seconds. 163 already precompiled.
Precompiling ComponentArraysZygoteExt...
1696.6 ms ✓ ComponentArrays → ComponentArraysZygoteExt
1 dependency successfully precompiled in 2 seconds. 99 already precompiled.
Precompiling ZygoteColorsExt...
1783.5 ms ✓ Zygote → ZygoteColorsExt
1 dependency successfully precompiled in 2 seconds. 89 already precompiled.
Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
n_eval::Union{Nothing, Int}, batchsize::Int) where {dset}
if n_train === nothing
imgs, labels = dset(:train)
else
imgs, labels = dset(:train)[1:n_train]
end
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
if n_eval === nothing
imgs, labels = dset(:test)
else
imgs, labels = dset(:test)[1:n_eval]
end
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)
)
end
function load_datasets(batchsize=256)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end
Create and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
)
model = HyperNet(weight_generator, core_network)
return model
end
create_model (generic function with 1 method)
Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
[ 1/ 50] MNIST Time 89.87371s Training Accuracy: 57.23% Test Accuracy: 50.00%
[ 1/ 50] FashionMNIST Time 0.05731s Training Accuracy: 53.12% Test Accuracy: 50.00%
[ 2/ 50] MNIST Time 0.02819s Training Accuracy: 68.36% Test Accuracy: 62.50%
[ 2/ 50] FashionMNIST Time 0.02842s Training Accuracy: 62.89% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.02884s Training Accuracy: 74.22% Test Accuracy: 56.25%
[ 3/ 50] FashionMNIST Time 0.02907s Training Accuracy: 56.93% Test Accuracy: 53.12%
[ 4/ 50] MNIST Time 0.02596s Training Accuracy: 77.34% Test Accuracy: 62.50%
[ 4/ 50] FashionMNIST Time 0.02039s Training Accuracy: 62.30% Test Accuracy: 56.25%
[ 5/ 50] MNIST Time 0.03110s Training Accuracy: 80.08% Test Accuracy: 68.75%
[ 5/ 50] FashionMNIST Time 0.03448s Training Accuracy: 66.31% Test Accuracy: 62.50%
[ 6/ 50] MNIST Time 0.02991s Training Accuracy: 85.16% Test Accuracy: 68.75%
[ 6/ 50] FashionMNIST Time 0.02568s Training Accuracy: 71.19% Test Accuracy: 78.12%
[ 7/ 50] MNIST Time 0.01992s Training Accuracy: 88.67% Test Accuracy: 75.00%
[ 7/ 50] FashionMNIST Time 0.02022s Training Accuracy: 72.07% Test Accuracy: 65.62%
[ 8/ 50] MNIST Time 0.13196s Training Accuracy: 91.02% Test Accuracy: 84.38%
[ 8/ 50] FashionMNIST Time 0.02167s Training Accuracy: 75.59% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.02541s Training Accuracy: 92.97% Test Accuracy: 81.25%
[ 9/ 50] FashionMNIST Time 0.02206s Training Accuracy: 75.68% Test Accuracy: 56.25%
[ 10/ 50] MNIST Time 0.02689s Training Accuracy: 95.70% Test Accuracy: 81.25%
[ 10/ 50] FashionMNIST Time 0.02438s Training Accuracy: 78.61% Test Accuracy: 65.62%
[ 11/ 50] MNIST Time 0.02680s Training Accuracy: 96.29% Test Accuracy: 84.38%
[ 11/ 50] FashionMNIST Time 0.02421s Training Accuracy: 79.88% Test Accuracy: 65.62%
[ 12/ 50] MNIST Time 0.02495s Training Accuracy: 97.27% Test Accuracy: 87.50%
[ 12/ 50] FashionMNIST Time 0.03477s Training Accuracy: 80.66% Test Accuracy: 65.62%
[ 13/ 50] MNIST Time 0.02354s Training Accuracy: 98.24% Test Accuracy: 87.50%
[ 13/ 50] FashionMNIST Time 0.02188s Training Accuracy: 82.42% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.02191s Training Accuracy: 99.02% Test Accuracy: 87.50%
[ 14/ 50] FashionMNIST Time 0.02102s Training Accuracy: 83.98% Test Accuracy: 71.88%
[ 15/ 50] MNIST Time 0.02096s Training Accuracy: 99.22% Test Accuracy: 87.50%
[ 15/ 50] FashionMNIST Time 0.02264s Training Accuracy: 84.28% Test Accuracy: 75.00%
[ 16/ 50] MNIST Time 0.02114s Training Accuracy: 99.41% Test Accuracy: 87.50%
[ 16/ 50] FashionMNIST Time 0.02110s Training Accuracy: 86.04% Test Accuracy: 71.88%
[ 17/ 50] MNIST Time 0.03495s Training Accuracy: 99.51% Test Accuracy: 87.50%
[ 17/ 50] FashionMNIST Time 0.02549s Training Accuracy: 86.82% Test Accuracy: 78.12%
[ 18/ 50] MNIST Time 0.02230s Training Accuracy: 99.80% Test Accuracy: 87.50%
[ 18/ 50] FashionMNIST Time 0.02086s Training Accuracy: 87.79% Test Accuracy: 78.12%
[ 19/ 50] MNIST Time 0.05698s Training Accuracy: 99.90% Test Accuracy: 87.50%
[ 19/ 50] FashionMNIST Time 0.07558s Training Accuracy: 88.96% Test Accuracy: 75.00%
[ 20/ 50] MNIST Time 0.02034s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 20/ 50] FashionMNIST Time 0.02012s Training Accuracy: 89.26% Test Accuracy: 71.88%
[ 21/ 50] MNIST Time 0.02018s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 21/ 50] FashionMNIST Time 0.06521s Training Accuracy: 90.04% Test Accuracy: 75.00%
[ 22/ 50] MNIST Time 0.02011s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 22/ 50] FashionMNIST Time 0.02768s Training Accuracy: 90.23% Test Accuracy: 75.00%
[ 23/ 50] MNIST Time 0.03188s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 23/ 50] FashionMNIST Time 0.02069s Training Accuracy: 90.92% Test Accuracy: 75.00%
[ 24/ 50] MNIST Time 0.02140s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 24/ 50] FashionMNIST Time 0.02107s Training Accuracy: 91.02% Test Accuracy: 71.88%
[ 25/ 50] MNIST Time 0.02118s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 25/ 50] FashionMNIST Time 0.02140s Training Accuracy: 91.21% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.03259s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 26/ 50] FashionMNIST Time 0.02231s Training Accuracy: 91.50% Test Accuracy: 71.88%
[ 27/ 50] MNIST Time 0.02168s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 27/ 50] FashionMNIST Time 0.02081s Training Accuracy: 92.48% Test Accuracy: 71.88%
[ 28/ 50] MNIST Time 0.02134s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 28/ 50] FashionMNIST Time 0.02085s Training Accuracy: 92.58% Test Accuracy: 71.88%
[ 29/ 50] MNIST Time 0.02097s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 29/ 50] FashionMNIST Time 0.02092s Training Accuracy: 93.16% Test Accuracy: 71.88%
[ 30/ 50] MNIST Time 0.02178s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 30/ 50] FashionMNIST Time 0.03325s Training Accuracy: 92.97% Test Accuracy: 71.88%
[ 31/ 50] MNIST Time 0.02205s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 31/ 50] FashionMNIST Time 0.02087s Training Accuracy: 93.46% Test Accuracy: 71.88%
[ 32/ 50] MNIST Time 0.02121s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 32/ 50] FashionMNIST Time 0.02089s Training Accuracy: 93.36% Test Accuracy: 71.88%
[ 33/ 50] MNIST Time 0.02140s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 33/ 50] FashionMNIST Time 0.02103s Training Accuracy: 93.85% Test Accuracy: 71.88%
[ 34/ 50] MNIST Time 0.02138s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 34/ 50] FashionMNIST Time 0.02500s Training Accuracy: 94.34% Test Accuracy: 71.88%
[ 35/ 50] MNIST Time 0.03562s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 35/ 50] FashionMNIST Time 0.02148s Training Accuracy: 94.73% Test Accuracy: 71.88%
[ 36/ 50] MNIST Time 0.02079s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 36/ 50] FashionMNIST Time 0.02092s Training Accuracy: 94.82% Test Accuracy: 71.88%
[ 37/ 50] MNIST Time 0.02153s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 37/ 50] FashionMNIST Time 0.02109s Training Accuracy: 95.12% Test Accuracy: 71.88%
[ 38/ 50] MNIST Time 0.02080s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 38/ 50] FashionMNIST Time 0.02182s Training Accuracy: 95.02% Test Accuracy: 71.88%
[ 39/ 50] MNIST Time 0.02098s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 39/ 50] FashionMNIST Time 0.03268s Training Accuracy: 95.31% Test Accuracy: 71.88%
[ 40/ 50] MNIST Time 0.02135s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 40/ 50] FashionMNIST Time 0.02154s Training Accuracy: 95.31% Test Accuracy: 71.88%
[ 41/ 50] MNIST Time 0.02120s Training Accuracy: 100.00% Test Accuracy: 87.50%
[ 41/ 50] FashionMNIST Time 0.02276s Training Accuracy: 95.61% Test Accuracy: 71.88%
[ 42/ 50] MNIST Time 0.02087s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 42/ 50] FashionMNIST Time 0.02136s Training Accuracy: 95.90% Test Accuracy: 71.88%
[ 43/ 50] MNIST Time 0.02088s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 43/ 50] FashionMNIST Time 0.02088s Training Accuracy: 95.70% Test Accuracy: 71.88%
[ 44/ 50] MNIST Time 0.03568s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 44/ 50] FashionMNIST Time 0.02237s Training Accuracy: 96.00% Test Accuracy: 71.88%
[ 45/ 50] MNIST Time 0.02152s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 45/ 50] FashionMNIST Time 0.02090s Training Accuracy: 96.39% Test Accuracy: 71.88%
[ 46/ 50] MNIST Time 0.02163s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 46/ 50] FashionMNIST Time 0.02138s Training Accuracy: 96.29% Test Accuracy: 71.88%
[ 47/ 50] MNIST Time 0.02088s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 47/ 50] FashionMNIST Time 0.02137s Training Accuracy: 96.39% Test Accuracy: 71.88%
[ 48/ 50] MNIST Time 0.02171s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 48/ 50] FashionMNIST Time 0.03456s Training Accuracy: 96.00% Test Accuracy: 71.88%
[ 49/ 50] MNIST Time 0.02158s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 49/ 50] FashionMNIST Time 0.02398s Training Accuracy: 96.68% Test Accuracy: 71.88%
[ 50/ 50] MNIST Time 0.02082s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 50/ 50] FashionMNIST Time 0.02131s Training Accuracy: 96.39% Test Accuracy: 71.88%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 84.38%
[FINAL] FashionMNIST Training Accuracy: 96.39% Test Accuracy: 71.88%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.232 GiB / 4.750 GiB available)
This page was generated using Literate.jl.