Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling LuxComponentArraysExt...
1521.3 ms ✓ Lux → LuxComponentArraysExt
1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling LuxMLUtilsExt...
2096.8 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 164 already precompiled.
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
create_model (generic function with 1 method)
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-04-16 03:44:32.315337: I external/xla/xla/service/service.cc:152] XLA service 0x4e8a2a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-04-16 03:44:32.315369: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1744775072.316093 46370 se_gpu_pjrt_client.cc:1040] Using BFC allocator.
I0000 00:00:1744775072.316192 46370 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1744775072.316237 46370 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1744775072.328247 46370 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1744775135.500686 46370 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1744775135.501979 46370 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1744775135.501987 46370 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1744775135.501994 46370 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1744775135.502000 46370 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1744775135.502007 46370 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1744775135.502013 46370 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1744775135.502019 46370 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1744775135.502025 46370 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1744775135.502031 46370 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-04-16 03:45:35.502061: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775135.504841 46370 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1744775135.504864 46370 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1744775135.504871 46370 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1744775135.504878 46370 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1744775135.504884 46370 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1744775135.504890 46370 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1744775135.504896 46370 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1744775135.504902 46370 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1744775135.504909 46370 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1744775135.504915 46370 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-04-16 03:45:35.504925: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775170.942897 46370 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1744775170.942939 46370 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1744775170.942943 46370 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1744775170.942945 46370 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1744775170.942948 46370 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1744775170.942951 46370 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1744775170.942953 46370 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1744775170.942956 46370 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1744775170.942960 46370 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1744775170.942963 46370 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-04-16 03:46:10.942972: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775170.944667 46370 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1744775170.944686 46370 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1744775170.944689 46370 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1744775170.944692 46370 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1744775170.944695 46370 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1744775170.944697 46370 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1744775170.944700 46370 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1744775170.944703 46370 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1744775170.944705 46370 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1744775170.944708 46370 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-04-16 03:46:10.944714: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775170.946415 46370 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1744775170.946437 46370 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1744775170.946440 46370 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1744775170.946442 46370 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1744775170.946445 46370 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1744775170.946448 46370 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1744775170.946451 46370 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1744775170.946453 46370 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1744775170.946456 46370 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1744775170.946458 46370 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-04-16 03:46:10.946465: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775170.948234 46370 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1744775170.948260 46370 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1744775170.948263 46370 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1744775170.948266 46370 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1744775170.948269 46370 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1744775170.948271 46370 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1744775170.948274 46370 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1744775170.948277 46370 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1744775170.948279 46370 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1744775170.948282 46370 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-04-16 03:46:10.948290: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775170.957772 46370 buffer_comparator.cc:156] Difference at 32: -nan, expected 0.585358
E0000 00:00:1744775170.957816 46370 buffer_comparator.cc:156] Difference at 33: -nan, expected 0.881101
E0000 00:00:1744775170.957819 46370 buffer_comparator.cc:156] Difference at 34: -nan, expected 0.738887
E0000 00:00:1744775170.957822 46370 buffer_comparator.cc:156] Difference at 35: -nan, expected 0.603294
E0000 00:00:1744775170.957825 46370 buffer_comparator.cc:156] Difference at 36: -nan, expected 1.04006
E0000 00:00:1744775170.957828 46370 buffer_comparator.cc:156] Difference at 37: -nan, expected 0.740676
E0000 00:00:1744775170.957830 46370 buffer_comparator.cc:156] Difference at 38: -nan, expected 0.766031
E0000 00:00:1744775170.957833 46370 buffer_comparator.cc:156] Difference at 39: -nan, expected 0.814752
E0000 00:00:1744775170.957836 46370 buffer_comparator.cc:156] Difference at 40: -nan, expected 0.431557
E0000 00:00:1744775170.957838 46370 buffer_comparator.cc:156] Difference at 41: -nan, expected 0.672244
2025-04-16 03:46:10.957848: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
[ 1/ 50] MNIST Time 38.91880s Training Accuracy: 34.86% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.04589s Training Accuracy: 31.54% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.09114s Training Accuracy: 37.40% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.03412s Training Accuracy: 44.14% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.03763s Training Accuracy: 41.99% Test Accuracy: 34.38%
[ 3/ 50] FashionMNIST Time 0.03522s Training Accuracy: 54.10% Test Accuracy: 53.12%
[ 4/ 50] MNIST Time 0.02859s Training Accuracy: 51.46% Test Accuracy: 43.75%
[ 4/ 50] FashionMNIST Time 0.03806s Training Accuracy: 59.08% Test Accuracy: 50.00%
[ 5/ 50] MNIST Time 0.02847s Training Accuracy: 55.76% Test Accuracy: 37.50%
[ 5/ 50] FashionMNIST Time 0.02782s Training Accuracy: 66.11% Test Accuracy: 56.25%
[ 6/ 50] MNIST Time 0.02832s Training Accuracy: 62.21% Test Accuracy: 40.62%
[ 6/ 50] FashionMNIST Time 0.02887s Training Accuracy: 72.17% Test Accuracy: 56.25%
[ 7/ 50] MNIST Time 0.03476s Training Accuracy: 67.87% Test Accuracy: 43.75%
[ 7/ 50] FashionMNIST Time 0.03046s Training Accuracy: 75.10% Test Accuracy: 56.25%
[ 8/ 50] MNIST Time 0.03663s Training Accuracy: 73.34% Test Accuracy: 53.12%
[ 8/ 50] FashionMNIST Time 0.02997s Training Accuracy: 79.59% Test Accuracy: 59.38%
[ 9/ 50] MNIST Time 0.03055s Training Accuracy: 79.88% Test Accuracy: 50.00%
[ 9/ 50] FashionMNIST Time 0.04096s Training Accuracy: 83.40% Test Accuracy: 59.38%
[ 10/ 50] MNIST Time 0.03340s Training Accuracy: 83.50% Test Accuracy: 59.38%
[ 10/ 50] FashionMNIST Time 0.03909s Training Accuracy: 85.94% Test Accuracy: 56.25%
[ 11/ 50] MNIST Time 0.03312s Training Accuracy: 86.52% Test Accuracy: 56.25%
[ 11/ 50] FashionMNIST Time 0.03172s Training Accuracy: 88.57% Test Accuracy: 56.25%
[ 12/ 50] MNIST Time 0.04441s Training Accuracy: 88.38% Test Accuracy: 50.00%
[ 12/ 50] FashionMNIST Time 0.03080s Training Accuracy: 91.11% Test Accuracy: 62.50%
[ 13/ 50] MNIST Time 0.03999s Training Accuracy: 92.48% Test Accuracy: 59.38%
[ 13/ 50] FashionMNIST Time 0.03201s Training Accuracy: 93.07% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.03302s Training Accuracy: 94.24% Test Accuracy: 59.38%
[ 14/ 50] FashionMNIST Time 0.03256s Training Accuracy: 94.53% Test Accuracy: 62.50%
[ 15/ 50] MNIST Time 0.03260s Training Accuracy: 96.00% Test Accuracy: 59.38%
[ 15/ 50] FashionMNIST Time 0.04075s Training Accuracy: 95.02% Test Accuracy: 62.50%
[ 16/ 50] MNIST Time 0.03124s Training Accuracy: 98.05% Test Accuracy: 65.62%
[ 16/ 50] FashionMNIST Time 0.03106s Training Accuracy: 96.88% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.03197s Training Accuracy: 98.54% Test Accuracy: 62.50%
[ 17/ 50] FashionMNIST Time 0.03312s Training Accuracy: 96.58% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.04307s Training Accuracy: 99.02% Test Accuracy: 65.62%
[ 18/ 50] FashionMNIST Time 0.03118s Training Accuracy: 96.68% Test Accuracy: 68.75%
[ 19/ 50] MNIST Time 0.03087s Training Accuracy: 99.51% Test Accuracy: 65.62%
[ 19/ 50] FashionMNIST Time 0.03746s Training Accuracy: 98.24% Test Accuracy: 71.88%
[ 20/ 50] MNIST Time 0.03226s Training Accuracy: 99.80% Test Accuracy: 62.50%
[ 20/ 50] FashionMNIST Time 0.04409s Training Accuracy: 98.24% Test Accuracy: 71.88%
[ 21/ 50] MNIST Time 0.04536s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 21/ 50] FashionMNIST Time 0.04570s Training Accuracy: 98.73% Test Accuracy: 71.88%
[ 22/ 50] MNIST Time 0.03205s Training Accuracy: 99.80% Test Accuracy: 68.75%
[ 22/ 50] FashionMNIST Time 0.03485s Training Accuracy: 99.51% Test Accuracy: 68.75%
[ 23/ 50] MNIST Time 0.04044s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 23/ 50] FashionMNIST Time 0.03030s Training Accuracy: 99.22% Test Accuracy: 71.88%
[ 24/ 50] MNIST Time 0.03964s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 24/ 50] FashionMNIST Time 0.03730s Training Accuracy: 99.80% Test Accuracy: 68.75%
[ 25/ 50] MNIST Time 0.03419s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 25/ 50] FashionMNIST Time 0.03029s Training Accuracy: 99.90% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.03194s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 26/ 50] FashionMNIST Time 0.04040s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 27/ 50] MNIST Time 0.03205s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 27/ 50] FashionMNIST Time 0.03129s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 28/ 50] MNIST Time 0.03303s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.03912s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 29/ 50] MNIST Time 0.04245s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 29/ 50] FashionMNIST Time 0.03398s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] MNIST Time 0.04756s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 30/ 50] FashionMNIST Time 0.03560s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 31/ 50] MNIST Time 0.03565s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 31/ 50] FashionMNIST Time 0.05222s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 32/ 50] MNIST Time 0.03304s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] FashionMNIST Time 0.04597s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 33/ 50] MNIST Time 0.03589s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 33/ 50] FashionMNIST Time 0.02996s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 34/ 50] MNIST Time 0.04052s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.02995s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 35/ 50] MNIST Time 0.04803s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] FashionMNIST Time 0.03419s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 36/ 50] MNIST Time 0.03486s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] FashionMNIST Time 0.03301s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 37/ 50] MNIST Time 0.03166s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] FashionMNIST Time 0.03831s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 38/ 50] MNIST Time 0.03298s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 38/ 50] FashionMNIST Time 0.03116s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 39/ 50] MNIST Time 0.02872s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] FashionMNIST Time 0.03310s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 40/ 50] MNIST Time 0.03679s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] FashionMNIST Time 0.03688s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 41/ 50] MNIST Time 0.03960s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] FashionMNIST Time 0.03095s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 42/ 50] MNIST Time 0.03294s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] FashionMNIST Time 0.04098s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 43/ 50] MNIST Time 0.03325s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] FashionMNIST Time 0.03876s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 44/ 50] MNIST Time 0.03123s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 44/ 50] FashionMNIST Time 0.03593s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 45/ 50] MNIST Time 0.02888s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] FashionMNIST Time 0.02979s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 46/ 50] MNIST Time 0.03526s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] FashionMNIST Time 0.02909s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 47/ 50] MNIST Time 0.03092s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 47/ 50] FashionMNIST Time 0.03097s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 48/ 50] MNIST Time 0.02879s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 48/ 50] FashionMNIST Time 0.03584s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 49/ 50] MNIST Time 0.02914s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] FashionMNIST Time 0.02852s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 50/ 50] MNIST Time 0.03321s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 50/ 50] FashionMNIST Time 0.03425s Training Accuracy: 100.00% Test Accuracy: 71.88%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 62.50%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 71.88%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.