Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Zygote
CUDA.allowscalar(false)
Precompiling LuxComponentArraysExt...
1540.4 ms ✓ Lux → LuxComponentArraysExt
2152.8 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
2 dependencies successfully precompiled in 2 seconds. 112 already precompiled.
Precompiling LuxCUDA...
26103.8 ms ✓ GPUCompiler
51357.4 ms ✓ CUDA
5182.4 ms ✓ Atomix → AtomixCUDAExt
8159.0 ms ✓ cuDNN
5259.5 ms ✓ LuxCUDA
5 dependencies successfully precompiled in 97 seconds. 95 already precompiled.
Precompiling WeightInitializersGPUArraysExt...
1407.8 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
4922.4 ms ✓ ArrayInterface → ArrayInterfaceCUDAExt
1 dependency successfully precompiled in 5 seconds. 101 already precompiled.
Precompiling NNlibCUDAExt...
4973.8 ms ✓ CUDA → ChainRulesCoreExt
5368.8 ms ✓ NNlib → NNlibCUDAExt
2 dependencies successfully precompiled in 6 seconds. 102 already precompiled.
Precompiling MLDataDevicesCUDAExt...
5034.3 ms ✓ MLDataDevices → MLDataDevicesCUDAExt
1 dependency successfully precompiled in 5 seconds. 104 already precompiled.
Precompiling LuxLibCUDAExt...
5212.8 ms ✓ CUDA → EnzymeCoreExt
5234.5 ms ✓ CUDA → SpecialFunctionsExt
5683.4 ms ✓ LuxLib → LuxLibCUDAExt
3 dependencies successfully precompiled in 6 seconds. 167 already precompiled.
Precompiling WeightInitializersCUDAExt...
5051.9 ms ✓ WeightInitializers → WeightInitializersCUDAExt
1 dependency successfully precompiled in 5 seconds. 109 already precompiled.
Precompiling NNlibCUDACUDNNExt...
5331.8 ms ✓ NNlib → NNlibCUDACUDNNExt
1 dependency successfully precompiled in 6 seconds. 106 already precompiled.
Precompiling MLDataDevicescuDNNExt...
5064.3 ms ✓ MLDataDevices → MLDataDevicescuDNNExt
1 dependency successfully precompiled in 5 seconds. 107 already precompiled.
Precompiling LuxLibcuDNNExt...
5799.5 ms ✓ LuxLib → LuxLibcuDNNExt
1 dependency successfully precompiled in 6 seconds. 174 already precompiled.
Precompiling MLDatasets...
359.5 ms ✓ Glob
381.9 ms ✓ WorkerUtilities
420.8 ms ✓ BufferedStreams
315.5 ms ✓ SimpleBufferStream
435.1 ms ✓ CodecZlib
560.7 ms ✓ URIs
425.1 ms ✓ DelimitedFiles
308.7 ms ✓ PackageExtensionCompat
340.6 ms ✓ BitFlags
570.6 ms ✓ ZipFile
663.8 ms ✓ GZip
659.9 ms ✓ ConcurrentUtilities
491.0 ms ✓ LoggingExtras
378.3 ms ✓ ContextVariablesX
318.7 ms ✓ InternedStrings
470.1 ms ✓ ExceptionUnwrapping
2069.5 ms ✓ ColorVectorSpace
1084.2 ms ✓ SplittablesBase
2874.3 ms ✓ Accessors
1373.7 ms ✓ MPICH_jll
758.5 ms ✓ WeakRefStrings
1120.0 ms ✓ MPItrampoline_jll
2206.8 ms ✓ AtomsBase
1096.9 ms ✓ OpenMPI_jll
400.1 ms ✓ StridedViews
1825.4 ms ✓ OpenSSL
1529.2 ms ✓ NPZ
582.7 ms ✓ FLoopsBase
11077.4 ms ✓ JSON3
3450.4 ms ✓ ColorSchemes
605.0 ms ✓ Accessors → AccessorsTestExt
759.1 ms ✓ Accessors → AccessorsDatesExt
634.0 ms ✓ Accessors → AccessorsUnitfulExt
741.6 ms ✓ BangBang
683.5 ms ✓ Accessors → AccessorsStaticArraysExt
18510.4 ms ✓ ImageCore
2226.2 ms ✓ Chemfiles
1871.0 ms ✓ HDF5_jll
2346.6 ms ✓ Pickle
33069.8 ms ✓ JLD2
694.3 ms ✓ BangBang → BangBangStaticArraysExt
18725.0 ms ✓ CSV
1629.3 ms ✓ BangBang → BangBangDataFramesExt
495.8 ms ✓ BangBang → BangBangChainRulesCoreExt
477.0 ms ✓ BangBang → BangBangTablesExt
832.1 ms ✓ MicroCollections
2071.7 ms ✓ ImageBase
2589.0 ms ✓ Transducers
1853.5 ms ✓ ImageShow
7107.9 ms ✓ HDF5
1397.9 ms ✓ Transducers → TransducersDataFramesExt
633.4 ms ✓ Transducers → TransducersAdaptExt
2301.0 ms ✓ MAT
4950.5 ms ✓ FLoops
19188.9 ms ✓ HTTP
1819.2 ms ✓ FileIO → HTTPExt
3050.2 ms ✓ DataDeps
6202.1 ms ✓ MLUtils
8856.3 ms ✓ MLDatasets
59 dependencies successfully precompiled in 71 seconds. 139 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1592.5 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 102 already precompiled.
Precompiling LuxMLUtilsExt...
2167.6 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 3 seconds. 167 already precompiled.
Precompiling OneHotArrays...
934.1 ms ✓ OneHotArrays
1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
727.2 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling Zygote...
324.8 ms ✓ RealDot
5301.8 ms ✓ ChainRules
32822.6 ms ✓ Zygote
3 dependencies successfully precompiled in 39 seconds. 83 already precompiled.
Precompiling AccessorsStructArraysExt...
473.3 ms ✓ Accessors → AccessorsStructArraysExt
1 dependency successfully precompiled in 1 seconds. 16 already precompiled.
Precompiling BangBangStructArraysExt...
481.8 ms ✓ BangBang → BangBangStructArraysExt
1 dependency successfully precompiled in 1 seconds. 22 already precompiled.
Precompiling ArrayInterfaceChainRulesExt...
864.4 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1 dependency successfully precompiled in 1 seconds. 39 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
833.2 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesZygoteExt...
1604.0 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
1 dependency successfully precompiled in 2 seconds. 93 already precompiled.
Precompiling LuxZygoteExt...
2761.2 ms ✓ Lux → LuxZygoteExt
1 dependency successfully precompiled in 3 seconds. 163 already precompiled.
Precompiling ComponentArraysZygoteExt...
1620.3 ms ✓ ComponentArrays → ComponentArraysZygoteExt
1 dependency successfully precompiled in 2 seconds. 99 already precompiled.
Precompiling ZygoteColorsExt...
1776.5 ms ✓ Zygote → ZygoteColorsExt
1 dependency successfully precompiled in 2 seconds. 89 already precompiled.
Loading Datasets
julia
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
n_eval::Union{Nothing, Int}, batchsize::Int) where {dset}
if n_train === nothing
imgs, labels = dset(:train)
else
imgs, labels = dset(:train)[1:n_train]
end
x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)
if n_eval === nothing
imgs, labels = dset(:test)
else
imgs, labels = dset(:test)[1:n_eval]
end
x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)
return (
DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)
)
end
function load_datasets(batchsize=256)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end
Create and Initialize the HyperNet
julia
function create_model()
# Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
)
model = HyperNet(weight_generator, core_network)
return model
end
create_model (generic function with 1 method)
Define Utility Functions
julia
const loss = CrossEntropyLoss(; logits=Val(true))
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
model = create_model()
dataloaders = load_datasets()
dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
end
ttime = time() - stime
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
digits=2)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
[ 1/ 50] MNIST Time 88.27585s Training Accuracy: 58.01% Test Accuracy: 46.88%
[ 1/ 50] FashionMNIST Time 0.02446s Training Accuracy: 52.15% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.02460s Training Accuracy: 66.99% Test Accuracy: 68.75%
[ 2/ 50] FashionMNIST Time 0.02514s Training Accuracy: 64.94% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.02564s Training Accuracy: 74.51% Test Accuracy: 62.50%
[ 3/ 50] FashionMNIST Time 0.02598s Training Accuracy: 64.75% Test Accuracy: 53.12%
[ 4/ 50] MNIST Time 0.05349s Training Accuracy: 74.41% Test Accuracy: 59.38%
[ 4/ 50] FashionMNIST Time 0.02134s Training Accuracy: 63.38% Test Accuracy: 56.25%
[ 5/ 50] MNIST Time 0.02132s Training Accuracy: 77.54% Test Accuracy: 62.50%
[ 5/ 50] FashionMNIST Time 0.02118s Training Accuracy: 67.87% Test Accuracy: 59.38%
[ 6/ 50] MNIST Time 0.02120s Training Accuracy: 84.47% Test Accuracy: 59.38%
[ 6/ 50] FashionMNIST Time 0.03953s Training Accuracy: 70.61% Test Accuracy: 65.62%
[ 7/ 50] MNIST Time 0.02204s Training Accuracy: 87.30% Test Accuracy: 75.00%
[ 7/ 50] FashionMNIST Time 0.02045s Training Accuracy: 66.41% Test Accuracy: 46.88%
[ 8/ 50] MNIST Time 0.02065s Training Accuracy: 91.11% Test Accuracy: 81.25%
[ 8/ 50] FashionMNIST Time 0.02127s Training Accuracy: 75.49% Test Accuracy: 53.12%
[ 9/ 50] MNIST Time 0.02198s Training Accuracy: 93.36% Test Accuracy: 84.38%
[ 9/ 50] FashionMNIST Time 0.02122s Training Accuracy: 75.88% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.02040s Training Accuracy: 94.82% Test Accuracy: 84.38%
[ 10/ 50] FashionMNIST Time 0.02252s Training Accuracy: 76.66% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.02274s Training Accuracy: 96.78% Test Accuracy: 84.38%
[ 11/ 50] FashionMNIST Time 0.02050s Training Accuracy: 75.20% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.02255s Training Accuracy: 97.85% Test Accuracy: 87.50%
[ 12/ 50] FashionMNIST Time 0.02147s Training Accuracy: 78.32% Test Accuracy: 59.38%
[ 13/ 50] MNIST Time 0.02286s Training Accuracy: 98.44% Test Accuracy: 87.50%
[ 13/ 50] FashionMNIST Time 0.02389s Training Accuracy: 80.37% Test Accuracy: 65.62%
[ 14/ 50] MNIST Time 0.02165s Training Accuracy: 98.34% Test Accuracy: 87.50%
[ 14/ 50] FashionMNIST Time 0.02102s Training Accuracy: 79.98% Test Accuracy: 59.38%
[ 15/ 50] MNIST Time 0.02389s Training Accuracy: 98.93% Test Accuracy: 84.38%
[ 15/ 50] FashionMNIST Time 0.02076s Training Accuracy: 80.86% Test Accuracy: 62.50%
[ 16/ 50] MNIST Time 0.02198s Training Accuracy: 99.22% Test Accuracy: 84.38%
[ 16/ 50] FashionMNIST Time 0.02283s Training Accuracy: 82.91% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.02087s Training Accuracy: 99.41% Test Accuracy: 84.38%
[ 17/ 50] FashionMNIST Time 0.02075s Training Accuracy: 84.57% Test Accuracy: 59.38%
[ 18/ 50] MNIST Time 0.02213s Training Accuracy: 99.71% Test Accuracy: 84.38%
[ 18/ 50] FashionMNIST Time 0.02183s Training Accuracy: 85.25% Test Accuracy: 65.62%
[ 19/ 50] MNIST Time 0.02074s Training Accuracy: 99.80% Test Accuracy: 84.38%
[ 19/ 50] FashionMNIST Time 0.02277s Training Accuracy: 85.84% Test Accuracy: 65.62%
[ 20/ 50] MNIST Time 0.02216s Training Accuracy: 99.80% Test Accuracy: 84.38%
[ 20/ 50] FashionMNIST Time 0.02070s Training Accuracy: 86.91% Test Accuracy: 62.50%
[ 21/ 50] MNIST Time 0.02558s Training Accuracy: 99.90% Test Accuracy: 84.38%
[ 21/ 50] FashionMNIST Time 0.02074s Training Accuracy: 87.30% Test Accuracy: 62.50%
[ 22/ 50] MNIST Time 0.02022s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 22/ 50] FashionMNIST Time 0.02249s Training Accuracy: 87.99% Test Accuracy: 62.50%
[ 23/ 50] MNIST Time 0.02069s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 23/ 50] FashionMNIST Time 0.02081s Training Accuracy: 87.60% Test Accuracy: 68.75%
[ 24/ 50] MNIST Time 0.02344s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 24/ 50] FashionMNIST Time 0.02160s Training Accuracy: 88.67% Test Accuracy: 62.50%
[ 25/ 50] MNIST Time 0.02121s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 25/ 50] FashionMNIST Time 0.02252s Training Accuracy: 88.96% Test Accuracy: 65.62%
[ 26/ 50] MNIST Time 0.02087s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 26/ 50] FashionMNIST Time 0.02066s Training Accuracy: 88.96% Test Accuracy: 62.50%
[ 27/ 50] MNIST Time 0.02329s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 27/ 50] FashionMNIST Time 0.02078s Training Accuracy: 89.84% Test Accuracy: 65.62%
[ 28/ 50] MNIST Time 0.02276s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 28/ 50] FashionMNIST Time 0.02285s Training Accuracy: 89.75% Test Accuracy: 65.62%
[ 29/ 50] MNIST Time 0.02087s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 29/ 50] FashionMNIST Time 0.02064s Training Accuracy: 89.65% Test Accuracy: 68.75%
[ 30/ 50] MNIST Time 0.02719s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 30/ 50] FashionMNIST Time 0.02065s Training Accuracy: 88.87% Test Accuracy: 65.62%
[ 31/ 50] MNIST Time 0.02154s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 31/ 50] FashionMNIST Time 0.02290s Training Accuracy: 89.36% Test Accuracy: 68.75%
[ 32/ 50] MNIST Time 0.02069s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 32/ 50] FashionMNIST Time 0.02048s Training Accuracy: 89.65% Test Accuracy: 68.75%
[ 33/ 50] MNIST Time 0.02328s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 33/ 50] FashionMNIST Time 0.02044s Training Accuracy: 90.23% Test Accuracy: 71.88%
[ 34/ 50] MNIST Time 0.02069s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 34/ 50] FashionMNIST Time 0.02398s Training Accuracy: 90.72% Test Accuracy: 65.62%
[ 35/ 50] MNIST Time 0.02258s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 35/ 50] FashionMNIST Time 0.02103s Training Accuracy: 91.02% Test Accuracy: 65.62%
[ 36/ 50] MNIST Time 0.02304s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 36/ 50] FashionMNIST Time 0.02057s Training Accuracy: 90.43% Test Accuracy: 71.88%
[ 37/ 50] MNIST Time 0.02040s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 37/ 50] FashionMNIST Time 0.02203s Training Accuracy: 91.50% Test Accuracy: 71.88%
[ 38/ 50] MNIST Time 0.02213s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 38/ 50] FashionMNIST Time 0.02045s Training Accuracy: 91.41% Test Accuracy: 68.75%
[ 39/ 50] MNIST Time 0.02225s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 39/ 50] FashionMNIST Time 0.02042s Training Accuracy: 91.89% Test Accuracy: 65.62%
[ 40/ 50] MNIST Time 0.02254s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 40/ 50] FashionMNIST Time 0.02756s Training Accuracy: 92.19% Test Accuracy: 71.88%
[ 41/ 50] MNIST Time 0.02421s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 41/ 50] FashionMNIST Time 0.02099s Training Accuracy: 92.38% Test Accuracy: 65.62%
[ 42/ 50] MNIST Time 0.02322s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 42/ 50] FashionMNIST Time 0.02150s Training Accuracy: 93.26% Test Accuracy: 65.62%
[ 43/ 50] MNIST Time 0.02165s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 43/ 50] FashionMNIST Time 0.02323s Training Accuracy: 92.48% Test Accuracy: 71.88%
[ 44/ 50] MNIST Time 0.02272s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 44/ 50] FashionMNIST Time 0.02271s Training Accuracy: 93.07% Test Accuracy: 71.88%
[ 45/ 50] MNIST Time 0.02384s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 45/ 50] FashionMNIST Time 0.02160s Training Accuracy: 92.19% Test Accuracy: 71.88%
[ 46/ 50] MNIST Time 0.02293s Training Accuracy: 100.00% Test Accuracy: 81.25%
[ 46/ 50] FashionMNIST Time 0.02301s Training Accuracy: 88.77% Test Accuracy: 71.88%
[ 47/ 50] MNIST Time 0.02104s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 47/ 50] FashionMNIST Time 0.02125s Training Accuracy: 90.53% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.02407s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 48/ 50] FashionMNIST Time 0.02186s Training Accuracy: 90.14% Test Accuracy: 65.62%
[ 49/ 50] MNIST Time 0.02145s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 49/ 50] FashionMNIST Time 0.02530s Training Accuracy: 90.82% Test Accuracy: 68.75%
[ 50/ 50] MNIST Time 0.02194s Training Accuracy: 100.00% Test Accuracy: 84.38%
[ 50/ 50] FashionMNIST Time 0.02140s Training Accuracy: 92.09% Test Accuracy: 75.00%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 84.38%
[FINAL] FashionMNIST Training Accuracy: 92.09% Test Accuracy: 75.00%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 3.170 GiB / 4.750 GiB available)
This page was generated using Literate.jl.