Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random,
Reactant
Precompiling ComponentArrays...
997.3 ms ✓ ComponentArrays
1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
616.2 ms ✓ MLDataDevices → MLDataDevicesComponentArraysExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
527.5 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1479.3 ms ✓ Lux → LuxComponentArraysExt
2313.9 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
3 dependencies successfully precompiled in 3 seconds. 112 already precompiled.
Precompiling MLDatasets...
387.3 ms ✓ Glob
423.1 ms ✓ WorkerUtilities
458.9 ms ✓ BufferedStreams
369.6 ms ✓ SimpleBufferStream
638.4 ms ✓ InlineStrings
621.4 ms ✓ URIs
480.7 ms ✓ CodecZlib
374.4 ms ✓ InvertedIndices
337.4 ms ✓ PackageExtensionCompat
380.1 ms ✓ BitFlags
859.8 ms ✓ GZip
743.2 ms ✓ ConcurrentUtilities
1215.9 ms ✓ Crayons
502.8 ms ✓ PooledArrays
776.1 ms ✓ ZipFile
894.4 ms ✓ StructTypes
1341.0 ms ✓ SentinelArrays
1057.3 ms ✓ MbedTLS
548.0 ms ✓ LoggingExtras
593.9 ms ✓ MPIPreferences
366.9 ms ✓ InternedStrings
697.1 ms ✓ ExceptionUnwrapping
2420.0 ms ✓ PeriodicTable
1155.1 ms ✓ Chemfiles_jll
1604.3 ms ✓ BFloat16s
3842.2 ms ✓ UnitfulAtomic
583.8 ms ✓ MicrosoftMPI_jll
902.8 ms ✓ libaec_jll
566.6 ms ✓ StringEncodings
515.8 ms ✓ InlineStrings → ParsersExt
447.3 ms ✓ StridedViews
2101.0 ms ✓ StringManipulation
2201.0 ms ✓ ImageShow
1522.4 ms ✓ NPZ
1934.6 ms ✓ OpenSSL
1159.4 ms ✓ OpenMPI_jll
1449.6 ms ✓ MPICH_jll
1151.0 ms ✓ MPItrampoline_jll
775.3 ms ✓ WeakRefStrings
2212.3 ms ✓ AtomsBase
2386.8 ms ✓ Pickle
10657.4 ms ✓ JSON3
1865.7 ms ✓ HDF5_jll
20312.5 ms ✓ PrettyTables
19675.2 ms ✓ HTTP
2413.8 ms ✓ Chemfiles
17362.1 ms ✓ CSV
3202.2 ms ✓ DataDeps
7643.8 ms ✓ HDF5
1955.8 ms ✓ FileIO → HTTPExt
2463.7 ms ✓ MAT
45238.6 ms ✓ DataFrames
1425.2 ms ✓ Transducers → TransducersDataFramesExt
1610.9 ms ✓ BangBang → BangBangDataFramesExt
9289.0 ms ✓ MLDatasets
55 dependencies successfully precompiled in 95 seconds. 145 already precompiled.
Precompiling OneHotArrays...
988.3 ms ✓ OneHotArrays
1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
748.7 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling ComponentArraysReactantExt...
12413.7 ms ✓ ComponentArrays → ComponentArraysReactantExt
1 dependency successfully precompiled in 13 seconds. 96 already precompiled.
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing, Int},
n_eval::Union{Nothing, Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize = min(batchsize, size(x_train, 4)), shuffle = true, partial = false
),
DataLoader(
(x_test, y_test);
batchsize = min(batchsize, size(x_test, 4)), shuffle = false, partial = false
),
)
end
function load_datasets(batchsize = 32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer
)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |> getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch = :HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator = Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride = 2),
Conv((3, 3), 16 => 32, relu; stride = 2),
Conv((3, 3), 32 => 64, relu; stride = 2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10)
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
),
core_network
)
end
create_model (generic function with 1 method)
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
dev = reactant_device(; force = true)
model = create_model()
dataloaders = load_datasets() |> dev
Random.seed!(1234)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(), CrossEntropyLoss(; logits = Val(true)),
((concrete_data_idx, x), y), train_state; return_gradients = Val(false)
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, train_dataloader, concrete_data_idx
) * 100;
digits = 2
)
test_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, test_dataloader, concrete_data_idx
) * 100;
digits = 2
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, train_dataloader, concrete_data_idx
) * 100;
digits = 2
)
test_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, test_dataloader, concrete_data_idx
) * 100;
digits = 2
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-03-11 22:39:32.586648: I external/xla/xla/service/service.cc:152] XLA service 0xc7465e0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-11 22:39:32.586698: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741732772.587560 1072320 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741732772.587662 1072320 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741732772.587723 1072320 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741732772.602990 1072320 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1741732846.958816 1072320 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1741732846.959998 1072320 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1741732846.960007 1072320 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1741732846.960014 1072320 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1741732846.960020 1072320 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1741732846.960026 1072320 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1741732846.960033 1072320 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1741732846.960039 1072320 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1741732846.960045 1072320 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1741732846.960052 1072320 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-11 22:40:46.960067: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732846.962802 1072320 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1741732846.962838 1072320 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1741732846.962845 1072320 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1741732846.962851 1072320 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1741732846.962858 1072320 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1741732846.962864 1072320 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1741732846.962870 1072320 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1741732846.962877 1072320 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1741732846.962883 1072320 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1741732846.962889 1072320 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-11 22:40:46.962899: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732880.324143 1072320 buffer_comparator.cc:156] Difference at 32: -nan, expected 0.585358
E0000 00:00:1741732880.324206 1072320 buffer_comparator.cc:156] Difference at 33: -nan, expected 0.881101
E0000 00:00:1741732880.324213 1072320 buffer_comparator.cc:156] Difference at 34: -nan, expected 0.738887
E0000 00:00:1741732880.324220 1072320 buffer_comparator.cc:156] Difference at 35: -nan, expected 0.603294
E0000 00:00:1741732880.324226 1072320 buffer_comparator.cc:156] Difference at 36: -nan, expected 1.04006
E0000 00:00:1741732880.324233 1072320 buffer_comparator.cc:156] Difference at 37: -nan, expected 0.740676
E0000 00:00:1741732880.324239 1072320 buffer_comparator.cc:156] Difference at 38: -nan, expected 0.766031
E0000 00:00:1741732880.324245 1072320 buffer_comparator.cc:156] Difference at 39: -nan, expected 0.814752
E0000 00:00:1741732880.324265 1072320 buffer_comparator.cc:156] Difference at 40: -nan, expected 0.431557
E0000 00:00:1741732880.324272 1072320 buffer_comparator.cc:156] Difference at 41: -nan, expected 0.672244
2025-03-11 22:41:20.324286: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732880.328628 1072320 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1741732880.328652 1072320 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1741732880.328659 1072320 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1741732880.328665 1072320 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1741732880.328671 1072320 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1741732880.328678 1072320 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1741732880.328684 1072320 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1741732880.328690 1072320 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1741732880.328696 1072320 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1741732880.328702 1072320 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-03-11 22:41:20.328711: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732880.331081 1072320 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1741732880.331092 1072320 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1741732880.331095 1072320 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1741732880.331098 1072320 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1741732880.331100 1072320 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1741732880.331103 1072320 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1741732880.331106 1072320 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1741732880.331108 1072320 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1741732880.331111 1072320 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1741732880.331114 1072320 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-03-11 22:41:20.331118: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732880.333245 1072320 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1741732880.333256 1072320 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1741732880.333259 1072320 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1741732880.333261 1072320 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1741732880.333264 1072320 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1741732880.333267 1072320 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1741732880.333270 1072320 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1741732880.333272 1072320 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1741732880.333275 1072320 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1741732880.333278 1072320 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-03-11 22:41:20.333282: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732880.335413 1072320 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1741732880.335424 1072320 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1741732880.335426 1072320 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1741732880.335429 1072320 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1741732880.335432 1072320 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1741732880.335435 1072320 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1741732880.335437 1072320 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1741732880.335440 1072320 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1741732880.335443 1072320 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1741732880.335445 1072320 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-03-11 22:41:20.335450: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
[ 1/ 50] MNIST Time 35.91804s Training Accuracy: 35.06% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.03871s Training Accuracy: 34.18% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.04073s Training Accuracy: 34.96% Test Accuracy: 37.50%
[ 2/ 50] FashionMNIST Time 0.07132s Training Accuracy: 51.17% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.03580s Training Accuracy: 36.52% Test Accuracy: 28.12%
[ 3/ 50] FashionMNIST Time 0.03345s Training Accuracy: 57.23% Test Accuracy: 53.12%
[ 4/ 50] MNIST Time 0.04741s Training Accuracy: 45.41% Test Accuracy: 34.38%
[ 4/ 50] FashionMNIST Time 0.03490s Training Accuracy: 64.45% Test Accuracy: 56.25%
[ 5/ 50] MNIST Time 0.03310s Training Accuracy: 50.78% Test Accuracy: 46.88%
[ 5/ 50] FashionMNIST Time 0.03279s Training Accuracy: 70.02% Test Accuracy: 59.38%
[ 6/ 50] MNIST Time 0.03607s Training Accuracy: 56.74% Test Accuracy: 40.62%
[ 6/ 50] FashionMNIST Time 0.02994s Training Accuracy: 74.02% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.03922s Training Accuracy: 62.40% Test Accuracy: 40.62%
[ 7/ 50] FashionMNIST Time 0.02892s Training Accuracy: 78.81% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.03913s Training Accuracy: 71.00% Test Accuracy: 43.75%
[ 8/ 50] FashionMNIST Time 0.02928s Training Accuracy: 83.69% Test Accuracy: 62.50%
[ 9/ 50] MNIST Time 0.04293s Training Accuracy: 74.61% Test Accuracy: 43.75%
[ 9/ 50] FashionMNIST Time 0.03059s Training Accuracy: 83.11% Test Accuracy: 65.62%
[ 10/ 50] MNIST Time 0.03220s Training Accuracy: 78.61% Test Accuracy: 46.88%
[ 10/ 50] FashionMNIST Time 0.03007s Training Accuracy: 88.09% Test Accuracy: 62.50%
[ 11/ 50] MNIST Time 0.03057s Training Accuracy: 84.08% Test Accuracy: 37.50%
[ 11/ 50] FashionMNIST Time 0.04358s Training Accuracy: 90.33% Test Accuracy: 59.38%
[ 12/ 50] MNIST Time 0.03569s Training Accuracy: 87.11% Test Accuracy: 43.75%
[ 12/ 50] FashionMNIST Time 0.03765s Training Accuracy: 91.80% Test Accuracy: 65.62%
[ 13/ 50] MNIST Time 0.03755s Training Accuracy: 90.23% Test Accuracy: 46.88%
[ 13/ 50] FashionMNIST Time 0.03633s Training Accuracy: 94.63% Test Accuracy: 62.50%
[ 14/ 50] MNIST Time 0.03024s Training Accuracy: 91.31% Test Accuracy: 56.25%
[ 14/ 50] FashionMNIST Time 0.04041s Training Accuracy: 94.82% Test Accuracy: 59.38%
[ 15/ 50] MNIST Time 0.03308s Training Accuracy: 95.12% Test Accuracy: 53.12%
[ 15/ 50] FashionMNIST Time 0.02890s Training Accuracy: 95.90% Test Accuracy: 62.50%
[ 16/ 50] MNIST Time 0.02936s Training Accuracy: 96.68% Test Accuracy: 50.00%
[ 16/ 50] FashionMNIST Time 0.02841s Training Accuracy: 95.90% Test Accuracy: 59.38%
[ 17/ 50] MNIST Time 0.03651s Training Accuracy: 97.17% Test Accuracy: 56.25%
[ 17/ 50] FashionMNIST Time 0.03268s Training Accuracy: 97.85% Test Accuracy: 62.50%
[ 18/ 50] MNIST Time 0.03884s Training Accuracy: 98.14% Test Accuracy: 53.12%
[ 18/ 50] FashionMNIST Time 0.02806s Training Accuracy: 97.66% Test Accuracy: 62.50%
[ 19/ 50] MNIST Time 0.03930s Training Accuracy: 99.22% Test Accuracy: 56.25%
[ 19/ 50] FashionMNIST Time 0.02913s Training Accuracy: 99.32% Test Accuracy: 59.38%
[ 20/ 50] MNIST Time 0.03430s Training Accuracy: 99.71% Test Accuracy: 53.12%
[ 20/ 50] FashionMNIST Time 0.03048s Training Accuracy: 99.32% Test Accuracy: 59.38%
[ 21/ 50] MNIST Time 0.03063s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 21/ 50] FashionMNIST Time 0.02979s Training Accuracy: 99.51% Test Accuracy: 59.38%
[ 22/ 50] MNIST Time 0.03091s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 22/ 50] FashionMNIST Time 0.02789s Training Accuracy: 99.51% Test Accuracy: 59.38%
[ 23/ 50] MNIST Time 0.03142s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 23/ 50] FashionMNIST Time 0.03848s Training Accuracy: 99.61% Test Accuracy: 62.50%
[ 24/ 50] MNIST Time 0.03005s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 24/ 50] FashionMNIST Time 0.03746s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 25/ 50] MNIST Time 0.03042s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 25/ 50] FashionMNIST Time 0.03885s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 26/ 50] MNIST Time 0.03081s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 26/ 50] FashionMNIST Time 0.03063s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] MNIST Time 0.03073s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 27/ 50] FashionMNIST Time 0.02767s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] MNIST Time 0.02941s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 28/ 50] FashionMNIST Time 0.02802s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 29/ 50] MNIST Time 0.03935s Training Accuracy: 100.00% Test Accuracy: 53.12%
[ 29/ 50] FashionMNIST Time 0.02998s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 30/ 50] MNIST Time 0.03779s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 30/ 50] FashionMNIST Time 0.02995s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 31/ 50] MNIST Time 0.03823s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 31/ 50] FashionMNIST Time 0.03050s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] MNIST Time 0.03040s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 32/ 50] FashionMNIST Time 0.03044s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 33/ 50] MNIST Time 0.03360s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 33/ 50] FashionMNIST Time 0.02847s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] MNIST Time 0.02868s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 34/ 50] FashionMNIST Time 0.03713s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] MNIST Time 0.02848s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 35/ 50] FashionMNIST Time 0.04032s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] MNIST Time 0.02940s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 36/ 50] FashionMNIST Time 0.03713s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] MNIST Time 0.02845s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 37/ 50] FashionMNIST Time 0.03695s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 38/ 50] MNIST Time 0.02979s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 38/ 50] FashionMNIST Time 0.02973s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] MNIST Time 0.02844s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 39/ 50] FashionMNIST Time 0.02848s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] MNIST Time 0.02931s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 40/ 50] FashionMNIST Time 0.02836s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] MNIST Time 0.03853s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 41/ 50] FashionMNIST Time 0.02855s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] MNIST Time 0.03617s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 42/ 50] FashionMNIST Time 0.02867s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] MNIST Time 0.03773s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 43/ 50] FashionMNIST Time 0.02900s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 44/ 50] MNIST Time 0.02822s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 44/ 50] FashionMNIST Time 0.02832s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] MNIST Time 0.02983s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 45/ 50] FashionMNIST Time 0.02831s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] MNIST Time 0.02834s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 46/ 50] FashionMNIST Time 0.03677s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 47/ 50] MNIST Time 0.02830s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 47/ 50] FashionMNIST Time 0.03632s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 48/ 50] MNIST Time 0.02826s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 48/ 50] FashionMNIST Time 0.03593s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] MNIST Time 0.02867s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 49/ 50] FashionMNIST Time 0.02949s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 50/ 50] MNIST Time 0.02858s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 50/ 50] FashionMNIST Time 0.02897s Training Accuracy: 100.00% Test Accuracy: 62.50%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 56.25%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 62.50%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.