Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
      Printf, Random, Setfield, Statistics, Zygote

CUDA.allowscalar(false)
Precompiling Lux...
    431.2 ms  ✓ Future
    384.5 ms  ✓ CEnum
    556.5 ms  ✓ ADTypes
    362.8 ms  ✓ OpenLibm_jll
    509.9 ms  ✓ Statistics
    468.1 ms  ✓ CompilerSupportLibraries_jll
    450.1 ms  ✓ Requires
    375.7 ms  ✓ Reexport
    302.3 ms  ✓ IfElse
    527.2 ms  ✓ EnzymeCore
    881.3 ms  ✓ IrrationalConstants
    426.0 ms  ✓ ConstructionBase
    317.5 ms  ✓ CommonWorldInvalidations
    374.2 ms  ✓ StaticArraysCore
    606.3 ms  ✓ CpuId
    539.3 ms  ✓ Compat
    640.9 ms  ✓ DocStringExtensions
    485.2 ms  ✓ JLLWrappers
    400.8 ms  ✓ NaNMath
    418.8 ms  ✓ Adapt
    395.3 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    364.9 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    352.6 ms  ✓ ADTypes → ADTypesConstructionBaseExt
   2482.7 ms  ✓ MacroTools
    375.1 ms  ✓ DiffResults
    780.2 ms  ✓ Static
    373.5 ms  ✓ Compat → CompatLinearAlgebraExt
    597.7 ms  ✓ LogExpFunctions
    622.5 ms  ✓ Hwloc_jll
    620.2 ms  ✓ OpenSpecFun_jll
    476.2 ms  ✓ GPUArraysCore
    369.5 ms  ✓ EnzymeCore → AdaptExt
    366.2 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    661.2 ms  ✓ CommonSubexpressions
   1566.7 ms  ✓ DispatchDoctor
    401.6 ms  ✓ BitTwiddlingConvenienceFunctions
   1532.0 ms  ✓ Setfield
   1011.7 ms  ✓ CPUSummary
   1237.4 ms  ✓ ChainRulesCore
    613.4 ms  ✓ Functors
   1540.3 ms  ✓ StaticArrayInterface
   7460.0 ms  ✓ StaticArrays
    375.0 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
   2190.3 ms  ✓ Hwloc
    436.5 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    669.3 ms  ✓ PolyesterWeave
   1207.8 ms  ✓ LuxCore
    430.5 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
   2569.8 ms  ✓ SpecialFunctions
    404.9 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    631.3 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
    841.6 ms  ✓ MLDataDevices
   1335.6 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    474.3 ms  ✓ CloseOpenIntervals
   1099.6 ms  ✓ Optimisers
    611.5 ms  ✓ LayoutPointers
    634.2 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    621.1 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    622.6 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    605.1 ms  ✓ Adapt → AdaptStaticArraysExt
    672.5 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    467.0 ms  ✓ LuxCore → LuxCoreFunctorsExt
    685.5 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
    455.6 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    475.1 ms  ✓ LuxCore → LuxCoreSetfieldExt
    609.0 ms  ✓ DiffRules
    670.5 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   1730.3 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
    474.8 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    449.7 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    423.1 ms  ✓ Optimisers → OptimisersAdaptExt
    935.8 ms  ✓ StrideArraysCore
   2870.6 ms  ✓ WeightInitializers
    752.4 ms  ✓ Polyester
    960.4 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   4085.2 ms  ✓ KernelAbstractions
   3731.7 ms  ✓ ForwardDiff
    707.8 ms  ✓ KernelAbstractions → LinearAlgebraExt
    869.9 ms  ✓ KernelAbstractions → EnzymeExt
    911.3 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   5147.9 ms  ✓ NNlib
    863.7 ms  ✓ NNlib → NNlibEnzymeCoreExt
    963.0 ms  ✓ NNlib → NNlibForwardDiffExt
   5825.8 ms  ✓ LuxLib
   9703.7 ms  ✓ Lux
  85 dependencies successfully precompiled in 46 seconds. 24 already precompiled.
