Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random,
    Reactant
Precompiling ComponentArrays...
    968.1 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
    636.9 ms  ✓ MLDataDevices → MLDataDevicesComponentArraysExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
    516.2 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1442.4 ms  ✓ Lux → LuxComponentArraysExt
   2113.0 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 2 seconds. 112 already precompiled.
Precompiling MLDatasets...
    390.9 ms  ✓ Glob
    419.0 ms  ✓ WorkerUtilities
    455.8 ms  ✓ BufferedStreams
    358.6 ms  ✓ SimpleBufferStream
    590.8 ms  ✓ URIs
    461.6 ms  ✓ CodecZlib
    333.5 ms  ✓ PackageExtensionCompat
    369.3 ms  ✓ BitFlags
    669.4 ms  ✓ GZip
    709.9 ms  ✓ ConcurrentUtilities
    598.0 ms  ✓ ZipFile
    544.7 ms  ✓ LoggingExtras
    813.0 ms  ✓ StructTypes
   1021.8 ms  ✓ MbedTLS
    590.1 ms  ✓ MPIPreferences
    356.6 ms  ✓ InternedStrings
    511.8 ms  ✓ ExceptionUnwrapping
   2172.3 ms  ✓ PeriodicTable
    580.7 ms  ✓ Chemfiles_jll
   2775.9 ms  ✓ UnitfulAtomic
    618.5 ms  ✓ libaec_jll
    478.4 ms  ✓ MicrosoftMPI_jll
    508.3 ms  ✓ InlineStrings → ParsersExt
    553.6 ms  ✓ StringEncodings
   1376.3 ms  ✓ Transducers → TransducersDataFramesExt
   1852.7 ms  ✓ ImageShow
   1619.2 ms  ✓ BangBang → BangBangDataFramesExt
    436.4 ms  ✓ StridedViews
   1645.0 ms  ✓ NPZ
   1850.9 ms  ✓ OpenSSL
   1117.1 ms  ✓ OpenMPI_jll
   1447.1 ms  ✓ MPICH_jll
   1120.6 ms  ✓ MPItrampoline_jll
    783.9 ms  ✓ WeakRefStrings
   2182.7 ms  ✓ AtomsBase
   2297.4 ms  ✓ Pickle
   1772.3 ms  ✓ HDF5_jll
   9757.5 ms  ✓ JSON3
   2320.1 ms  ✓ Chemfiles
   7327.9 ms  ✓ HDF5
   2346.2 ms  ✓ MAT
  18230.7 ms  ✓ HTTP
  16356.5 ms  ✓ CSV
   1808.5 ms  ✓ FileIO → HTTPExt
   2992.6 ms  ✓ DataDeps
   9032.5 ms  ✓ MLDatasets
  46 dependencies successfully precompiled in 43 seconds. 154 already precompiled.
Precompiling OneHotArrays...
   1000.5 ms  ✓ OneHotArrays
  1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
    747.1 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
  1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling ComponentArraysReactantExt...
  12255.3 ms  ✓ ComponentArrays → ComponentArraysReactantExt
  1 dependency successfully precompiled in 13 seconds. 96 already precompiled.

Loading Datasets

julia
function load_dataset(
        ::Type{dset}, n_train::Union{Nothing, Int},
        n_eval::Union{Nothing, Int}, batchsize::Int
    ) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize = min(batchsize, size(x_train, 4)), shuffle = true, partial = false
        ),
        DataLoader(
            (x_test, y_test);
            batchsize = min(batchsize, size(x_test, 4)), shuffle = false, partial = false
        ),
    )
end

function load_datasets(batchsize = 32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(
        weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer
    )
    ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
        ComponentArray |> getaxes
    return @compact(; ca_axes, weight_generator, core_network, dispatch = :HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator = Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride = 2),
        Conv((3, 3), 16 => 32, relu; stride = 2),
        Conv((3, 3), 32 => 64, relu; stride = 2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10)
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network))
        ),
        core_network
    )
