Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random,
    Reactant
Precompiling ComponentArrays...
    997.3 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
    616.2 ms  ✓ MLDataDevices → MLDataDevicesComponentArraysExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
    527.5 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1479.3 ms  ✓ Lux → LuxComponentArraysExt
   2313.9 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 3 seconds. 112 already precompiled.
Precompiling MLDatasets...
    387.3 ms  ✓ Glob
    423.1 ms  ✓ WorkerUtilities
    458.9 ms  ✓ BufferedStreams
    369.6 ms  ✓ SimpleBufferStream
    638.4 ms  ✓ InlineStrings
    621.4 ms  ✓ URIs
    480.7 ms  ✓ CodecZlib
    374.4 ms  ✓ InvertedIndices
    337.4 ms  ✓ PackageExtensionCompat
    380.1 ms  ✓ BitFlags
    859.8 ms  ✓ GZip
    743.2 ms  ✓ ConcurrentUtilities
   1215.9 ms  ✓ Crayons
    502.8 ms  ✓ PooledArrays
    776.1 ms  ✓ ZipFile
    894.4 ms  ✓ StructTypes
   1341.0 ms  ✓ SentinelArrays
   1057.3 ms  ✓ MbedTLS
    548.0 ms  ✓ LoggingExtras
    593.9 ms  ✓ MPIPreferences
    366.9 ms  ✓ InternedStrings
    697.1 ms  ✓ ExceptionUnwrapping
   2420.0 ms  ✓ PeriodicTable
   1155.1 ms  ✓ Chemfiles_jll
   1604.3 ms  ✓ BFloat16s
   3842.2 ms  ✓ UnitfulAtomic
    583.8 ms  ✓ MicrosoftMPI_jll
    902.8 ms  ✓ libaec_jll
    566.6 ms  ✓ StringEncodings
    515.8 ms  ✓ InlineStrings → ParsersExt
    447.3 ms  ✓ StridedViews
   2101.0 ms  ✓ StringManipulation
   2201.0 ms  ✓ ImageShow
   1522.4 ms  ✓ NPZ
   1934.6 ms  ✓ OpenSSL
   1159.4 ms  ✓ OpenMPI_jll
   1449.6 ms  ✓ MPICH_jll
   1151.0 ms  ✓ MPItrampoline_jll
    775.3 ms  ✓ WeakRefStrings
   2212.3 ms  ✓ AtomsBase
   2386.8 ms  ✓ Pickle
  10657.4 ms  ✓ JSON3
   1865.7 ms  ✓ HDF5_jll
  20312.5 ms  ✓ PrettyTables
  19675.2 ms  ✓ HTTP
   2413.8 ms  ✓ Chemfiles
  17362.1 ms  ✓ CSV
   3202.2 ms  ✓ DataDeps
   7643.8 ms  ✓ HDF5
   1955.8 ms  ✓ FileIO → HTTPExt
   2463.7 ms  ✓ MAT
  45238.6 ms  ✓ DataFrames
   1425.2 ms  ✓ Transducers → TransducersDataFramesExt
   1610.9 ms  ✓ BangBang → BangBangDataFramesExt
   9289.0 ms  ✓ MLDatasets
  55 dependencies successfully precompiled in 95 seconds. 145 already precompiled.
Precompiling OneHotArrays...
    988.3 ms  ✓ OneHotArrays
  1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
    748.7 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
  1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling ComponentArraysReactantExt...
  12413.7 ms  ✓ ComponentArrays → ComponentArraysReactantExt
  1 dependency successfully precompiled in 13 seconds. 96 already precompiled.

Loading Datasets

julia
function load_dataset(
        ::Type{dset}, n_train::Union{Nothing, Int},
        n_eval::Union{Nothing, Int}, batchsize::Int
    ) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize = min(batchsize, size(x_train, 4)), shuffle = true, partial = false
        ),
        DataLoader(
            (x_test, y_test);
            batchsize = min(batchsize, size(x_test, 4)), shuffle = false, partial = false
        ),
    )
end