Precompiling ComponentArrays...
    925.9 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 46 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
    608.8 ms  ✓ MLDataDevices → MLDataDevicesComponentArraysExt
  1 dependency successfully precompiled in 1 seconds. 49 already precompiled.
Precompiling LuxComponentArraysExt...
    539.2 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1548.7 ms  ✓ Lux → LuxComponentArraysExt
   2022.7 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 2 seconds. 111 already precompiled.
Precompiling LuxCUDA...
    298.2 ms  ✓ IteratorInterfaceExtensions
    362.1 ms  ✓ ExprTools
    531.9 ms  ✓ AbstractFFTs
    433.4 ms  ✓ SuiteSparse_jll
    557.3 ms  ✓ Serialization
    483.2 ms  ✓ OrderedCollections
    296.3 ms  ✓ DataValueInterfaces
    372.2 ms  ✓ Zlib_jll
    344.8 ms  ✓ DataAPI
    408.9 ms  ✓ Scratch
    648.3 ms  ✓ demumble_jll
   1336.6 ms  ✓ SentinelArrays
    344.3 ms  ✓ TableTraits
   2424.9 ms  ✓ FixedPointNumbers
   2127.5 ms  ✓ StringManipulation
   2681.4 ms  ✓ TimerOutputs
   3797.9 ms  ✓ SparseArrays
   1732.1 ms  ✓ DataStructures
    936.5 ms  ✓ CUDA_Driver_jll
   3769.9 ms  ✓ Test
   1040.4 ms  ✓ LazyArtifacts
    553.3 ms  ✓ NVTX_jll
    473.3 ms  ✓ PooledArrays
    554.2 ms  ✓ JuliaNVTXCallbacks_jll
    510.0 ms  ✓ Missings
    846.5 ms  ✓ Tables
    647.7 ms  ✓ Statistics → SparseArraysExt
    512.4 ms  ✓ SortingAlgorithms
    934.2 ms  ✓ KernelAbstractions → SparseArraysExt
    523.6 ms  ✓ BFloat16s
   2315.0 ms  ✓ ColorTypes
   1357.7 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
   1421.4 ms  ✓ LLVMExtra_jll
   2741.9 ms  ✓ CUDA_Runtime_jll
   4287.9 ms  ✓ Colors
   1978.2 ms  ✓ CUDNN_jll
   1289.3 ms  ✓ NVTX
   6555.5 ms  ✓ LLVM
   1297.7 ms  ✓ LLVM → BFloat16sExt
   1754.9 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
   2181.9 ms  ✓ GPUArrays
  20482.5 ms  ✓ PrettyTables
  27368.6 ms  ✓ GPUCompiler
  46725.2 ms  ✓ DataFrames
  52293.6 ms  ✓ CUDA
   5045.6 ms  ✓ Atomix → AtomixCUDAExt
   9033.3 ms  ✓ cuDNN
   5594.8 ms  ✓ LuxCUDA
  48 dependencies successfully precompiled in 152 seconds. 52 already precompiled.
Precompiling MLDataDevicesGPUArraysExt...
   1369.2 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
  1 dependency successfully precompiled in 2 seconds. 42 already precompiled.
Precompiling WeightInitializersGPUArraysExt...
   1448.9 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
  1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling ComponentArraysGPUArraysExt...
   1584.3 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
  1 dependency successfully precompiled in 2 seconds. 68 already precompiled.
Precompiling ParsersExt...
    489.1 ms  ✓ InlineStrings → ParsersExt
  1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
    691.6 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 7 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
    727.2 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    678.2 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling AbstractFFTsChainRulesCoreExt...
    413.7 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
  1 dependency successfully precompiled in 0 seconds. 9 already precompiled.
Precompiling ArrayInterfaceCUDAExt...
   4948.7 ms  ✓ ArrayInterface → ArrayInterfaceCUDAExt
  1 dependency successfully precompiled in 5 seconds. 101 already precompiled.
Precompiling NNlibCUDAExt...
   5117.1 ms  ✓ CUDA → ChainRulesCoreExt
   5548.4 ms  ✓ NNlib → NNlibCUDAExt
  2 dependencies successfully precompiled in 6 seconds. 102 already precompiled.
