Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling LuxComponentArraysExt...
   1521.3 ms  ✓ Lux → LuxComponentArraysExt
  1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling LuxMLUtilsExt...
   2096.8 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 164 already precompiled.

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end
create_model (generic function with 1 method)

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
2025-04-16 03:44:32.315337: I external/xla/xla/service/service.cc:152] XLA service 0x4e8a2a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-04-16 03:44:32.315369: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1744775072.316093   46370 se_gpu_pjrt_client.cc:1040] Using BFC allocator.
I0000 00:00:1744775072.316192   46370 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1744775072.316237   46370 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1744775072.328247   46370 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1744775135.500686   46370 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1744775135.501979   46370 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1744775135.501987   46370 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1744775135.501994   46370 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1744775135.502000   46370 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1744775135.502007   46370 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1744775135.502013   46370 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1744775135.502019   46370 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1744775135.502025   46370 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1744775135.502031   46370 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-04-16 03:45:35.502061: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775135.504841   46370 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1744775135.504864   46370 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1744775135.504871   46370 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1744775135.504878   46370 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1744775135.504884   46370 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1744775135.504890   46370 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1744775135.504896   46370 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1744775135.504902   46370 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1744775135.504909   46370 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1744775135.504915   46370 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-04-16 03:45:35.504925: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775170.942897   46370 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1744775170.942939   46370 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1744775170.942943   46370 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1744775170.942945   46370 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1744775170.942948   46370 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1744775170.942951   46370 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1744775170.942953   46370 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1744775170.942956   46370 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1744775170.942960   46370 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1744775170.942963   46370 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-04-16 03:46:10.942972: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775170.944667   46370 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1744775170.944686   46370 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1744775170.944689   46370 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1744775170.944692   46370 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1744775170.944695   46370 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1744775170.944697   46370 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1744775170.944700   46370 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1744775170.944703   46370 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1744775170.944705   46370 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1744775170.944708   46370 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-04-16 03:46:10.944714: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775170.946415   46370 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1744775170.946437   46370 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1744775170.946440   46370 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1744775170.946442   46370 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1744775170.946445   46370 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1744775170.946448   46370 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1744775170.946451   46370 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1744775170.946453   46370 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1744775170.946456   46370 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1744775170.946458   46370 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-04-16 03:46:10.946465: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775170.948234   46370 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1744775170.948260   46370 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1744775170.948263   46370 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1744775170.948266   46370 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1744775170.948269   46370 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1744775170.948271   46370 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1744775170.948274   46370 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1744775170.948277   46370 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1744775170.948279   46370 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1744775170.948282   46370 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-04-16 03:46:10.948290: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744775170.957772   46370 buffer_comparator.cc:156] Difference at 32: -nan, expected 0.585358
E0000 00:00:1744775170.957816   46370 buffer_comparator.cc:156] Difference at 33: -nan, expected 0.881101
E0000 00:00:1744775170.957819   46370 buffer_comparator.cc:156] Difference at 34: -nan, expected 0.738887
E0000 00:00:1744775170.957822   46370 buffer_comparator.cc:156] Difference at 35: -nan, expected 0.603294
E0000 00:00:1744775170.957825   46370 buffer_comparator.cc:156] Difference at 36: -nan, expected 1.04006
E0000 00:00:1744775170.957828   46370 buffer_comparator.cc:156] Difference at 37: -nan, expected 0.740676
E0000 00:00:1744775170.957830   46370 buffer_comparator.cc:156] Difference at 38: -nan, expected 0.766031
E0000 00:00:1744775170.957833   46370 buffer_comparator.cc:156] Difference at 39: -nan, expected 0.814752
E0000 00:00:1744775170.957836   46370 buffer_comparator.cc:156] Difference at 40: -nan, expected 0.431557
E0000 00:00:1744775170.957838   46370 buffer_comparator.cc:156] Difference at 41: -nan, expected 0.672244
2025-04-16 03:46:10.957848: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
[  1/ 50]	       MNIST	Time 38.91880s	Training Accuracy: 34.86%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.04589s	Training Accuracy: 31.54%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.09114s	Training Accuracy: 37.40%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.03412s	Training Accuracy: 44.14%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.03763s	Training Accuracy: 41.99%	Test Accuracy: 34.38%
[  3/ 50]	FashionMNIST	Time 0.03522s	Training Accuracy: 54.10%	Test Accuracy: 53.12%
[  4/ 50]	       MNIST	Time 0.02859s	Training Accuracy: 51.46%	Test Accuracy: 43.75%
[  4/ 50]	FashionMNIST	Time 0.03806s	Training Accuracy: 59.08%	Test Accuracy: 50.00%
[  5/ 50]	       MNIST	Time 0.02847s	Training Accuracy: 55.76%	Test Accuracy: 37.50%
[  5/ 50]	FashionMNIST	Time 0.02782s	Training Accuracy: 66.11%	Test Accuracy: 56.25%
[  6/ 50]	       MNIST	Time 0.02832s	Training Accuracy: 62.21%	Test Accuracy: 40.62%
[  6/ 50]	FashionMNIST	Time 0.02887s	Training Accuracy: 72.17%	Test Accuracy: 56.25%
[  7/ 50]	       MNIST	Time 0.03476s	Training Accuracy: 67.87%	Test Accuracy: 43.75%
[  7/ 50]	FashionMNIST	Time 0.03046s	Training Accuracy: 75.10%	Test Accuracy: 56.25%
[  8/ 50]	       MNIST	Time 0.03663s	Training Accuracy: 73.34%	Test Accuracy: 53.12%
[  8/ 50]	FashionMNIST	Time 0.02997s	Training Accuracy: 79.59%	Test Accuracy: 59.38%
[  9/ 50]	       MNIST	Time 0.03055s	Training Accuracy: 79.88%	Test Accuracy: 50.00%
[  9/ 50]	FashionMNIST	Time 0.04096s	Training Accuracy: 83.40%	Test Accuracy: 59.38%
[ 10/ 50]	       MNIST	Time 0.03340s	Training Accuracy: 83.50%	Test Accuracy: 59.38%
[ 10/ 50]	FashionMNIST	Time 0.03909s	Training Accuracy: 85.94%	Test Accuracy: 56.25%
[ 11/ 50]	       MNIST	Time 0.03312s	Training Accuracy: 86.52%	Test Accuracy: 56.25%
[ 11/ 50]	FashionMNIST	Time 0.03172s	Training Accuracy: 88.57%	Test Accuracy: 56.25%
[ 12/ 50]	       MNIST	Time 0.04441s	Training Accuracy: 88.38%	Test Accuracy: 50.00%
[ 12/ 50]	FashionMNIST	Time 0.03080s	Training Accuracy: 91.11%	Test Accuracy: 62.50%
[ 13/ 50]	       MNIST	Time 0.03999s	Training Accuracy: 92.48%	Test Accuracy: 59.38%
[ 13/ 50]	FashionMNIST	Time 0.03201s	Training Accuracy: 93.07%	Test Accuracy: 68.75%
[ 14/ 50]	       MNIST	Time 0.03302s	Training Accuracy: 94.24%	Test Accuracy: 59.38%
[ 14/ 50]	FashionMNIST	Time 0.03256s	Training Accuracy: 94.53%	Test Accuracy: 62.50%
[ 15/ 50]	       MNIST	Time 0.03260s	Training Accuracy: 96.00%	Test Accuracy: 59.38%
[ 15/ 50]	FashionMNIST	Time 0.04075s	Training Accuracy: 95.02%	Test Accuracy: 62.50%
[ 16/ 50]	       MNIST	Time 0.03124s	Training Accuracy: 98.05%	Test Accuracy: 65.62%
[ 16/ 50]	FashionMNIST	Time 0.03106s	Training Accuracy: 96.88%	Test Accuracy: 65.62%
[ 17/ 50]	       MNIST	Time 0.03197s	Training Accuracy: 98.54%	Test Accuracy: 62.50%
[ 17/ 50]	FashionMNIST	Time 0.03312s	Training Accuracy: 96.58%	Test Accuracy: 68.75%
[ 18/ 50]	       MNIST	Time 0.04307s	Training Accuracy: 99.02%	Test Accuracy: 65.62%
[ 18/ 50]	FashionMNIST	Time 0.03118s	Training Accuracy: 96.68%	Test Accuracy: 68.75%
[ 19/ 50]	       MNIST	Time 0.03087s	Training Accuracy: 99.51%	Test Accuracy: 65.62%
[ 19/ 50]	FashionMNIST	Time 0.03746s	Training Accuracy: 98.24%	Test Accuracy: 71.88%
[ 20/ 50]	       MNIST	Time 0.03226s	Training Accuracy: 99.80%	Test Accuracy: 62.50%
[ 20/ 50]	FashionMNIST	Time 0.04409s	Training Accuracy: 98.24%	Test Accuracy: 71.88%
[ 21/ 50]	       MNIST	Time 0.04536s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 21/ 50]	FashionMNIST	Time 0.04570s	Training Accuracy: 98.73%	Test Accuracy: 71.88%
[ 22/ 50]	       MNIST	Time 0.03205s	Training Accuracy: 99.80%	Test Accuracy: 68.75%
[ 22/ 50]	FashionMNIST	Time 0.03485s	Training Accuracy: 99.51%	Test Accuracy: 68.75%
[ 23/ 50]	       MNIST	Time 0.04044s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 23/ 50]	FashionMNIST	Time 0.03030s	Training Accuracy: 99.22%	Test Accuracy: 71.88%
[ 24/ 50]	       MNIST	Time 0.03964s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 24/ 50]	FashionMNIST	Time 0.03730s	Training Accuracy: 99.80%	Test Accuracy: 68.75%
[ 25/ 50]	       MNIST	Time 0.03419s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 25/ 50]	FashionMNIST	Time 0.03029s	Training Accuracy: 99.90%	Test Accuracy: 71.88%
[ 26/ 50]	       MNIST	Time 0.03194s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 26/ 50]	FashionMNIST	Time 0.04040s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 27/ 50]	       MNIST	Time 0.03205s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 27/ 50]	FashionMNIST	Time 0.03129s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 28/ 50]	       MNIST	Time 0.03303s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	FashionMNIST	Time 0.03912s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 29/ 50]	       MNIST	Time 0.04245s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 29/ 50]	FashionMNIST	Time 0.03398s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 30/ 50]	       MNIST	Time 0.04756s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 30/ 50]	FashionMNIST	Time 0.03560s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 31/ 50]	       MNIST	Time 0.03565s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 31/ 50]	FashionMNIST	Time 0.05222s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 32/ 50]	       MNIST	Time 0.03304s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 32/ 50]	FashionMNIST	Time 0.04597s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 33/ 50]	       MNIST	Time 0.03589s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 33/ 50]	FashionMNIST	Time 0.02996s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 34/ 50]	       MNIST	Time 0.04052s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 34/ 50]	FashionMNIST	Time 0.02995s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 35/ 50]	       MNIST	Time 0.04803s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 35/ 50]	FashionMNIST	Time 0.03419s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 36/ 50]	       MNIST	Time 0.03486s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 36/ 50]	FashionMNIST	Time 0.03301s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 37/ 50]	       MNIST	Time 0.03166s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 37/ 50]	FashionMNIST	Time 0.03831s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 38/ 50]	       MNIST	Time 0.03298s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 38/ 50]	FashionMNIST	Time 0.03116s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 39/ 50]	       MNIST	Time 0.02872s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 39/ 50]	FashionMNIST	Time 0.03310s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 40/ 50]	       MNIST	Time 0.03679s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 40/ 50]	FashionMNIST	Time 0.03688s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 41/ 50]	       MNIST	Time 0.03960s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 41/ 50]	FashionMNIST	Time 0.03095s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 42/ 50]	       MNIST	Time 0.03294s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 42/ 50]	FashionMNIST	Time 0.04098s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 43/ 50]	       MNIST	Time 0.03325s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 43/ 50]	FashionMNIST	Time 0.03876s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 44/ 50]	       MNIST	Time 0.03123s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 44/ 50]	FashionMNIST	Time 0.03593s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 45/ 50]	       MNIST	Time 0.02888s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 45/ 50]	FashionMNIST	Time 0.02979s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 46/ 50]	       MNIST	Time 0.03526s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 46/ 50]	FashionMNIST	Time 0.02909s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 47/ 50]	       MNIST	Time 0.03092s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 47/ 50]	FashionMNIST	Time 0.03097s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 48/ 50]	       MNIST	Time 0.02879s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 48/ 50]	FashionMNIST	Time 0.03584s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 49/ 50]	       MNIST	Time 0.02914s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 49/ 50]	FashionMNIST	Time 0.02852s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 50/ 50]	       MNIST	Time 0.03321s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 50/ 50]	FashionMNIST	Time 0.03425s	Training Accuracy: 100.00%	Test Accuracy: 71.88%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 71.88%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.