function load_datasets(batchsize = 32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(
        weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer
    )
    ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
        ComponentArray |> getaxes
    return @compact(; ca_axes, weight_generator, core_network, dispatch = :HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator = Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride = 2),
        Conv((3, 3), 16 => 32, relu; stride = 2),
        Conv((3, 3), 32 => 64, relu; stride = 2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10)
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network))
        ),
        core_network
    )
end
create_model (generic function with 1 method)

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    dev = reactant_device(; force = true)

    model = create_model()
    dataloaders = load_datasets() |> dev

    Random.seed!(1234)
    ps, st = Lux.setup(Random.default_rng(), model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(), CrossEntropyLoss(; logits = Val(true)),
                ((concrete_data_idx, x), y), train_state; return_gradients = Val(false)
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled, train_state.parameters,
                train_state.states, train_dataloader, concrete_data_idx
            ) * 100;
            digits = 2
        )
        test_acc = round(
            accuracy(
                model_compiled, train_state.parameters,
                train_state.states, test_dataloader, concrete_data_idx
            ) * 100;
            digits = 2
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled, train_state.parameters,
                train_state.states, train_dataloader, concrete_data_idx
            ) * 100;
            digits = 2
        )
        test_acc = round(
            accuracy(
                model_compiled, train_state.parameters,
                train_state.states, test_dataloader, concrete_data_idx
            ) * 100;
            digits = 2
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
2025-03-11 22:39:32.586648: I external/xla/xla/service/service.cc:152] XLA service 0xc7465e0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-11 22:39:32.586698: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741732772.587560 1072320 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741732772.587662 1072320 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741732772.587723 1072320 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741732772.602990 1072320 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1741732846.958816 1072320 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1741732846.959998 1072320 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1741732846.960007 1072320 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1741732846.960014 1072320 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1741732846.960020 1072320 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1741732846.960026 1072320 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1741732846.960033 1072320 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1741732846.960039 1072320 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1741732846.960045 1072320 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1741732846.960052 1072320 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-11 22:40:46.960067: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732846.962802 1072320 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1741732846.962838 1072320 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1741732846.962845 1072320 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1741732846.962851 1072320 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1741732846.962858 1072320 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1741732846.962864 1072320 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1741732846.962870 1072320 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1741732846.962877 1072320 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1741732846.962883 1072320 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1741732846.962889 1072320 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-11 22:40:46.962899: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732880.324143 1072320 buffer_comparator.cc:156] Difference at 32: -nan, expected 0.585358
E0000 00:00:1741732880.324206 1072320 buffer_comparator.cc:156] Difference at 33: -nan, expected 0.881101
E0000 00:00:1741732880.324213 1072320 buffer_comparator.cc:156] Difference at 34: -nan, expected 0.738887
E0000 00:00:1741732880.324220 1072320 buffer_comparator.cc:156] Difference at 35: -nan, expected 0.603294
E0000 00:00:1741732880.324226 1072320 buffer_comparator.cc:156] Difference at 36: -nan, expected 1.04006
E0000 00:00:1741732880.324233 1072320 buffer_comparator.cc:156] Difference at 37: -nan, expected 0.740676
E0000 00:00:1741732880.324239 1072320 buffer_comparator.cc:156] Difference at 38: -nan, expected 0.766031
E0000 00:00:1741732880.324245 1072320 buffer_comparator.cc:156] Difference at 39: -nan, expected 0.814752
E0000 00:00:1741732880.324265 1072320 buffer_comparator.cc:156] Difference at 40: -nan, expected 0.431557
E0000 00:00:1741732880.324272 1072320 buffer_comparator.cc:156] Difference at 41: -nan, expected 0.672244
2025-03-11 22:41:20.324286: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732880.328628 1072320 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1741732880.328652 1072320 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1741732880.328659 1072320 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1741732880.328665 1072320 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1741732880.328671 1072320 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1741732880.328678 1072320 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1741732880.328684 1072320 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1741732880.328690 1072320 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1741732880.328696 1072320 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1741732880.328702 1072320 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-03-11 22:41:20.328711: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732880.331081 1072320 buffer_comparator.cc:156] Difference at 160: -nan, expected 2.24976
E0000 00:00:1741732880.331092 1072320 buffer_comparator.cc:156] Difference at 161: -nan, expected 1.51732
E0000 00:00:1741732880.331095 1072320 buffer_comparator.cc:156] Difference at 162: -nan, expected 2.23336
E0000 00:00:1741732880.331098 1072320 buffer_comparator.cc:156] Difference at 163: -nan, expected 1.99652
E0000 00:00:1741732880.331100 1072320 buffer_comparator.cc:156] Difference at 164: -nan, expected 1.7657
E0000 00:00:1741732880.331103 1072320 buffer_comparator.cc:156] Difference at 165: -nan, expected 1.87991
E0000 00:00:1741732880.331106 1072320 buffer_comparator.cc:156] Difference at 166: -nan, expected 1.82301
E0000 00:00:1741732880.331108 1072320 buffer_comparator.cc:156] Difference at 167: -nan, expected 2.01534
E0000 00:00:1741732880.331111 1072320 buffer_comparator.cc:156] Difference at 168: -nan, expected 2.17682
E0000 00:00:1741732880.331114 1072320 buffer_comparator.cc:156] Difference at 169: -nan, expected 2.28508
2025-03-11 22:41:20.331118: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732880.333245 1072320 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1741732880.333256 1072320 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1741732880.333259 1072320 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1741732880.333261 1072320 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1741732880.333264 1072320 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1741732880.333267 1072320 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1741732880.333270 1072320 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1741732880.333272 1072320 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1741732880.333275 1072320 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1741732880.333278 1072320 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-03-11 22:41:20.333282: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741732880.335413 1072320 buffer_comparator.cc:156] Difference at 320: -nan, expected 2.66964
E0000 00:00:1741732880.335424 1072320 buffer_comparator.cc:156] Difference at 321: -nan, expected 1.74014
E0000 00:00:1741732880.335426 1072320 buffer_comparator.cc:156] Difference at 322: -nan, expected 2.52041
E0000 00:00:1741732880.335429 1072320 buffer_comparator.cc:156] Difference at 323: -nan, expected 2.17431
E0000 00:00:1741732880.335432 1072320 buffer_comparator.cc:156] Difference at 324: -nan, expected 1.93476
E0000 00:00:1741732880.335435 1072320 buffer_comparator.cc:156] Difference at 325: -nan, expected 2.28937
E0000 00:00:1741732880.335437 1072320 buffer_comparator.cc:156] Difference at 326: -nan, expected 2.2459
E0000 00:00:1741732880.335440 1072320 buffer_comparator.cc:156] Difference at 327: -nan, expected 2.22886
E0000 00:00:1741732880.335443 1072320 buffer_comparator.cc:156] Difference at 328: -nan, expected 2.35168
E0000 00:00:1741732880.335445 1072320 buffer_comparator.cc:156] Difference at 329: -nan, expected 2.40895
2025-03-11 22:41:20.335450: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
[  1/ 50]	       MNIST	Time 35.91804s	Training Accuracy: 35.06%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.03871s	Training Accuracy: 34.18%	Test Accuracy: 43.75%
[  2/ 50]	       MNIST	Time 0.04073s	Training Accuracy: 34.96%	Test Accuracy: 37.50%
[  2/ 50]	FashionMNIST	Time 0.07132s	Training Accuracy: 51.17%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.03580s	Training Accuracy: 36.52%	Test Accuracy: 28.12%
[  3/ 50]	FashionMNIST	Time 0.03345s	Training Accuracy: 57.23%	Test Accuracy: 53.12%
[  4/ 50]	       MNIST	Time 0.04741s	Training Accuracy: 45.41%	Test Accuracy: 34.38%
[  4/ 50]	FashionMNIST	Time 0.03490s	Training Accuracy: 64.45%	Test Accuracy: 56.25%
[  5/ 50]	       MNIST	Time 0.03310s	Training Accuracy: 50.78%	Test Accuracy: 46.88%
[  5/ 50]	FashionMNIST	Time 0.03279s	Training Accuracy: 70.02%	Test Accuracy: 59.38%
[  6/ 50]	       MNIST	Time 0.03607s	Training Accuracy: 56.74%	Test Accuracy: 40.62%
[  6/ 50]	FashionMNIST	Time 0.02994s	Training Accuracy: 74.02%	Test Accuracy: 59.38%
[  7/ 50]	       MNIST	Time 0.03922s	Training Accuracy: 62.40%	Test Accuracy: 40.62%
[  7/ 50]	FashionMNIST	Time 0.02892s	Training Accuracy: 78.81%	Test Accuracy: 62.50%
[  8/ 50]	       MNIST	Time 0.03913s	Training Accuracy: 71.00%	Test Accuracy: 43.75%
[  8/ 50]	FashionMNIST	Time 0.02928s	Training Accuracy: 83.69%	Test Accuracy: 62.50%
[  9/ 50]	       MNIST	Time 0.04293s	Training Accuracy: 74.61%	Test Accuracy: 43.75%
[  9/ 50]	FashionMNIST	Time 0.03059s	Training Accuracy: 83.11%	Test Accuracy: 65.62%
[ 10/ 50]	       MNIST	Time 0.03220s	Training Accuracy: 78.61%	Test Accuracy: 46.88%
[ 10/ 50]	FashionMNIST	Time 0.03007s	Training Accuracy: 88.09%	Test Accuracy: 62.50%
[ 11/ 50]	       MNIST	Time 0.03057s	Training Accuracy: 84.08%	Test Accuracy: 37.50%
[ 11/ 50]	FashionMNIST	Time 0.04358s	Training Accuracy: 90.33%	Test Accuracy: 59.38%
[ 12/ 50]	       MNIST	Time 0.03569s	Training Accuracy: 87.11%	Test Accuracy: 43.75%
[ 12/ 50]	FashionMNIST	Time 0.03765s	Training Accuracy: 91.80%	Test Accuracy: 65.62%
[ 13/ 50]	       MNIST	Time 0.03755s	Training Accuracy: 90.23%	Test Accuracy: 46.88%
[ 13/ 50]	FashionMNIST	Time 0.03633s	Training Accuracy: 94.63%	Test Accuracy: 62.50%
[ 14/ 50]	       MNIST	Time 0.03024s	Training Accuracy: 91.31%	Test Accuracy: 56.25%
[ 14/ 50]	FashionMNIST	Time 0.04041s	Training Accuracy: 94.82%	Test Accuracy: 59.38%
[ 15/ 50]	       MNIST	Time 0.03308s	Training Accuracy: 95.12%	Test Accuracy: 53.12%
[ 15/ 50]	FashionMNIST	Time 0.02890s	Training Accuracy: 95.90%	Test Accuracy: 62.50%
[ 16/ 50]	       MNIST	Time 0.02936s	Training Accuracy: 96.68%	Test Accuracy: 50.00%
[ 16/ 50]	FashionMNIST	Time 0.02841s	Training Accuracy: 95.90%	Test Accuracy: 59.38%
[ 17/ 50]	       MNIST	Time 0.03651s	Training Accuracy: 97.17%	Test Accuracy: 56.25%
[ 17/ 50]	FashionMNIST	Time 0.03268s	Training Accuracy: 97.85%	Test Accuracy: 62.50%
[ 18/ 50]	       MNIST	Time 0.03884s	Training Accuracy: 98.14%	Test Accuracy: 53.12%
[ 18/ 50]	FashionMNIST	Time 0.02806s	Training Accuracy: 97.66%	Test Accuracy: 62.50%
[ 19/ 50]	       MNIST	Time 0.03930s	Training Accuracy: 99.22%	Test Accuracy: 56.25%
[ 19/ 50]	FashionMNIST	Time 0.02913s	Training Accuracy: 99.32%	Test Accuracy: 59.38%
[ 20/ 50]	       MNIST	Time 0.03430s	Training Accuracy: 99.71%	Test Accuracy: 53.12%
[ 20/ 50]	FashionMNIST	Time 0.03048s	Training Accuracy: 99.32%	Test Accuracy: 59.38%
[ 21/ 50]	       MNIST	Time 0.03063s	Training Accuracy: 99.90%	Test Accuracy: 53.12%
[ 21/ 50]	FashionMNIST	Time 0.02979s	Training Accuracy: 99.51%	Test Accuracy: 59.38%
[ 22/ 50]	       MNIST	Time 0.03091s	Training Accuracy: 99.90%	Test Accuracy: 53.12%
[ 22/ 50]	FashionMNIST	Time 0.02789s	Training Accuracy: 99.51%	Test Accuracy: 59.38%
[ 23/ 50]	       MNIST	Time 0.03142s	Training Accuracy: 99.90%	Test Accuracy: 53.12%
[ 23/ 50]	FashionMNIST	Time 0.03848s	Training Accuracy: 99.61%	Test Accuracy: 62.50%
[ 24/ 50]	       MNIST	Time 0.03005s	Training Accuracy: 99.90%	Test Accuracy: 53.12%
[ 24/ 50]	FashionMNIST	Time 0.03746s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 25/ 50]	       MNIST	Time 0.03042s	Training Accuracy: 99.90%	Test Accuracy: 53.12%
[ 25/ 50]	FashionMNIST	Time 0.03885s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 26/ 50]	       MNIST	Time 0.03081s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 26/ 50]	FashionMNIST	Time 0.03063s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 27/ 50]	       MNIST	Time 0.03073s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 27/ 50]	FashionMNIST	Time 0.02767s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	       MNIST	Time 0.02941s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 28/ 50]	FashionMNIST	Time 0.02802s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 29/ 50]	       MNIST	Time 0.03935s	Training Accuracy: 100.00%	Test Accuracy: 53.12%
[ 29/ 50]	FashionMNIST	Time 0.02998s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 30/ 50]	       MNIST	Time 0.03779s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 30/ 50]	FashionMNIST	Time 0.02995s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 31/ 50]	       MNIST	Time 0.03823s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 31/ 50]	FashionMNIST	Time 0.03050s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 32/ 50]	       MNIST	Time 0.03040s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 32/ 50]	FashionMNIST	Time 0.03044s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 33/ 50]	       MNIST	Time 0.03360s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 33/ 50]	FashionMNIST	Time 0.02847s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 34/ 50]	       MNIST	Time 0.02868s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 34/ 50]	FashionMNIST	Time 0.03713s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 35/ 50]	       MNIST	Time 0.02848s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 35/ 50]	FashionMNIST	Time 0.04032s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 36/ 50]	       MNIST	Time 0.02940s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 36/ 50]	FashionMNIST	Time 0.03713s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 37/ 50]	       MNIST	Time 0.02845s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 37/ 50]	FashionMNIST	Time 0.03695s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 38/ 50]	       MNIST	Time 0.02979s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 38/ 50]	FashionMNIST	Time 0.02973s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 39/ 50]	       MNIST	Time 0.02844s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 39/ 50]	FashionMNIST	Time 0.02848s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 40/ 50]	       MNIST	Time 0.02931s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 40/ 50]	FashionMNIST	Time 0.02836s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 41/ 50]	       MNIST	Time 0.03853s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 41/ 50]	FashionMNIST	Time 0.02855s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 42/ 50]	       MNIST	Time 0.03617s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 42/ 50]	FashionMNIST	Time 0.02867s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 43/ 50]	       MNIST	Time 0.03773s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 43/ 50]	FashionMNIST	Time 0.02900s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 44/ 50]	       MNIST	Time 0.02822s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 44/ 50]	FashionMNIST	Time 0.02832s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 45/ 50]	       MNIST	Time 0.02983s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 45/ 50]	FashionMNIST	Time 0.02831s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 46/ 50]	       MNIST	Time 0.02834s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 46/ 50]	FashionMNIST	Time 0.03677s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 47/ 50]	       MNIST	Time 0.02830s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 47/ 50]	FashionMNIST	Time 0.03632s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 48/ 50]	       MNIST	Time 0.02826s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 48/ 50]	FashionMNIST	Time 0.03593s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 49/ 50]	       MNIST	Time 0.02867s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 49/ 50]	FashionMNIST	Time 0.02949s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 50/ 50]	       MNIST	Time 0.02858s	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[ 50/ 50]	FashionMNIST	Time 0.02897s	Training Accuracy: 100.00%	Test Accuracy: 62.50%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 56.25%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 62.50%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.