Precompiling MLDataDevicesCUDAExt...
   5046.9 ms  ✓ MLDataDevices → MLDataDevicesCUDAExt
  1 dependency successfully precompiled in 5 seconds. 104 already precompiled.
Precompiling LuxLibCUDAExt...
   5217.2 ms  ✓ CUDA → SpecialFunctionsExt
   5372.5 ms  ✓ CUDA → EnzymeCoreExt
   6165.1 ms  ✓ LuxLib → LuxLibCUDAExt
  3 dependencies successfully precompiled in 7 seconds. 167 already precompiled.
Precompiling WeightInitializersCUDAExt...
   5111.6 ms  ✓ WeightInitializers → WeightInitializersCUDAExt
  1 dependency successfully precompiled in 5 seconds. 109 already precompiled.
Precompiling NNlibCUDACUDNNExt...
   5350.8 ms  ✓ NNlib → NNlibCUDACUDNNExt
  1 dependency successfully precompiled in 6 seconds. 106 already precompiled.
Precompiling MLDataDevicescuDNNExt...
   5128.7 ms  ✓ MLDataDevices → MLDataDevicescuDNNExt
  1 dependency successfully precompiled in 5 seconds. 107 already precompiled.
Precompiling LuxLibcuDNNExt...
   5774.5 ms  ✓ LuxLib → LuxLibcuDNNExt
  1 dependency successfully precompiled in 6 seconds. 174 already precompiled.
Precompiling MLDatasets...
    400.0 ms  ✓ TensorCore
    385.2 ms  ✓ LazyModules
    364.2 ms  ✓ MappedArrays
    464.4 ms  ✓ CodecZlib
    384.8 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    656.6 ms  ✓ GZip
    378.9 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    673.4 ms  ✓ ConcurrentUtilities
   1950.3 ms  ✓ Distributed
    579.8 ms  ✓ ZipFile
    808.6 ms  ✓ StructTypes
    406.6 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
   1127.0 ms  ✓ MbedTLS
    568.2 ms  ✓ LoggingExtras
    744.8 ms  ✓ MPIPreferences
   1029.2 ms  ✓ SimpleTraits
    474.4 ms  ✓ ContextVariablesX
    491.6 ms  ✓ ExceptionUnwrapping
   1214.5 ms  ✓ SplittablesBase
    600.2 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    599.3 ms  ✓ OpenSSL_jll
    564.6 ms  ✓ Chemfiles_jll
    685.7 ms  ✓ libaec_jll
    563.1 ms  ✓ MicrosoftMPI_jll
    612.8 ms  ✓ Libiconv_jll
    434.1 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   1061.0 ms  ✓ FilePathsBase
    758.8 ms  ✓ WeakRefStrings
   2286.2 ms  ✓ StatsBase
   4337.0 ms  ✓ FileIO
    460.2 ms  ✓ MosaicViews
   2075.7 ms  ✓ ColorVectorSpace
   3039.3 ms  ✓ Accessors
   1546.9 ms  ✓ MPICH_jll
   1322.7 ms  ✓ MPItrampoline_jll
   1236.3 ms  ✓ OpenMPI_jll
    594.9 ms  ✓ FLoopsBase
   2166.5 ms  ✓ OpenSSL
    545.6 ms  ✓ StringEncodings
    526.1 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
  20972.5 ms  ✓ Unitful
  11093.9 ms  ✓ JSON3
   1193.5 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   1559.0 ms  ✓ NPZ
   3468.8 ms  ✓ ColorSchemes
    622.4 ms  ✓ Accessors → AccessorsTestExt
    815.0 ms  ✓ Accessors → AccessorsDatesExt
    783.7 ms  ✓ BangBang
    699.6 ms  ✓ Accessors → AccessorsStaticArraysExt
   1821.9 ms  ✓ HDF5_jll
  19228.6 ms  ✓ ImageCore
   2422.7 ms  ✓ Pickle
  19487.1 ms  ✓ HTTP
    562.3 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    588.1 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   2868.0 ms  ✓ UnitfulAtomic
  34010.6 ms  ✓ JLD2
    663.8 ms  ✓ Accessors → AccessorsUnitfulExt
   2475.3 ms  ✓ PeriodicTable
    751.7 ms  ✓ BangBang → BangBangStaticArraysExt
    516.0 ms  ✓ BangBang → BangBangChainRulesCoreExt
    500.2 ms  ✓ BangBang → BangBangTablesExt
   1912.3 ms  ✓ BangBang → BangBangDataFramesExt
    863.4 ms  ✓ MicroCollections
   2084.0 ms  ✓ ImageBase
  19556.2 ms  ✓ CSV
   3240.1 ms  ✓ DataDeps
   1917.8 ms  ✓ FileIO → HTTPExt
   7412.0 ms  ✓ HDF5
   2242.9 ms  ✓ AtomsBase
   2744.4 ms  ✓ Transducers
   1964.7 ms  ✓ ImageShow
   2461.6 ms  ✓ MAT
    654.6 ms  ✓ Transducers → TransducersAdaptExt
   1420.7 ms  ✓ Transducers → TransducersDataFramesExt
   2382.6 ms  ✓ Chemfiles
   5101.2 ms  ✓ FLoops
   6352.2 ms  ✓ MLUtils
   9087.8 ms  ✓ MLDatasets
  79 dependencies successfully precompiled in 93 seconds. 119 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1777.4 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 102 already precompiled.
