Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux, ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random,
Reactant
Precompiling ComponentArrays...
968.1 ms ✓ ComponentArrays
1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
636.9 ms ✓ MLDataDevices → MLDataDevicesComponentArraysExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
516.2 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1442.4 ms ✓ Lux → LuxComponentArraysExt
2113.0 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
3 dependencies successfully precompiled in 2 seconds. 112 already precompiled.
Precompiling MLDatasets...
390.9 ms ✓ Glob
419.0 ms ✓ WorkerUtilities
455.8 ms ✓ BufferedStreams
358.6 ms ✓ SimpleBufferStream
590.8 ms ✓ URIs
461.6 ms ✓ CodecZlib
333.5 ms ✓ PackageExtensionCompat
369.3 ms ✓ BitFlags
669.4 ms ✓ GZip
709.9 ms ✓ ConcurrentUtilities
598.0 ms ✓ ZipFile
544.7 ms ✓ LoggingExtras
813.0 ms ✓ StructTypes
1021.8 ms ✓ MbedTLS
590.1 ms ✓ MPIPreferences
356.6 ms ✓ InternedStrings
511.8 ms ✓ ExceptionUnwrapping
2172.3 ms ✓ PeriodicTable
580.7 ms ✓ Chemfiles_jll
2775.9 ms ✓ UnitfulAtomic
618.5 ms ✓ libaec_jll
478.4 ms ✓ MicrosoftMPI_jll
508.3 ms ✓ InlineStrings → ParsersExt
553.6 ms ✓ StringEncodings
1376.3 ms ✓ Transducers → TransducersDataFramesExt
1852.7 ms ✓ ImageShow
1619.2 ms ✓ BangBang → BangBangDataFramesExt
436.4 ms ✓ StridedViews
1645.0 ms ✓ NPZ
1850.9 ms ✓ OpenSSL
1117.1 ms ✓ OpenMPI_jll
1447.1 ms ✓ MPICH_jll
1120.6 ms ✓ MPItrampoline_jll
783.9 ms ✓ WeakRefStrings
2182.7 ms ✓ AtomsBase
2297.4 ms ✓ Pickle
1772.3 ms ✓ HDF5_jll
9757.5 ms ✓ JSON3
2320.1 ms ✓ Chemfiles
7327.9 ms ✓ HDF5
2346.2 ms ✓ MAT
18230.7 ms ✓ HTTP
16356.5 ms ✓ CSV
1808.5 ms ✓ FileIO → HTTPExt
2992.6 ms ✓ DataDeps
9032.5 ms ✓ MLDatasets
46 dependencies successfully precompiled in 43 seconds. 154 already precompiled.
Precompiling OneHotArrays...
1000.5 ms ✓ OneHotArrays
1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
747.1 ms ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling ComponentArraysReactantExt...
12255.3 ms ✓ ComponentArrays → ComponentArraysReactantExt
1 dependency successfully precompiled in 13 seconds. 96 already precompiled.
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing, Int},
n_eval::Union{Nothing, Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize = min(batchsize, size(x_train, 4)), shuffle = true, partial = false
),
DataLoader(
(x_test, y_test);
batchsize = min(batchsize, size(x_test, 4)), shuffle = false, partial = false
),
)
end
function load_datasets(batchsize = 32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)
Implement a HyperNet Layer
julia
function HyperNet(
weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer
)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |> getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch = :HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
HyperNet (generic function with 1 method)
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator = Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride = 2),
Conv((3, 3), 16 => 32, relu; stride = 2),
Conv((3, 3), 32 => 64, relu; stride = 2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10)
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
),
core_network
)
end
create_model (generic function with 1 method)
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
accuracy (generic function with 1 method)
Training
julia
function train()
dev = reactant_device(; force = true)
model = create_model()
dataloaders = load_datasets() |> dev
Random.seed!(1234)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))
### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(), CrossEntropyLoss(; logits = Val(true)),
((concrete_data_idx, x), y), train_state; return_gradients = Val(false)
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, train_dataloader, concrete_data_idx
) * 100;
digits = 2
)
test_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, test_dataloader, concrete_data_idx
) * 100;
digits = 2
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, train_dataloader, concrete_data_idx
) * 100;
digits = 2
)
test_acc = round(
accuracy(
model_compiled, train_state.parameters,
train_state.states, test_dataloader, concrete_data_idx
) * 100;
digits = 2
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-03-08 15:39:37.638068: I external/xla/xla/service/service.cc:152] XLA service 0xee76fb0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-08 15:39:37.638111: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741448377.638985 3567203 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741448377.639060 3567203 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741448377.639111 3567203 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741448377.650453 3567203 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1741448502.810384 3567203 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1741448502.810444 3567203 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1741448502.810452 3567203 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1741448502.810459 3567203 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1741448502.810466 3567203 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1741448502.810472 3567203 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1741448502.810479 3567203 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1741448502.810486 3567203 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1741448502.810492 3567203 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1741448502.810499 3567203 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-08 15:41:42.810515: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448502.813414 3567203 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1741448502.813449 3567203 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1741448502.813457 3567203 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1741448502.813464 3567203 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1741448502.813471 3567203 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1741448502.813478 3567203 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1741448502.813485 3567203 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1741448502.813491 3567203 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1741448502.813498 3567203 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1741448502.813505 3567203 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-08 15:41:42.813516: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448528.637669 3567203 buffer_comparator.cc:156] Difference at 32: 0, expected 0.585358
E0000 00:00:1741448528.637730 3567203 buffer_comparator.cc:156] Difference at 33: 0, expected 0.881101
E0000 00:00:1741448528.637740 3567203 buffer_comparator.cc:156] Difference at 34: 0, expected 0.738887
E0000 00:00:1741448528.637746 3567203 buffer_comparator.cc:156] Difference at 35: 0, expected 0.603294
E0000 00:00:1741448528.637753 3567203 buffer_comparator.cc:156] Difference at 36: 0, expected 1.04006
E0000 00:00:1741448528.637759 3567203 buffer_comparator.cc:156] Difference at 37: 0.364706, expected 0.740676
E0000 00:00:1741448528.637766 3567203 buffer_comparator.cc:156] Difference at 38: 0.992157, expected 0.766031
E0000 00:00:1741448528.637772 3567203 buffer_comparator.cc:156] Difference at 40: 0.611765, expected 0.431557
E0000 00:00:1741448528.637781 3567203 buffer_comparator.cc:156] Difference at 41: 0, expected 0.672244
E0000 00:00:1741448528.637788 3567203 buffer_comparator.cc:156] Difference at 42: 0, expected 0.692783
2025-03-08 15:42:08.637803: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448528.641748 3567203 buffer_comparator.cc:156] Difference at 160: 0.704499, expected 2.24976
E0000 00:00:1741448528.641764 3567203 buffer_comparator.cc:156] Difference at 161: 0.996866, expected 1.51732
E0000 00:00:1741448528.641769 3567203 buffer_comparator.cc:156] Difference at 162: 0.776841, expected 2.23336
E0000 00:00:1741448528.641773 3567203 buffer_comparator.cc:156] Difference at 163: 0.660587, expected 1.99652
E0000 00:00:1741448528.641777 3567203 buffer_comparator.cc:156] Difference at 164: 1.03633, expected 1.7657
E0000 00:00:1741448528.641781 3567203 buffer_comparator.cc:156] Difference at 165: 0.912078, expected 1.87991
E0000 00:00:1741448528.641785 3567203 buffer_comparator.cc:156] Difference at 166: 0.809391, expected 1.82301
E0000 00:00:1741448528.641789 3567203 buffer_comparator.cc:156] Difference at 167: 0.822743, expected 2.01534
E0000 00:00:1741448528.641793 3567203 buffer_comparator.cc:156] Difference at 168: 0.422923, expected 2.17682
E0000 00:00:1741448528.641797 3567203 buffer_comparator.cc:156] Difference at 169: 0.858019, expected 2.28508
2025-03-08 15:42:08.641803: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448528.644027 3567203 buffer_comparator.cc:156] Difference at 160: 0.704499, expected 2.24976
E0000 00:00:1741448528.644043 3567203 buffer_comparator.cc:156] Difference at 161: 0.996866, expected 1.51732
E0000 00:00:1741448528.644047 3567203 buffer_comparator.cc:156] Difference at 162: 0.776841, expected 2.23336
E0000 00:00:1741448528.644051 3567203 buffer_comparator.cc:156] Difference at 163: 0.660587, expected 1.99652
E0000 00:00:1741448528.644055 3567203 buffer_comparator.cc:156] Difference at 164: 1.03633, expected 1.7657
E0000 00:00:1741448528.644059 3567203 buffer_comparator.cc:156] Difference at 165: 0.912078, expected 1.87991
E0000 00:00:1741448528.644063 3567203 buffer_comparator.cc:156] Difference at 166: 0.809391, expected 1.82301
E0000 00:00:1741448528.644067 3567203 buffer_comparator.cc:156] Difference at 167: 0.822743, expected 2.01534
E0000 00:00:1741448528.644071 3567203 buffer_comparator.cc:156] Difference at 168: 0.422923, expected 2.17682
E0000 00:00:1741448528.644075 3567203 buffer_comparator.cc:156] Difference at 169: 0.858019, expected 2.28508
2025-03-08 15:42:08.644081: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448528.646290 3567203 buffer_comparator.cc:156] Difference at 320: 0.635997, expected 2.66964
E0000 00:00:1741448528.646305 3567203 buffer_comparator.cc:156] Difference at 321: 0.937885, expected 1.74014
E0000 00:00:1741448528.646309 3567203 buffer_comparator.cc:156] Difference at 322: 0.613587, expected 2.52041
E0000 00:00:1741448528.646313 3567203 buffer_comparator.cc:156] Difference at 323: 0.704588, expected 2.17431
E0000 00:00:1741448528.646317 3567203 buffer_comparator.cc:156] Difference at 324: 0.827663, expected 1.93476
E0000 00:00:1741448528.646321 3567203 buffer_comparator.cc:156] Difference at 325: 0.78249, expected 2.28937
E0000 00:00:1741448528.646325 3567203 buffer_comparator.cc:156] Difference at 326: 0.937569, expected 2.2459
E0000 00:00:1741448528.646329 3567203 buffer_comparator.cc:156] Difference at 327: 0.916275, expected 2.22886
E0000 00:00:1741448528.646333 3567203 buffer_comparator.cc:156] Difference at 328: 0.879483, expected 2.35168
E0000 00:00:1741448528.646337 3567203 buffer_comparator.cc:156] Difference at 329: 0.745365, expected 2.40895
2025-03-08 15:42:08.646345: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448528.648554 3567203 buffer_comparator.cc:156] Difference at 320: 0.635997, expected 2.66964
E0000 00:00:1741448528.648569 3567203 buffer_comparator.cc:156] Difference at 321: 0.937885, expected 1.74014
E0000 00:00:1741448528.648574 3567203 buffer_comparator.cc:156] Difference at 322: 0.613587, expected 2.52041
E0000 00:00:1741448528.648578 3567203 buffer_comparator.cc:156] Difference at 323: 0.704588, expected 2.17431
E0000 00:00:1741448528.648582 3567203 buffer_comparator.cc:156] Difference at 324: 0.827663, expected 1.93476
E0000 00:00:1741448528.648586 3567203 buffer_comparator.cc:156] Difference at 325: 0.78249, expected 2.28937
E0000 00:00:1741448528.648590 3567203 buffer_comparator.cc:156] Difference at 326: 0.937569, expected 2.2459
E0000 00:00:1741448528.648594 3567203 buffer_comparator.cc:156] Difference at 327: 0.916275, expected 2.22886
E0000 00:00:1741448528.648598 3567203 buffer_comparator.cc:156] Difference at 328: 0.879483, expected 2.35168
E0000 00:00:1741448528.648602 3567203 buffer_comparator.cc:156] Difference at 329: 0.745365, expected 2.40895
2025-03-08 15:42:08.648608: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
[ 1/ 50] MNIST Time 28.54848s Training Accuracy: 35.16% Test Accuracy: 40.62%
[ 1/ 50] FashionMNIST Time 0.03629s Training Accuracy: 32.42% Test Accuracy: 46.88%
[ 2/ 50] MNIST Time 0.03454s Training Accuracy: 33.69% Test Accuracy: 31.25%
[ 2/ 50] FashionMNIST Time 0.08035s Training Accuracy: 47.75% Test Accuracy: 50.00%
[ 3/ 50] MNIST Time 0.03006s Training Accuracy: 41.80% Test Accuracy: 46.88%
[ 3/ 50] FashionMNIST Time 0.03112s Training Accuracy: 54.88% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.02650s Training Accuracy: 46.97% Test Accuracy: 37.50%
[ 4/ 50] FashionMNIST Time 0.02468s Training Accuracy: 60.55% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.02426s Training Accuracy: 54.49% Test Accuracy: 46.88%
[ 5/ 50] FashionMNIST Time 0.06597s Training Accuracy: 65.92% Test Accuracy: 59.38%
[ 6/ 50] MNIST Time 0.02475s Training Accuracy: 59.77% Test Accuracy: 50.00%
[ 6/ 50] FashionMNIST Time 0.02469s Training Accuracy: 70.31% Test Accuracy: 62.50%
[ 7/ 50] MNIST Time 0.02462s Training Accuracy: 65.33% Test Accuracy: 56.25%
[ 7/ 50] FashionMNIST Time 0.02478s Training Accuracy: 75.59% Test Accuracy: 62.50%
[ 8/ 50] MNIST Time 0.02465s Training Accuracy: 71.09% Test Accuracy: 43.75%
[ 8/ 50] FashionMNIST Time 0.05393s Training Accuracy: 80.27% Test Accuracy: 59.38%
[ 9/ 50] MNIST Time 0.02572s Training Accuracy: 75.88% Test Accuracy: 43.75%
[ 9/ 50] FashionMNIST Time 0.02551s Training Accuracy: 81.54% Test Accuracy: 59.38%
[ 10/ 50] MNIST Time 0.02569s Training Accuracy: 80.66% Test Accuracy: 50.00%
[ 10/ 50] FashionMNIST Time 0.04109s Training Accuracy: 87.11% Test Accuracy: 59.38%
[ 11/ 50] MNIST Time 0.02470s Training Accuracy: 85.06% Test Accuracy: 46.88%
[ 11/ 50] FashionMNIST Time 0.02467s Training Accuracy: 88.09% Test Accuracy: 71.88%
[ 12/ 50] MNIST Time 0.02448s Training Accuracy: 87.60% Test Accuracy: 43.75%
[ 12/ 50] FashionMNIST Time 0.03749s Training Accuracy: 90.23% Test Accuracy: 65.62%
[ 13/ 50] MNIST Time 0.02538s Training Accuracy: 89.16% Test Accuracy: 46.88%
[ 13/ 50] FashionMNIST Time 0.02459s Training Accuracy: 93.46% Test Accuracy: 68.75%
[ 14/ 50] MNIST Time 0.02439s Training Accuracy: 93.36% Test Accuracy: 50.00%
[ 14/ 50] FashionMNIST Time 0.02466s Training Accuracy: 94.53% Test Accuracy: 71.88%
[ 15/ 50] MNIST Time 0.02458s Training Accuracy: 94.34% Test Accuracy: 46.88%
[ 15/ 50] FashionMNIST Time 0.02491s Training Accuracy: 95.31% Test Accuracy: 62.50%
[ 16/ 50] MNIST Time 0.02456s Training Accuracy: 96.97% Test Accuracy: 56.25%
[ 16/ 50] FashionMNIST Time 0.02503s Training Accuracy: 97.27% Test Accuracy: 65.62%
[ 17/ 50] MNIST Time 0.02466s Training Accuracy: 98.34% Test Accuracy: 53.12%
[ 17/ 50] FashionMNIST Time 0.02532s Training Accuracy: 97.36% Test Accuracy: 71.88%
[ 18/ 50] MNIST Time 0.03622s Training Accuracy: 98.83% Test Accuracy: 59.38%
[ 18/ 50] FashionMNIST Time 0.02458s Training Accuracy: 98.34% Test Accuracy: 68.75%
[ 19/ 50] MNIST Time 0.02458s Training Accuracy: 99.51% Test Accuracy: 56.25%
[ 19/ 50] FashionMNIST Time 0.02460s Training Accuracy: 99.02% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.03623s Training Accuracy: 99.90% Test Accuracy: 56.25%
[ 20/ 50] FashionMNIST Time 0.02473s Training Accuracy: 99.22% Test Accuracy: 71.88%
[ 21/ 50] MNIST Time 0.02575s Training Accuracy: 99.90% Test Accuracy: 56.25%
[ 21/ 50] FashionMNIST Time 0.02552s Training Accuracy: 99.41% Test Accuracy: 68.75%
[ 22/ 50] MNIST Time 0.02480s Training Accuracy: 99.90% Test Accuracy: 56.25%
[ 22/ 50] FashionMNIST Time 0.02492s Training Accuracy: 99.71% Test Accuracy: 65.62%
[ 23/ 50] MNIST Time 0.02481s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 23/ 50] FashionMNIST Time 0.02462s Training Accuracy: 99.80% Test Accuracy: 68.75%
[ 24/ 50] MNIST Time 0.02482s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 24/ 50] FashionMNIST Time 0.02573s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 25/ 50] MNIST Time 0.02607s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 25/ 50] FashionMNIST Time 0.03856s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.02559s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 26/ 50] FashionMNIST Time 0.02586s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 27/ 50] MNIST Time 0.02484s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 27/ 50] FashionMNIST Time 0.02403s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 28/ 50] MNIST Time 0.02475s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 28/ 50] FashionMNIST Time 0.02506s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 29/ 50] MNIST Time 0.02460s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 29/ 50] FashionMNIST Time 0.02485s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 30/ 50] MNIST Time 0.02459s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 30/ 50] FashionMNIST Time 0.02460s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 31/ 50] MNIST Time 0.03671s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 31/ 50] FashionMNIST Time 0.02446s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 32/ 50] MNIST Time 0.02446s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 32/ 50] FashionMNIST Time 0.02643s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 33/ 50] MNIST Time 0.03690s Training Accuracy: 100.00% Test Accuracy: 59.38%
[ 33/ 50] FashionMNIST Time 0.02458s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 34/ 50] MNIST Time 0.02463s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.02466s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 35/ 50] MNIST Time 0.02674s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] FashionMNIST Time 0.02461s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 36/ 50] MNIST Time 0.02461s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] FashionMNIST Time 0.02500s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 37/ 50] MNIST Time 0.02464s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] FashionMNIST Time 0.02452s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 38/ 50] MNIST Time 0.02491s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 38/ 50] FashionMNIST Time 0.03697s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 39/ 50] MNIST Time 0.02473s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] FashionMNIST Time 0.02454s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 40/ 50] MNIST Time 0.02459s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] FashionMNIST Time 0.04066s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 41/ 50] MNIST Time 0.02638s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] FashionMNIST Time 0.02535s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 42/ 50] MNIST Time 0.02436s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] FashionMNIST Time 0.02486s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 43/ 50] MNIST Time 0.02606s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] FashionMNIST Time 0.02600s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 44/ 50] MNIST Time 0.02477s Training Accuracy: 100.00% Test Accuracy: 65.62%
[ 44/ 50] FashionMNIST Time 0.02505s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 45/ 50] MNIST Time 0.02749s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] FashionMNIST Time 0.02627s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 46/ 50] MNIST Time 0.03704s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] FashionMNIST Time 0.02567s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 47/ 50] MNIST Time 0.02569s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 47/ 50] FashionMNIST Time 0.02553s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 48/ 50] MNIST Time 0.02431s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 48/ 50] FashionMNIST Time 0.02501s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 49/ 50] MNIST Time 0.02633s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] FashionMNIST Time 0.02652s Training Accuracy: 100.00% Test Accuracy: 75.00%
[ 50/ 50] MNIST Time 0.02583s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 50/ 50] FashionMNIST Time 0.02661s Training Accuracy: 100.00% Test Accuracy: 75.00%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 62.50%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 75.00%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.