end
create_model (generic function with 1 method)

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    dev = reactant_device(; force = true)

    model = create_model()
    dataloaders = load_datasets() |> dev

    Random.seed!(1234)
    ps, st = Lux.setup(Random.default_rng(), model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(), CrossEntropyLoss(; logits = Val(true)),
                ((concrete_data_idx, x), y), train_state; return_gradients = Val(false)
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled, train_state.parameters,
                train_state.states, train_dataloader, concrete_data_idx
            ) * 100;
            digits = 2
        )
        test_acc = round(
            accuracy(
                model_compiled, train_state.parameters,
                train_state.states, test_dataloader, concrete_data_idx
            ) * 100;
            digits = 2
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled, train_state.parameters,
                train_state.states, train_dataloader, concrete_data_idx
            ) * 100;
            digits = 2
        )
        test_acc = round(
            accuracy(
                model_compiled, train_state.parameters,
                train_state.states, test_dataloader, concrete_data_idx
            ) * 100;
            digits = 2
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
2025-03-08 15:39:37.638068: I external/xla/xla/service/service.cc:152] XLA service 0xee76fb0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-08 15:39:37.638111: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741448377.638985 3567203 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741448377.639060 3567203 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741448377.639111 3567203 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741448377.650453 3567203 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1741448502.810384 3567203 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1741448502.810444 3567203 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1741448502.810452 3567203 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1741448502.810459 3567203 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1741448502.810466 3567203 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1741448502.810472 3567203 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1741448502.810479 3567203 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1741448502.810486 3567203 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1741448502.810492 3567203 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1741448502.810499 3567203 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-08 15:41:42.810515: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448502.813414 3567203 buffer_comparator.cc:156] Difference at 160: 0, expected 3.2697
E0000 00:00:1741448502.813449 3567203 buffer_comparator.cc:156] Difference at 161: 0, expected 4.0983
E0000 00:00:1741448502.813457 3567203 buffer_comparator.cc:156] Difference at 162: 0, expected 4.10357
E0000 00:00:1741448502.813464 3567203 buffer_comparator.cc:156] Difference at 163: 0, expected 3.24249
E0000 00:00:1741448502.813471 3567203 buffer_comparator.cc:156] Difference at 164: 0, expected 4.43023
E0000 00:00:1741448502.813478 3567203 buffer_comparator.cc:156] Difference at 165: 0, expected 3.93868
E0000 00:00:1741448502.813485 3567203 buffer_comparator.cc:156] Difference at 166: 0, expected 3.76497
E0000 00:00:1741448502.813491 3567203 buffer_comparator.cc:156] Difference at 167: 0, expected 3.88143
E0000 00:00:1741448502.813498 3567203 buffer_comparator.cc:156] Difference at 168: 0, expected 3.84474
E0000 00:00:1741448502.813505 3567203 buffer_comparator.cc:156] Difference at 169: 0, expected 3.90903
2025-03-08 15:41:42.813516: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448528.637669 3567203 buffer_comparator.cc:156] Difference at 32: 0, expected 0.585358
E0000 00:00:1741448528.637730 3567203 buffer_comparator.cc:156] Difference at 33: 0, expected 0.881101
E0000 00:00:1741448528.637740 3567203 buffer_comparator.cc:156] Difference at 34: 0, expected 0.738887
E0000 00:00:1741448528.637746 3567203 buffer_comparator.cc:156] Difference at 35: 0, expected 0.603294
E0000 00:00:1741448528.637753 3567203 buffer_comparator.cc:156] Difference at 36: 0, expected 1.04006
E0000 00:00:1741448528.637759 3567203 buffer_comparator.cc:156] Difference at 37: 0.364706, expected 0.740676
E0000 00:00:1741448528.637766 3567203 buffer_comparator.cc:156] Difference at 38: 0.992157, expected 0.766031
E0000 00:00:1741448528.637772 3567203 buffer_comparator.cc:156] Difference at 40: 0.611765, expected 0.431557
E0000 00:00:1741448528.637781 3567203 buffer_comparator.cc:156] Difference at 41: 0, expected 0.672244
E0000 00:00:1741448528.637788 3567203 buffer_comparator.cc:156] Difference at 42: 0, expected 0.692783
2025-03-08 15:42:08.637803: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448528.641748 3567203 buffer_comparator.cc:156] Difference at 160: 0.704499, expected 2.24976
E0000 00:00:1741448528.641764 3567203 buffer_comparator.cc:156] Difference at 161: 0.996866, expected 1.51732
E0000 00:00:1741448528.641769 3567203 buffer_comparator.cc:156] Difference at 162: 0.776841, expected 2.23336
E0000 00:00:1741448528.641773 3567203 buffer_comparator.cc:156] Difference at 163: 0.660587, expected 1.99652
E0000 00:00:1741448528.641777 3567203 buffer_comparator.cc:156] Difference at 164: 1.03633, expected 1.7657
E0000 00:00:1741448528.641781 3567203 buffer_comparator.cc:156] Difference at 165: 0.912078, expected 1.87991
E0000 00:00:1741448528.641785 3567203 buffer_comparator.cc:156] Difference at 166: 0.809391, expected 1.82301
E0000 00:00:1741448528.641789 3567203 buffer_comparator.cc:156] Difference at 167: 0.822743, expected 2.01534
E0000 00:00:1741448528.641793 3567203 buffer_comparator.cc:156] Difference at 168: 0.422923, expected 2.17682
E0000 00:00:1741448528.641797 3567203 buffer_comparator.cc:156] Difference at 169: 0.858019, expected 2.28508
2025-03-08 15:42:08.641803: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448528.644027 3567203 buffer_comparator.cc:156] Difference at 160: 0.704499, expected 2.24976
E0000 00:00:1741448528.644043 3567203 buffer_comparator.cc:156] Difference at 161: 0.996866, expected 1.51732
E0000 00:00:1741448528.644047 3567203 buffer_comparator.cc:156] Difference at 162: 0.776841, expected 2.23336
E0000 00:00:1741448528.644051 3567203 buffer_comparator.cc:156] Difference at 163: 0.660587, expected 1.99652
E0000 00:00:1741448528.644055 3567203 buffer_comparator.cc:156] Difference at 164: 1.03633, expected 1.7657
E0000 00:00:1741448528.644059 3567203 buffer_comparator.cc:156] Difference at 165: 0.912078, expected 1.87991
E0000 00:00:1741448528.644063 3567203 buffer_comparator.cc:156] Difference at 166: 0.809391, expected 1.82301
E0000 00:00:1741448528.644067 3567203 buffer_comparator.cc:156] Difference at 167: 0.822743, expected 2.01534
E0000 00:00:1741448528.644071 3567203 buffer_comparator.cc:156] Difference at 168: 0.422923, expected 2.17682
E0000 00:00:1741448528.644075 3567203 buffer_comparator.cc:156] Difference at 169: 0.858019, expected 2.28508
2025-03-08 15:42:08.644081: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448528.646290 3567203 buffer_comparator.cc:156] Difference at 320: 0.635997, expected 2.66964
E0000 00:00:1741448528.646305 3567203 buffer_comparator.cc:156] Difference at 321: 0.937885, expected 1.74014
E0000 00:00:1741448528.646309 3567203 buffer_comparator.cc:156] Difference at 322: 0.613587, expected 2.52041
E0000 00:00:1741448528.646313 3567203 buffer_comparator.cc:156] Difference at 323: 0.704588, expected 2.17431
E0000 00:00:1741448528.646317 3567203 buffer_comparator.cc:156] Difference at 324: 0.827663, expected 1.93476
E0000 00:00:1741448528.646321 3567203 buffer_comparator.cc:156] Difference at 325: 0.78249, expected 2.28937
E0000 00:00:1741448528.646325 3567203 buffer_comparator.cc:156] Difference at 326: 0.937569, expected 2.2459
E0000 00:00:1741448528.646329 3567203 buffer_comparator.cc:156] Difference at 327: 0.916275, expected 2.22886
E0000 00:00:1741448528.646333 3567203 buffer_comparator.cc:156] Difference at 328: 0.879483, expected 2.35168
E0000 00:00:1741448528.646337 3567203 buffer_comparator.cc:156] Difference at 329: 0.745365, expected 2.40895
2025-03-08 15:42:08.646345: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741448528.648554 3567203 buffer_comparator.cc:156] Difference at 320: 0.635997, expected 2.66964
E0000 00:00:1741448528.648569 3567203 buffer_comparator.cc:156] Difference at 321: 0.937885, expected 1.74014
E0000 00:00:1741448528.648574 3567203 buffer_comparator.cc:156] Difference at 322: 0.613587, expected 2.52041
E0000 00:00:1741448528.648578 3567203 buffer_comparator.cc:156] Difference at 323: 0.704588, expected 2.17431
E0000 00:00:1741448528.648582 3567203 buffer_comparator.cc:156] Difference at 324: 0.827663, expected 1.93476
E0000 00:00:1741448528.648586 3567203 buffer_comparator.cc:156] Difference at 325: 0.78249, expected 2.28937
E0000 00:00:1741448528.648590 3567203 buffer_comparator.cc:156] Difference at 326: 0.937569, expected 2.2459
E0000 00:00:1741448528.648594 3567203 buffer_comparator.cc:156] Difference at 327: 0.916275, expected 2.22886
E0000 00:00:1741448528.648598 3567203 buffer_comparator.cc:156] Difference at 328: 0.879483, expected 2.35168
E0000 00:00:1741448528.648602 3567203 buffer_comparator.cc:156] Difference at 329: 0.745365, expected 2.40895
2025-03-08 15:42:08.648608: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
[  1/ 50]	       MNIST	Time 28.54848s	Training Accuracy: 35.16%	Test Accuracy: 40.62%
[  1/ 50]	FashionMNIST	Time 0.03629s	Training Accuracy: 32.42%	Test Accuracy: 46.88%
[  2/ 50]	       MNIST	Time 0.03454s	Training Accuracy: 33.69%	Test Accuracy: 31.25%
[  2/ 50]	FashionMNIST	Time 0.08035s	Training Accuracy: 47.75%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.03006s	Training Accuracy: 41.80%	Test Accuracy: 46.88%
[  3/ 50]	FashionMNIST	Time 0.03112s	Training Accuracy: 54.88%	Test Accuracy: 56.25%
[  4/ 50]	       MNIST	Time 0.02650s	Training Accuracy: 46.97%	Test Accuracy: 37.50%
[  4/ 50]	FashionMNIST	Time 0.02468s	Training Accuracy: 60.55%	Test Accuracy: 59.38%
[  5/ 50]	       MNIST	Time 0.02426s	Training Accuracy: 54.49%	Test Accuracy: 46.88%
[  5/ 50]	FashionMNIST	Time 0.06597s	Training Accuracy: 65.92%	Test Accuracy: 59.38%
[  6/ 50]	       MNIST	Time 0.02475s	Training Accuracy: 59.77%	Test Accuracy: 50.00%
[  6/ 50]	FashionMNIST	Time 0.02469s	Training Accuracy: 70.31%	Test Accuracy: 62.50%
[  7/ 50]	       MNIST	Time 0.02462s	Training Accuracy: 65.33%	Test Accuracy: 56.25%
[  7/ 50]	FashionMNIST	Time 0.02478s	Training Accuracy: 75.59%	Test Accuracy: 62.50%
[  8/ 50]	       MNIST	Time 0.02465s	Training Accuracy: 71.09%	Test Accuracy: 43.75%
[  8/ 50]	FashionMNIST	Time 0.05393s	Training Accuracy: 80.27%	Test Accuracy: 59.38%
[  9/ 50]	       MNIST	Time 0.02572s	Training Accuracy: 75.88%	Test Accuracy: 43.75%
[  9/ 50]	FashionMNIST	Time 0.02551s	Training Accuracy: 81.54%	Test Accuracy: 59.38%
[ 10/ 50]	       MNIST	Time 0.02569s	Training Accuracy: 80.66%	Test Accuracy: 50.00%
[ 10/ 50]	FashionMNIST	Time 0.04109s	Training Accuracy: 87.11%	Test Accuracy: 59.38%
[ 11/ 50]	       MNIST	Time 0.02470s	Training Accuracy: 85.06%	Test Accuracy: 46.88%
[ 11/ 50]	FashionMNIST	Time 0.02467s	Training Accuracy: 88.09%	Test Accuracy: 71.88%
[ 12/ 50]	       MNIST	Time 0.02448s	Training Accuracy: 87.60%	Test Accuracy: 43.75%
[ 12/ 50]	FashionMNIST	Time 0.03749s	Training Accuracy: 90.23%	Test Accuracy: 65.62%
[ 13/ 50]	       MNIST	Time 0.02538s	Training Accuracy: 89.16%	Test Accuracy: 46.88%
[ 13/ 50]	FashionMNIST	Time 0.02459s	Training Accuracy: 93.46%	Test Accuracy: 68.75%
[ 14/ 50]	       MNIST	Time 0.02439s	Training Accuracy: 93.36%	Test Accuracy: 50.00%
[ 14/ 50]	FashionMNIST	Time 0.02466s	Training Accuracy: 94.53%	Test Accuracy: 71.88%
[ 15/ 50]	       MNIST	Time 0.02458s	Training Accuracy: 94.34%	Test Accuracy: 46.88%
[ 15/ 50]	FashionMNIST	Time 0.02491s	Training Accuracy: 95.31%	Test Accuracy: 62.50%
[ 16/ 50]	       MNIST	Time 0.02456s	Training Accuracy: 96.97%	Test Accuracy: 56.25%
[ 16/ 50]	FashionMNIST	Time 0.02503s	Training Accuracy: 97.27%	Test Accuracy: 65.62%
[ 17/ 50]	       MNIST	Time 0.02466s	Training Accuracy: 98.34%	Test Accuracy: 53.12%
[ 17/ 50]	FashionMNIST	Time 0.02532s	Training Accuracy: 97.36%	Test Accuracy: 71.88%
[ 18/ 50]	       MNIST	Time 0.03622s	Training Accuracy: 98.83%	Test Accuracy: 59.38%
[ 18/ 50]	FashionMNIST	Time 0.02458s	Training Accuracy: 98.34%	Test Accuracy: 68.75%
[ 19/ 50]	       MNIST	Time 0.02458s	Training Accuracy: 99.51%	Test Accuracy: 56.25%
[ 19/ 50]	FashionMNIST	Time 0.02460s	Training Accuracy: 99.02%	Test Accuracy: 68.75%
[ 20/ 50]	       MNIST	Time 0.03623s	Training Accuracy: 99.90%	Test Accuracy: 56.25%
[ 20/ 50]	FashionMNIST	Time 0.02473s	Training Accuracy: 99.22%	Test Accuracy: 71.88%
[ 21/ 50]	       MNIST	Time 0.02575s	Training Accuracy: 99.90%	Test Accuracy: 56.25%
[ 21/ 50]	FashionMNIST	Time 0.02552s	Training Accuracy: 99.41%	Test Accuracy: 68.75%
[ 22/ 50]	       MNIST	Time 0.02480s	Training Accuracy: 99.90%	Test Accuracy: 56.25%
[ 22/ 50]	FashionMNIST	Time 0.02492s	Training Accuracy: 99.71%	Test Accuracy: 65.62%
[ 23/ 50]	       MNIST	Time 0.02481s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 23/ 50]	FashionMNIST	Time 0.02462s	Training Accuracy: 99.80%	Test Accuracy: 68.75%
[ 24/ 50]	       MNIST	Time 0.02482s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 24/ 50]	FashionMNIST	Time 0.02573s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 25/ 50]	       MNIST	Time 0.02607s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 25/ 50]	FashionMNIST	Time 0.03856s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 26/ 50]	       MNIST	Time 0.02559s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 26/ 50]	FashionMNIST	Time 0.02586s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 27/ 50]	       MNIST	Time 0.02484s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 27/ 50]	FashionMNIST	Time 0.02403s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 28/ 50]	       MNIST	Time 0.02475s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 28/ 50]	FashionMNIST	Time 0.02506s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 29/ 50]	       MNIST	Time 0.02460s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 29/ 50]	FashionMNIST	Time 0.02485s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 30/ 50]	       MNIST	Time 0.02459s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 30/ 50]	FashionMNIST	Time 0.02460s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 31/ 50]	       MNIST	Time 0.03671s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 31/ 50]	FashionMNIST	Time 0.02446s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 32/ 50]	       MNIST	Time 0.02446s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 32/ 50]	FashionMNIST	Time 0.02643s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 33/ 50]	       MNIST	Time 0.03690s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 33/ 50]	FashionMNIST	Time 0.02458s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 34/ 50]	       MNIST	Time 0.02463s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 34/ 50]	FashionMNIST	Time 0.02466s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 35/ 50]	       MNIST	Time 0.02674s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 35/ 50]	FashionMNIST	Time 0.02461s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 36/ 50]	       MNIST	Time 0.02461s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 36/ 50]	FashionMNIST	Time 0.02500s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 37/ 50]	       MNIST	Time 0.02464s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 37/ 50]	FashionMNIST	Time 0.02452s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 38/ 50]	       MNIST	Time 0.02491s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 38/ 50]	FashionMNIST	Time 0.03697s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 39/ 50]	       MNIST	Time 0.02473s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 39/ 50]	FashionMNIST	Time 0.02454s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 40/ 50]	       MNIST	Time 0.02459s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 40/ 50]	FashionMNIST	Time 0.04066s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 41/ 50]	       MNIST	Time 0.02638s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 41/ 50]	FashionMNIST	Time 0.02535s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 42/ 50]	       MNIST	Time 0.02436s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 42/ 50]	FashionMNIST	Time 0.02486s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 43/ 50]	       MNIST	Time 0.02606s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 43/ 50]	FashionMNIST	Time 0.02600s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 44/ 50]	       MNIST	Time 0.02477s	Training Accuracy: 100.00%	Test Accuracy: 65.62%
[ 44/ 50]	FashionMNIST	Time 0.02505s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 45/ 50]	       MNIST	Time 0.02749s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 45/ 50]	FashionMNIST	Time 0.02627s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 46/ 50]	       MNIST	Time 0.03704s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 46/ 50]	FashionMNIST	Time 0.02567s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 47/ 50]	       MNIST	Time 0.02569s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 47/ 50]	FashionMNIST	Time 0.02553s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 48/ 50]	       MNIST	Time 0.02431s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 48/ 50]	FashionMNIST	Time 0.02501s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 49/ 50]	       MNIST	Time 0.02633s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 49/ 50]	FashionMNIST	Time 0.02652s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 50/ 50]	       MNIST	Time 0.02583s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 50/ 50]	FashionMNIST	Time 0.02661s	Training Accuracy: 100.00%	Test Accuracy: 75.00%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 75.00%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.