Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling LuxComponentArraysExt...
668.1 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1396.1 ms ✓ Lux → LuxComponentArraysExt
2 dependencies successfully precompiled in 2 seconds. 113 already precompiled.
Precompiling LuxMLUtilsExt...
2046.6 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 169 already precompiled.
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
create_model (generic function with 1 method)
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-03-28 04:30:08.099642: I external/xla/xla/service/service.cc:152] XLA service 0x6cb84e0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-28 04:30:08.099676: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1743136208.100529 3381069 se_gpu_pjrt_client.cc:1039] Using BFC allocator.
I0000 00:00:1743136208.100607 3381069 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1743136208.100650 3381069 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1743136208.112216 3381069 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1743136265.603630 3381069 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1743136265.604599 3381069 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1743136265.604607 3381069 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1743136265.604614 3381069 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1743136265.604621 3381069 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1743136265.604627 3381069 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1743136265.604633 3381069 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1743136265.604639 3381069 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1743136265.604646 3381069 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1743136265.604652 3381069 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-28 04:31:05.604668: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136265.607416 3381069 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1743136265.607441 3381069 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1743136265.607449 3381069 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1743136265.607455 3381069 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1743136265.607462 3381069 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1743136265.607468 3381069 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1743136265.607474 3381069 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1743136265.607481 3381069 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1743136265.607487 3381069 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1743136265.607493 3381069 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-28 04:31:05.607504: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136300.135365 3381069 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1743136300.135404 3381069 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1743136300.135408 3381069 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1743136300.135413 3381069 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1743136300.135417 3381069 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1743136300.135421 3381069 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1743136300.135425 3381069 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1743136300.135429 3381069 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1743136300.135434 3381069 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1743136300.135438 3381069 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-03-28 04:31:40.135449: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136300.137692 3381069 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1743136300.137707 3381069 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1743136300.137711 3381069 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1743136300.137715 3381069 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1743136300.137719 3381069 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1743136300.137724 3381069 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1743136300.137728 3381069 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1743136300.137732 3381069 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1743136300.137736 3381069 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1743136300.137740 3381069 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-03-28 04:31:40.137746: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136300.139998 3381069 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1743136300.140013 3381069 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1743136300.140017 3381069 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1743136300.140021 3381069 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1743136300.140025 3381069 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1743136300.140029 3381069 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1743136300.140032 3381069 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1743136300.140036 3381069 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1743136300.140040 3381069 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1743136300.140044 3381069 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-03-28 04:31:40.140050: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136300.142278 3381069 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1743136300.142293 3381069 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1743136300.142297 3381069 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1743136300.142301 3381069 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1743136300.142306 3381069 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1743136300.142310 3381069 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1743136300.142314 3381069 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1743136300.142318 3381069 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1743136300.142322 3381069 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1743136300.142326 3381069 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-03-28 04:31:40.142332: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136300.153953 3381069 buffer_comparator.cc:156] Difference at 32: -nan, expected 0.585358
E0000 00:00:1743136300.153964 3381069 buffer_comparator.cc:156] Difference at 33: -nan, expected 0.881101
E0000 00:00:1743136300.153967 3381069 buffer_comparator.cc:156] Difference at 34: -nan, expected 0.738887
E0000 00:00:1743136300.153970 3381069 buffer_comparator.cc:156] Difference at 35: -nan, expected 0.603294
E0000 00:00:1743136300.153973 3381069 buffer_comparator.cc:156] Difference at 36: -nan, expected 1.04006
E0000 00:00:1743136300.153976 3381069 buffer_comparator.cc:156] Difference at 37: -nan, expected 0.740676
E0000 00:00:1743136300.153979 3381069 buffer_comparator.cc:156] Difference at 38: -nan, expected 0.766031
E0000 00:00:1743136300.153981 3381069 buffer_comparator.cc:156] Difference at 39: -nan, expected 0.814752
E0000 00:00:1743136300.153984 3381069 buffer_comparator.cc:156] Difference at 40: -nan, expected 0.431557
E0000 00:00:1743136300.153987 3381069 buffer_comparator.cc:156] Difference at 41: -nan, expected 0.672244
2025-03-28 04:31:40.153991: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
[ 1/ 50] MNIST Time 37.75709s Training Accuracy: 35.06% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.04267s Training Accuracy: 34.38% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.07874s Training Accuracy: 34.47% Test Accuracy: 40.62%
[ 2/ 50] FashionMNIST Time 0.03368s Training Accuracy: 46.48% Test Accuracy: 53.12%
[ 3/ 50] MNIST Time 0.03505s Training Accuracy: 40.82% Test Accuracy: 31.25%
[ 3/ 50] FashionMNIST Time 0.03268s Training Accuracy: 55.18% Test Accuracy: 59.38%
[ 4/ 50] MNIST Time 0.02995s Training Accuracy: 48.24% Test Accuracy: 40.62%
[ 4/ 50] FashionMNIST Time 0.02902s Training Accuracy: 63.48% Test Accuracy: 53.12%
[ 5/ 50] MNIST Time 0.02958s Training Accuracy: 54.88% Test Accuracy: 43.75%
[ 5/ 50] FashionMNIST Time 0.02862s Training Accuracy: 69.53% Test Accuracy: 59.38%
[ 6/ 50] MNIST Time 0.02941s Training Accuracy: 60.35% Test Accuracy: 43.75%
[ 6/ 50] FashionMNIST Time 0.03849s Training Accuracy: 72.56% Test Accuracy: 62.50%
[ 7/ 50] MNIST Time 0.02909s Training Accuracy: 66.60% Test Accuracy: 46.88%
[ 7/ 50] FashionMNIST Time 0.03742s Training Accuracy: 76.56% Test Accuracy: 65.62%
[ 8/ 50] MNIST Time 0.02866s Training Accuracy: 73.34% Test Accuracy: 46.88%
[ 8/ 50] FashionMNIST Time 0.03892s Training Accuracy: 81.84% Test Accuracy: 68.75%
[ 9/ 50] MNIST Time 0.02802s Training Accuracy: 77.64% Test Accuracy: 46.88%
[ 9/ 50] FashionMNIST Time 0.03922s Training Accuracy: 85.16% Test Accuracy: 68.75%
[ 10/ 50] MNIST Time 0.03796s Training Accuracy: 80.96% Test Accuracy: 53.12%
[ 10/ 50] FashionMNIST Time 0.03270s Training Accuracy: 86.72% Test Accuracy: 65.62%
[ 11/ 50] MNIST Time 0.03192s Training Accuracy: 84.08% Test Accuracy: 56.25%
[ 11/ 50] FashionMNIST Time 0.05059s Training Accuracy: 90.82% Test Accuracy: 71.88%
[ 12/ 50] MNIST Time 0.02909s Training Accuracy: 86.72% Test Accuracy: 43.75%
[ 12/ 50] FashionMNIST Time 0.03840s Training Accuracy: 92.68% Test Accuracy: 75.00%
[ 13/ 50] MNIST Time 0.02946s Training Accuracy: 90.04% Test Accuracy: 50.00%
[ 13/ 50] FashionMNIST Time 0.05467s Training Accuracy: 93.36% Test Accuracy: 65.62%
[ 14/ 50] MNIST Time 0.03002s Training Accuracy: 93.36% Test Accuracy: 50.00%
[ 14/ 50] FashionMNIST Time 0.03657s Training Accuracy: 95.12% Test Accuracy: 65.62%
[ 15/ 50] MNIST Time 0.02908s Training Accuracy: 94.14% Test Accuracy: 50.00%
[ 15/ 50] FashionMNIST Time 0.02804s Training Accuracy: 96.48% Test Accuracy: 68.75%
[ 16/ 50] MNIST Time 0.03602s Training Accuracy: 96.68% Test Accuracy: 56.25%
[ 16/ 50] FashionMNIST Time 0.02802s Training Accuracy: 97.07% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.03713s Training Accuracy: 98.34% Test Accuracy: 53.12%
[ 17/ 50] FashionMNIST Time 0.02842s Training Accuracy: 97.95% Test Accuracy: 71.88%
[ 18/ 50] MNIST Time 0.04627s Training Accuracy: 99.32% Test Accuracy: 53.12%
[ 18/ 50] FashionMNIST Time 0.02771s Training Accuracy: 97.07% Test Accuracy: 71.88%
[ 19/ 50] MNIST Time 0.03132s Training Accuracy: 99.32% Test Accuracy: 53.12%
[ 19/ 50] FashionMNIST Time 0.03583s Training Accuracy: 98.83% Test Accuracy: 71.88%
[ 20/ 50] MNIST Time 0.02813s Training Accuracy: 99.90% Test Accuracy: 56.25%
[ 20/ 50] FashionMNIST Time 0.03146s Training Accuracy: 99.41% Test Accuracy: 71.88%
[ 21/ 50] MNIST Time 0.03432s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 21/ 50] FashionMNIST Time 0.02969s Training Accuracy: 99.51% Test Accuracy: 68.75%
[ 22/ 50] MNIST Time 0.03973s Training Accuracy: 99.90% Test Accuracy: 53.12%
[ 22/ 50] FashionMNIST Time 0.02803s Training Accuracy: 99.71% Test Accuracy: 68.75%
[ 23/ 50] MNIST Time 0.03466s Training Accuracy: 99.90% Test Accuracy: 56.25%
[ 23/ 50] FashionMNIST Time 0.03239s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 24/ 50] MNIST Time 0.02799s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 24/ 50] FashionMNIST Time 0.04212s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 25/ 50] MNIST Time 0.02948s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 25/ 50] FashionMNIST Time 0.03995s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 26/ 50] MNIST Time 0.02774s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 26/ 50] FashionMNIST Time 0.03664s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 27/ 50] MNIST Time 0.03102s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 27/ 50] FashionMNIST Time 0.02823s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 28/ 50] MNIST Time 0.04109s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 28/ 50] FashionMNIST Time 0.02803s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 29/ 50] MNIST Time 0.03562s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 29/ 50] FashionMNIST Time 0.03202s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 30/ 50] MNIST Time 0.02864s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 30/ 50] FashionMNIST Time 0.02958s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 31/ 50] MNIST Time 0.02794s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 31/ 50] FashionMNIST Time 0.04448s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 32/ 50] MNIST Time 0.02832s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 32/ 50] FashionMNIST Time 0.03999s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 33/ 50] MNIST Time 0.02891s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 33/ 50] FashionMNIST Time 0.02975s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 34/ 50] MNIST Time 0.02961s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 34/ 50] FashionMNIST Time 0.02866s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 35/ 50] MNIST Time 0.03529s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 35/ 50] FashionMNIST Time 0.02839s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 36/ 50] MNIST Time 0.03640s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 36/ 50] FashionMNIST Time 0.03029s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 37/ 50] MNIST Time 0.02978s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 37/ 50] FashionMNIST Time 0.03685s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 38/ 50] MNIST Time 0.02917s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 38/ 50] FashionMNIST Time 0.03857s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 39/ 50] MNIST Time 0.02921s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 39/ 50] FashionMNIST Time 0.02788s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 40/ 50] MNIST Time 0.02939s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 40/ 50] FashionMNIST Time 0.02925s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 41/ 50] MNIST Time 0.03749s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 41/ 50] FashionMNIST Time 0.02817s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 42/ 50] MNIST Time 0.03697s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 42/ 50] FashionMNIST Time 0.02842s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 43/ 50] MNIST Time 0.03001s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 43/ 50] FashionMNIST Time 0.04860s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] MNIST Time 0.02960s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 44/ 50] FashionMNIST Time 0.03534s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 45/ 50] MNIST Time 0.02829s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 45/ 50] FashionMNIST Time 0.02779s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 46/ 50] MNIST Time 0.02769s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 46/ 50] FashionMNIST Time 0.02857s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 47/ 50] MNIST Time 0.03701s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 47/ 50] FashionMNIST Time 0.02867s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 48/ 50] MNIST Time 0.03833s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 48/ 50] FashionMNIST Time 0.03150s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 49/ 50] MNIST Time 0.03445s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 49/ 50] FashionMNIST Time 0.03588s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 50/ 50] MNIST Time 0.02951s Training Accuracy: 100.00% Test Accuracy: 56.25%
[ 50/ 50] FashionMNIST Time 0.03822s Training Accuracy: 100.00% Test Accuracy: 65.62%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 56.25%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 65.62%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.