Precompiling LuxMLUtilsExt...
   2287.5 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 3 seconds. 167 already precompiled.
Precompiling OneHotArrays...
    949.1 ms  ✓ OneHotArrays
  1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling MLDataDevicesOneHotArraysExt...
    775.1 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
  1 dependency successfully precompiled in 1 seconds. 35 already precompiled.
Precompiling Zygote...
    397.5 ms  ✓ FillArrays → FillArraysStatisticsExt
    591.1 ms  ✓ SuiteSparse
    681.6 ms  ✓ FillArrays → FillArraysSparseArraysExt
    832.1 ms  ✓ StructArrays
    998.9 ms  ✓ ZygoteRules
    745.2 ms  ✓ SparseInverseSubset
    403.5 ms  ✓ StructArrays → StructArraysAdaptExt
    384.0 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
   1963.7 ms  ✓ IRTools
    669.4 ms  ✓ StructArrays → StructArraysSparseArraysExt
   5387.0 ms  ✓ ChainRules
  34727.4 ms  ✓ Zygote
  12 dependencies successfully precompiled in 43 seconds. 74 already precompiled.
Precompiling AccessorsStructArraysExt...
    460.6 ms  ✓ Accessors → AccessorsStructArraysExt
  1 dependency successfully precompiled in 1 seconds. 16 already precompiled.
Precompiling BangBangStructArraysExt...
    585.0 ms  ✓ BangBang → BangBangStructArraysExt
  1 dependency successfully precompiled in 1 seconds. 22 already precompiled.
Precompiling StructArraysStaticArraysExt...
    683.5 ms  ✓ StructArrays → StructArraysStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling ArrayInterfaceChainRulesExt...
    792.5 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 39 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    876.9 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    439.0 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 0 seconds. 15 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1654.0 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  1 dependency successfully precompiled in 2 seconds. 93 already precompiled.
Precompiling LuxZygoteExt...
   2910.5 ms  ✓ Lux → LuxZygoteExt
  1 dependency successfully precompiled in 3 seconds. 163 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1696.6 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
  1 dependency successfully precompiled in 2 seconds. 99 already precompiled.
Precompiling ZygoteColorsExt...
   1783.5 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 89 already precompiled.

Loading Datasets

julia
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
        n_eval::Union{Nothing, Int}, batchsize::Int) where {dset}
    if n_train === nothing
        imgs, labels = dset(:train)
    else
        imgs, labels = dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9)

    if n_eval === nothing
        imgs, labels = dset(:test)
    else
        imgs, labels = dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9)

    return (
        DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true),
        DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)
    )
end

function load_datasets(batchsize=256)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
load_datasets (generic function with 2 methods)

Implement a HyperNet Layer

