Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling LuxComponentArraysExt...
    668.1 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1396.1 ms  ✓ Lux → LuxComponentArraysExt
  2 dependencies successfully precompiled in 2 seconds. 113 already precompiled.
Precompiling LuxMLUtilsExt...
   2046.6 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 169 already precompiled.

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end
create_model (generic function with 1 method)

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
2025-03-28 04:30:08.099642: I external/xla/xla/service/service.cc:152] XLA service 0x6cb84e0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-28 04:30:08.099676: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1743136208.100529 3381069 se_gpu_pjrt_client.cc:1039] Using BFC allocator.
I0000 00:00:1743136208.100607 3381069 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1743136208.100650 3381069 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1743136208.112216 3381069 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1743136265.603630 3381069 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1743136265.604599 3381069 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1743136265.604607 3381069 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1743136265.604614 3381069 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1743136265.604621 3381069 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1743136265.604627 3381069 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1743136265.604633 3381069 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1743136265.604639 3381069 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1743136265.604646 3381069 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1743136265.604652 3381069 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-28 04:31:05.604668: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136265.607416 3381069 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1743136265.607441 3381069 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1743136265.607449 3381069 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1743136265.607455 3381069 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1743136265.607462 3381069 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1743136265.607468 3381069 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1743136265.607474 3381069 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1743136265.607481 3381069 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1743136265.607487 3381069 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1743136265.607493 3381069 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-28 04:31:05.607504: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136300.135365 3381069 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1743136300.135404 3381069 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1743136300.135408 3381069 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1743136300.135413 3381069 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1743136300.135417 3381069 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1743136300.135421 3381069 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1743136300.135425 3381069 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1743136300.135429 3381069 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1743136300.135434 3381069 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1743136300.135438 3381069 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-03-28 04:31:40.135449: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136300.137692 3381069 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1743136300.137707 3381069 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1743136300.137711 3381069 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1743136300.137715 3381069 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1743136300.137719 3381069 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1743136300.137724 3381069 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1743136300.137728 3381069 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1743136300.137732 3381069 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1743136300.137736 3381069 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1743136300.137740 3381069 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-03-28 04:31:40.137746: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136300.139998 3381069 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1743136300.140013 3381069 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1743136300.140017 3381069 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1743136300.140021 3381069 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1743136300.140025 3381069 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1743136300.140029 3381069 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1743136300.140032 3381069 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1743136300.140036 3381069 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1743136300.140040 3381069 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1743136300.140044 3381069 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-03-28 04:31:40.140050: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136300.142278 3381069 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1743136300.142293 3381069 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1743136300.142297 3381069 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1743136300.142301 3381069 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1743136300.142306 3381069 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1743136300.142310 3381069 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1743136300.142314 3381069 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1743136300.142318 3381069 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1743136300.142322 3381069 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1743136300.142326 3381069 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-03-28 04:31:40.142332: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136300.153953 3381069 buffer_comparator.cc:156] Difference at 32: -nan, expected 0.585358
E0000 00:00:1743136300.153964 3381069 buffer_comparator.cc:156] Difference at 33: -nan, expected 0.881101
E0000 00:00:1743136300.153967 3381069 buffer_comparator.cc:156] Difference at 34: -nan, expected 0.738887
E0000 00:00:1743136300.153970 3381069 buffer_comparator.cc:156] Difference at 35: -nan, expected 0.603294
E0000 00:00:1743136300.153973 3381069 buffer_comparator.cc:156] Difference at 36: -nan, expected 1.04006
E0000 00:00:1743136300.153976 3381069 buffer_comparator.cc:156] Difference at 37: -nan, expected 0.740676
E0000 00:00:1743136300.153979 3381069 buffer_comparator.cc:156] Difference at 38: -nan, expected 0.766031
E0000 00:00:1743136300.153981 3381069 buffer_comparator.cc:156] Difference at 39: -nan, expected 0.814752
E0000 00:00:1743136300.153984 3381069 buffer_comparator.cc:156] Difference at 40: -nan, expected 0.431557
E0000 00:00:1743136300.153987 3381069 buffer_comparator.cc:156] Difference at 41: -nan, expected 0.672244
2025-03-28 04:31:40.153991: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
[  1/ 50]	       MNIST	Time 37.75709s	Training Accuracy: 35.06%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.04267s	Training Accuracy: 34.38%	Test Accuracy: 43.75%
[  2/ 50]	       MNIST	Time 0.07874s	Training Accuracy: 34.47%	Test Accuracy: 40.62%
[  2/ 50]	FashionMNIST	Time 0.03368s	Training Accuracy: 46.48%	Test Accuracy: 53.12%
[  3/ 50]	       MNIST	Time 0.03505s	Training Accuracy: 40.82%	Test Accuracy: 31.25%
[  3/ 50]	FashionMNIST	Time 0.03268s	Training Accuracy: 55.18%	Test Accuracy: 59.38%
[  4/ 50]	       MNIST	Time 0.02995s	Training Accuracy: 48.24%	Test Accuracy: 40.62%
[  4/ 50]	FashionMNIST	Time 0.02902s	Training Accuracy: 63.48%	Test Accuracy: 53.12%
[  5/ 50]	       MNIST	Time 0.02958s	Training Accuracy: 54.88%	Test Accuracy: 43.75%
[  5/ 50]	FashionMNIST	Time 0.02862s	Training Accuracy: 69.53%	Test Accuracy: 59.38%
[  6/ 50]	       MNIST	Time 0.02941s	Training Accuracy: 60.35%	Test Accuracy: 43.75%
[  6/ 50]	FashionMNIST	Time 0.03849s	Training Accuracy: 72.56%	Test Accuracy: 62.50%
[  7/ 50]	       MNIST	Time 0.02909s	Training Accuracy: 66.60%	Test Accuracy: 46.88%
[  7/ 50]	FashionMNIST	Time 0.03742s	Training Accuracy: 76.56%	Test Accuracy: 65.62%
[  8/ 50]	       MNIST	Time 0.02866s	Training Accuracy: 73.34%	Test Accuracy: 46.88%
[  8/ 50]	FashionMNIST	Time 0.03892s	Training Accuracy: 81.84%	Test Accuracy: 68.75%
[  9/ 50]	       MNIST	Time 0.02802s	Training Accuracy: 77.64%	Test Accuracy: 46.88%
[  9/ 50]	FashionMNIST	Time 0.03922s	Training Accuracy: 85.16%	Test Accuracy: 68.75%
[ 10/ 50]	       MNIST	Time 0.03796s	Training Accuracy: 80.96%	Test Accuracy: 53.12%
[ 10/ 50]	FashionMNIST	Time 0.03270s	Training Accuracy: 86.72%	Test Accuracy: 65.62%
[ 11/ 50]	       MNIST	Time 0.03192s	Training Accuracy: 84.08%	Test Accuracy: 56.25%
[ 11/ 50]	FashionMNIST	Time 0.05059s	Training Accuracy: 90.82%	Test Accuracy: 71.88%
[ 12/ 50]	       MNIST	Time 0.02909s	Training Accuracy: 86.72%	Test Accuracy: 43.75%
[ 12/ 50]	FashionMNIST	Time 0.03840s	Training Accuracy: 92.68%	Test Accuracy: 75.00%
[ 13/ 50]	       MNIST	Time 0.02946s	Training Accuracy: 90.04%	Test Accuracy: 50.00%
[ 13/ 50]	FashionMNIST	Time 0.05467s	Training Accuracy: 93.36%	Test Accuracy: 65.62%
[ 14/ 50]	       MNIST	Time 0.03002s	Training Accuracy: 93.36%	Test Accuracy: 50.00%
[ 14/ 50]	FashionMNIST	Time 0.03657s	Training Accuracy: 95.12%	Test Accuracy: 65.62%
[ 15/ 50]	       MNIST	Time 0.02908s	Training Accuracy: 94.14%	Test Accuracy: 50.00%
[ 15/ 50]	FashionMNIST	Time 0.02804s	Training Accuracy: 96.48%	Test Accuracy: 68.75%
[ 16/ 50]	       MNIST	Time 0.03602s	Training Accuracy: 96.68%	Test Accuracy: 56.25%
[ 16/ 50]	FashionMNIST	Time 0.02802s	Training Accuracy: 97.07%	Test Accuracy: 65.62%
[ 17/ 50]	       MNIST	Time 0.03713s	Training Accuracy: 98.34%	Test Accuracy: 53.12%
[ 17/ 50]	FashionMNIST	Time 0.02842s	Training Accuracy: 97.95%	Test Accuracy: 71.88%
[ 18/ 50]	       MNIST	Time 0.04627s	Training Accuracy: 99.32%	Test Accuracy: 53.12%
[ 18/ 50]	FashionMNIST	Time 0.02771s	Training Accuracy: 97.07%	Test Accuracy: 71.88%
[ 19/ 50]	       MNIST	Time 0.03132s	Training Accuracy: 99.32%	Test Accuracy: 53.12%
[ 19/ 50]	FashionMNIST	Time 0.03583s	Training Accuracy: 98.83%	Test Accuracy: 71.88%
[ 20/ 50]	       MNIST	Time 0.02813s	Training Accuracy: 99.90%	Test Accuracy: 56.25%
[ 20/ 50]	FashionMNIST	Time 0.03146s	Training Accuracy: 99.41%	Test Accuracy: 71.88%
[ 21/ 50]	       MNIST	Time 0.03432s	Training Accuracy: 99.90%	Test Accuracy: 53.12%
[ 21/ 50]	FashionMNIST	Time 0.02969s	Training Accuracy: 99.51%	Test Accuracy: 68.75%
[ 22/ 50]	       MNIST	Time 0.03973s	Training Accuracy: 99.90%	Test Accuracy: 53.12%
[ 22/ 50]	FashionMNIST	Time 0.02803s	Training Accuracy: 99.71%	Test Accuracy: 68.75%
[ 23/ 50]	       MNIST	Time 0.03466s	Training Accuracy: 99.90%	Test Accuracy: 56.25%
[ 23/ 50]	FashionMNIST	Time 0.03239s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 24/ 50]	       MNIST	Time 0.02799s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 24/ 50]	FashionMNIST	Time 0.04212s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 25/ 50]	       MNIST	Time 0.02948s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 25/ 50]	FashionMNIST	Time 0.03995s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 26/ 50]	       MNIST	Time 0.02774s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 26/ 50]	FashionMNIST	Time 0.03664s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 27/ 50]	       MNIST	Time 0.03102s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 27/ 50]	FashionMNIST	Time 0.02823s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 28/ 50]	       MNIST	Time 0.04109s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 28/ 50]	FashionMNIST	Time 0.02803s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 29/ 50]	       MNIST	Time 0.03562s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 29/ 50]	FashionMNIST	Time 0.03202s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 30/ 50]	       MNIST	Time 0.02864s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 30/ 50]	FashionMNIST	Time 0.02958s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 31/ 50]	       MNIST	Time 0.02794s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 31/ 50]	FashionMNIST	Time 0.04448s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 32/ 50]	       MNIST	Time 0.02832s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 32/ 50]	FashionMNIST	Time 0.03999s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 33/ 50]	       MNIST	Time 0.02891s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 33/ 50]	FashionMNIST	Time 0.02975s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 34/ 50]	       MNIST	Time 0.02961s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 34/ 50]	FashionMNIST	Time 0.02866s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 35/ 50]	       MNIST	Time 0.03529s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 35/ 50]	FashionMNIST	Time 0.02839s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 36/ 50]	       MNIST	Time 0.03640s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 36/ 50]	FashionMNIST	Time 0.03029s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 37/ 50]	       MNIST	Time 0.02978s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 37/ 50]	FashionMNIST	Time 0.03685s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 38/ 50]	       MNIST	Time 0.02917s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 38/ 50]	FashionMNIST	Time 0.03857s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 39/ 50]	       MNIST	Time 0.02921s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 39/ 50]	FashionMNIST	Time 0.02788s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 40/ 50]	       MNIST	Time 0.02939s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 40/ 50]	FashionMNIST	Time 0.02925s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 41/ 50]	       MNIST	Time 0.03749s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 41/ 50]	FashionMNIST	Time 0.02817s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 42/ 50]	       MNIST	Time 0.03697s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 42/ 50]	FashionMNIST	Time 0.02842s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 43/ 50]	       MNIST	Time 0.03001s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 43/ 50]	FashionMNIST	Time 0.04860s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 44/ 50]	       MNIST	Time 0.02960s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 44/ 50]	FashionMNIST	Time 0.03534s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 45/ 50]	       MNIST	Time 0.02829s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 45/ 50]	FashionMNIST	Time 0.02779s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 46/ 50]	       MNIST	Time 0.02769s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 46/ 50]	FashionMNIST	Time 0.02857s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 47/ 50]	       MNIST	Time 0.03701s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 47/ 50]	FashionMNIST	Time 0.02867s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 48/ 50]	       MNIST	Time 0.03833s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 48/ 50]	FashionMNIST	Time 0.03150s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 49/ 50]	       MNIST	Time 0.03445s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 49/ 50]	FashionMNIST	Time 0.03588s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 50/ 50]	       MNIST	Time 0.02951s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 50/ 50]	FashionMNIST	Time 0.03822s	Training Accuracy: 100.00%	Test Accuracy: 65.62%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 65.62%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.