julia
function HyperNet(
        weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
    ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
              ComponentArray |>
              getaxes
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end
HyperNet (generic function with 1 method)

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end

Create and Initialize the HyperNet

julia
function create_model()
    # Doesn't need to be a MLP can have any Lux Layer
    core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
    weight_generator = Chain(
        Embedding(2 => 32),
        Dense(32, 64, relu),
        Dense(64, Lux.parameterlength(core_network))
    )

    model = HyperNet(weight_generator, core_network)
    return model
end
create_model (generic function with 1 method)

Define Utility Functions

julia
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model((data_idx, x), ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end
accuracy (generic function with 1 method)

Training

julia
function train()
    model = create_model()
    dataloaders = load_datasets()

    dev = gpu_device()
    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

    ### Lets train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoZygote(), loss, ((data_idx, x), y), train_state)
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev
        train_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, train_dataloader, data_idx) * 100;
            digits=2)
        test_acc = round(
            accuracy(model, train_state.parameters,
                train_state.states, test_dataloader, data_idx) * 100;
            digits=2)

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
[  1/ 50]	       MNIST	Time 89.87371s	Training Accuracy: 57.23%	Test Accuracy: 50.00%
[  1/ 50]	FashionMNIST	Time 0.05731s	Training Accuracy: 53.12%	Test Accuracy: 50.00%
[  2/ 50]	       MNIST	Time 0.02819s	Training Accuracy: 68.36%	Test Accuracy: 62.50%
[  2/ 50]	FashionMNIST	Time 0.02842s	Training Accuracy: 62.89%	Test Accuracy: 50.00%
[  3/ 50]	       MNIST	Time 0.02884s	Training Accuracy: 74.22%	Test Accuracy: 56.25%
[  3/ 50]	FashionMNIST	Time 0.02907s	Training Accuracy: 56.93%	Test Accuracy: 53.12%
[  4/ 50]	       MNIST	Time 0.02596s	Training Accuracy: 77.34%	Test Accuracy: 62.50%
[  4/ 50]	FashionMNIST	Time 0.02039s	Training Accuracy: 62.30%	Test Accuracy: 56.25%
[  5/ 50]	       MNIST	Time 0.03110s	Training Accuracy: 80.08%	Test Accuracy: 68.75%
[  5/ 50]	FashionMNIST	Time 0.03448s	Training Accuracy: 66.31%	Test Accuracy: 62.50%
[  6/ 50]	       MNIST	Time 0.02991s	Training Accuracy: 85.16%	Test Accuracy: 68.75%
[  6/ 50]	FashionMNIST	Time 0.02568s	Training Accuracy: 71.19%	Test Accuracy: 78.12%
[  7/ 50]	       MNIST	Time 0.01992s	Training Accuracy: 88.67%	Test Accuracy: 75.00%
[  7/ 50]	FashionMNIST	Time 0.02022s	Training Accuracy: 72.07%	Test Accuracy: 65.62%
[  8/ 50]	       MNIST	Time 0.13196s	Training Accuracy: 91.02%	Test Accuracy: 84.38%
[  8/ 50]	FashionMNIST	Time 0.02167s	Training Accuracy: 75.59%	Test Accuracy: 62.50%
[  9/ 50]	       MNIST	Time 0.02541s	Training Accuracy: 92.97%	Test Accuracy: 81.25%
[  9/ 50]	FashionMNIST	Time 0.02206s	Training Accuracy: 75.68%	Test Accuracy: 56.25%
[ 10/ 50]	       MNIST	Time 0.02689s	Training Accuracy: 95.70%	Test Accuracy: 81.25%
[ 10/ 50]	FashionMNIST	Time 0.02438s	Training Accuracy: 78.61%	Test Accuracy: 65.62%
[ 11/ 50]	       MNIST	Time 0.02680s	Training Accuracy: 96.29%	Test Accuracy: 84.38%
[ 11/ 50]	FashionMNIST	Time 0.02421s	Training Accuracy: 79.88%	Test Accuracy: 65.62%
[ 12/ 50]	       MNIST	Time 0.02495s	Training Accuracy: 97.27%	Test Accuracy: 87.50%
[ 12/ 50]	FashionMNIST	Time 0.03477s	Training Accuracy: 80.66%	Test Accuracy: 65.62%
[ 13/ 50]	       MNIST	Time 0.02354s	Training Accuracy: 98.24%	Test Accuracy: 87.50%
[ 13/ 50]	FashionMNIST	Time 0.02188s	Training Accuracy: 82.42%	Test Accuracy: 68.75%
[ 14/ 50]	       MNIST	Time 0.02191s	Training Accuracy: 99.02%	Test Accuracy: 87.50%
[ 14/ 50]	FashionMNIST	Time 0.02102s	Training Accuracy: 83.98%	Test Accuracy: 71.88%
[ 15/ 50]	       MNIST	Time 0.02096s	Training Accuracy: 99.22%	Test Accuracy: 87.50%
[ 15/ 50]	FashionMNIST	Time 0.02264s	Training Accuracy: 84.28%	Test Accuracy: 75.00%
[ 16/ 50]	       MNIST	Time 0.02114s	Training Accuracy: 99.41%	Test Accuracy: 87.50%
[ 16/ 50]	FashionMNIST	Time 0.02110s	Training Accuracy: 86.04%	Test Accuracy: 71.88%
[ 17/ 50]	       MNIST	Time 0.03495s	Training Accuracy: 99.51%	Test Accuracy: 87.50%
[ 17/ 50]	FashionMNIST	Time 0.02549s	Training Accuracy: 86.82%	Test Accuracy: 78.12%
[ 18/ 50]	       MNIST	Time 0.02230s	Training Accuracy: 99.80%	Test Accuracy: 87.50%
[ 18/ 50]	FashionMNIST	Time 0.02086s	Training Accuracy: 87.79%	Test Accuracy: 78.12%
[ 19/ 50]	       MNIST	Time 0.05698s	Training Accuracy: 99.90%	Test Accuracy: 87.50%
[ 19/ 50]	FashionMNIST	Time 0.07558s	Training Accuracy: 88.96%	Test Accuracy: 75.00%
[ 20/ 50]	       MNIST	Time 0.02034s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 20/ 50]	FashionMNIST	Time 0.02012s	Training Accuracy: 89.26%	Test Accuracy: 71.88%
[ 21/ 50]	       MNIST	Time 0.02018s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 21/ 50]	FashionMNIST	Time 0.06521s	Training Accuracy: 90.04%	Test Accuracy: 75.00%
[ 22/ 50]	       MNIST	Time 0.02011s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 22/ 50]	FashionMNIST	Time 0.02768s	Training Accuracy: 90.23%	Test Accuracy: 75.00%
[ 23/ 50]	       MNIST	Time 0.03188s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 23/ 50]	FashionMNIST	Time 0.02069s	Training Accuracy: 90.92%	Test Accuracy: 75.00%
[ 24/ 50]	       MNIST	Time 0.02140s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 24/ 50]	FashionMNIST	Time 0.02107s	Training Accuracy: 91.02%	Test Accuracy: 71.88%
[ 25/ 50]	       MNIST	Time 0.02118s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 25/ 50]	FashionMNIST	Time 0.02140s	Training Accuracy: 91.21%	Test Accuracy: 71.88%
[ 26/ 50]	       MNIST	Time 0.03259s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 26/ 50]	FashionMNIST	Time 0.02231s	Training Accuracy: 91.50%	Test Accuracy: 71.88%
[ 27/ 50]	       MNIST	Time 0.02168s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 27/ 50]	FashionMNIST	Time 0.02081s	Training Accuracy: 92.48%	Test Accuracy: 71.88%
[ 28/ 50]	       MNIST	Time 0.02134s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 28/ 50]	FashionMNIST	Time 0.02085s	Training Accuracy: 92.58%	Test Accuracy: 71.88%
[ 29/ 50]	       MNIST	Time 0.02097s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 29/ 50]	FashionMNIST	Time 0.02092s	Training Accuracy: 93.16%	Test Accuracy: 71.88%
[ 30/ 50]	       MNIST	Time 0.02178s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 30/ 50]	FashionMNIST	Time 0.03325s	Training Accuracy: 92.97%	Test Accuracy: 71.88%
[ 31/ 50]	       MNIST	Time 0.02205s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 31/ 50]	FashionMNIST	Time 0.02087s	Training Accuracy: 93.46%	Test Accuracy: 71.88%
[ 32/ 50]	       MNIST	Time 0.02121s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 32/ 50]	FashionMNIST	Time 0.02089s	Training Accuracy: 93.36%	Test Accuracy: 71.88%
[ 33/ 50]	       MNIST	Time 0.02140s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 33/ 50]	FashionMNIST	Time 0.02103s	Training Accuracy: 93.85%	Test Accuracy: 71.88%
[ 34/ 50]	       MNIST	Time 0.02138s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 34/ 50]	FashionMNIST	Time 0.02500s	Training Accuracy: 94.34%	Test Accuracy: 71.88%
[ 35/ 50]	       MNIST	Time 0.03562s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 35/ 50]	FashionMNIST	Time 0.02148s	Training Accuracy: 94.73%	Test Accuracy: 71.88%
[ 36/ 50]	       MNIST	Time 0.02079s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 36/ 50]	FashionMNIST	Time 0.02092s	Training Accuracy: 94.82%	Test Accuracy: 71.88%
[ 37/ 50]	       MNIST	Time 0.02153s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 37/ 50]	FashionMNIST	Time 0.02109s	Training Accuracy: 95.12%	Test Accuracy: 71.88%
[ 38/ 50]	       MNIST	Time 0.02080s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 38/ 50]	FashionMNIST	Time 0.02182s	Training Accuracy: 95.02%	Test Accuracy: 71.88%
[ 39/ 50]	       MNIST	Time 0.02098s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 39/ 50]	FashionMNIST	Time 0.03268s	Training Accuracy: 95.31%	Test Accuracy: 71.88%
[ 40/ 50]	       MNIST	Time 0.02135s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 40/ 50]	FashionMNIST	Time 0.02154s	Training Accuracy: 95.31%	Test Accuracy: 71.88%
[ 41/ 50]	       MNIST	Time 0.02120s	Training Accuracy: 100.00%	Test Accuracy: 87.50%
[ 41/ 50]	FashionMNIST	Time 0.02276s	Training Accuracy: 95.61%	Test Accuracy: 71.88%
[ 42/ 50]	       MNIST	Time 0.02087s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 42/ 50]	FashionMNIST	Time 0.02136s	Training Accuracy: 95.90%	Test Accuracy: 71.88%
[ 43/ 50]	       MNIST	Time 0.02088s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 43/ 50]	FashionMNIST	Time 0.02088s	Training Accuracy: 95.70%	Test Accuracy: 71.88%
[ 44/ 50]	       MNIST	Time 0.03568s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 44/ 50]	FashionMNIST	Time 0.02237s	Training Accuracy: 96.00%	Test Accuracy: 71.88%
[ 45/ 50]	       MNIST	Time 0.02152s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 45/ 50]	FashionMNIST	Time 0.02090s	Training Accuracy: 96.39%	Test Accuracy: 71.88%
[ 46/ 50]	       MNIST	Time 0.02163s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 46/ 50]	FashionMNIST	Time 0.02138s	Training Accuracy: 96.29%	Test Accuracy: 71.88%
[ 47/ 50]	       MNIST	Time 0.02088s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 47/ 50]	FashionMNIST	Time 0.02137s	Training Accuracy: 96.39%	Test Accuracy: 71.88%
[ 48/ 50]	       MNIST	Time 0.02171s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 48/ 50]	FashionMNIST	Time 0.03456s	Training Accuracy: 96.00%	Test Accuracy: 71.88%
[ 49/ 50]	       MNIST	Time 0.02158s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 49/ 50]	FashionMNIST	Time 0.02398s	Training Accuracy: 96.68%	Test Accuracy: 71.88%
[ 50/ 50]	       MNIST	Time 0.02082s	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[ 50/ 50]	FashionMNIST	Time 0.02131s	Training Accuracy: 96.39%	Test Accuracy: 71.88%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 84.38%
[FINAL]	FashionMNIST	Training Accuracy: 96.39%	Test Accuracy: 71.88%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 2.232 GiB / 4.750 GiB available)

This page was generated using Literate.jl.