Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Precompiling Reactant...
  13130.2 ms  ? Enzyme
  13619.6 ms  ? Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  14845.8 ms  ? Reactant
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Enzyme...
Info Given Enzyme was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  13275.6 ms  ? Enzyme
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling LuxEnzymeExt...
  13180.5 ms  ? Enzyme
    724.8 ms  ? Enzyme → EnzymeChainRulesCoreExt
    874.8 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    916.3 ms  ? Enzyme → EnzymeStaticArraysExt
    715.8 ms  ? Enzyme → EnzymeLogExpFunctionsExt
Info Given LuxEnzymeExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    693.0 ms  ? Lux → LuxEnzymeExt
    775.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeLogExpFunctionsExt...
  13171.3 ms  ? Enzyme
Info Given EnzymeLogExpFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    717.4 ms  ? Enzyme → EnzymeLogExpFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeChainRulesCoreExt...
  13275.0 ms  ? Enzyme
Info Given EnzymeChainRulesCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    775.8 ms  ? Enzyme → EnzymeChainRulesCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeSpecialFunctionsExt...
  13043.5 ms  ? Enzyme
    747.5 ms  ? Enzyme → EnzymeLogExpFunctionsExt
Info Given EnzymeSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    852.1 ms  ? Enzyme → EnzymeSpecialFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeGPUArraysCoreExt...
  13126.3 ms  ? Enzyme
Info Given EnzymeGPUArraysCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    705.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeStaticArraysExt...
  12986.0 ms  ? Enzyme
Info Given EnzymeStaticArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    961.5 ms  ? Enzyme → EnzymeStaticArraysExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling OptimisersReactantExt...
  13073.6 ms  ? Enzyme
    719.7 ms  ? Enzyme → EnzymeChainRulesCoreExt
    733.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1900.6 ms  ? Reactant
    696.3 ms  ? Reactant → ReactantStatisticsExt
Info Given OptimisersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    741.9 ms  ? Optimisers → OptimisersReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxCoreReactantExt...
  13156.2 ms  ? Enzyme
    702.3 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1966.7 ms  ? Reactant
Info Given LuxCoreReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    782.1 ms  ? LuxCore → LuxCoreReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling MLDataDevicesReactantExt...
  13138.5 ms  ? Enzyme
    698.9 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1900.6 ms  ? Reactant
Info Given MLDataDevicesReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    749.8 ms  ? MLDataDevices → MLDataDevicesReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling WeightInitializersReactantExt...
  13355.2 ms  ? Enzyme
    714.9 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    730.7 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    862.6 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   1907.0 ms  ? Reactant
    693.9 ms  ? Reactant → ReactantStatisticsExt
Info Given WeightInitializersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    737.7 ms  ? WeightInitializers → WeightInitializersReactantExt
    901.1 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ComponentArraysReactantExt...
  13706.7 ms  ? Enzyme
    722.6 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    755.8 ms  ? Enzyme → EnzymeChainRulesCoreExt
   1913.3 ms  ? Reactant
    705.8 ms  ? Reactant → ReactantArrayInterfaceExt
Info Given ComponentArraysReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    831.4 ms  ? ComponentArrays → ComponentArraysReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantKernelAbstractionsExt...
  13376.0 ms  ? Enzyme
    719.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    936.9 ms  ? Enzyme → EnzymeStaticArraysExt
   1964.5 ms  ? Reactant
Info Given ReactantKernelAbstractionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    691.2 ms  ? Reactant → ReactantKernelAbstractionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantArrayInterfaceExt...
  13181.9 ms  ? Enzyme
    711.4 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1927.1 ms  ? Reactant
Info Given ReactantArrayInterfaceExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    703.3 ms  ? Reactant → ReactantArrayInterfaceExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantSpecialFunctionsExt...
  13104.9 ms  ? Enzyme
    728.7 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    734.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    865.4 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   1970.7 ms  ? Reactant
Info Given ReactantSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    888.3 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantStatisticsExt...
  13694.5 ms  ? Enzyme
    741.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1983.3 ms  ? Reactant
Info Given ReactantStatisticsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    728.0 ms  ? Reactant → ReactantStatisticsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantOneHotArraysExt...
  13326.6 ms  ? Enzyme
    721.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    737.4 ms  ? Enzyme → EnzymeChainRulesCoreExt
    935.1 ms  ? Enzyme → EnzymeStaticArraysExt
   2033.3 ms  ? Reactant
    689.8 ms  ? Reactant → ReactantStatisticsExt
    778.9 ms  ? Reactant → ReactantKernelAbstractionsExt
Info Given ReactantOneHotArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
   1042.5 ms  ? Reactant → ReactantOneHotArraysExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxLibReactantExt...
  13379.0 ms  ? Enzyme
    751.7 ms  ? Enzyme → EnzymeChainRulesCoreExt
    896.8 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    939.8 ms  ? Enzyme → EnzymeStaticArraysExt
    745.6 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    725.8 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1916.9 ms  ? Reactant
    701.5 ms  ? Reactant → ReactantStatisticsExt
    785.9 ms  ? Reactant → ReactantKernelAbstractionsExt
    888.3 ms  ? Reactant → ReactantSpecialFunctionsExt
    760.1 ms  ? Reactant → ReactantArrayInterfaceExt
    785.5 ms  ? MLDataDevices → MLDataDevicesReactantExt
    762.7 ms  ? LuxCore → LuxCoreReactantExt
Info Given LuxLibReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    728.0 ms  ? LuxLib → LuxLibReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantNNlibExt...
  13216.2 ms  ? Enzyme
    746.5 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    760.0 ms  ? Enzyme → EnzymeChainRulesCoreExt
    916.3 ms  ? Enzyme → EnzymeStaticArraysExt
   2035.0 ms  ? Reactant
    699.8 ms  ? Reactant → ReactantStatisticsExt
    725.5 ms  ? Reactant → ReactantKernelAbstractionsExt
Info Given ReactantNNlibExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
   1163.4 ms  ? Reactant → ReactantNNlibExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-014a-71b5118e5067 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxReactantExt...
  13461.3 ms  ? Enzyme
    755.8 ms  ? Enzyme → EnzymeChainRulesCoreExt
    850.0 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    987.7 ms  ? Enzyme → EnzymeStaticArraysExt
    745.2 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    701.1 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    769.1 ms  ? Lux → LuxEnzymeExt
   1919.3 ms  ? Reactant
    692.9 ms  ? Reactant → ReactantStatisticsExt
    716.1 ms  ? Reactant → ReactantKernelAbstractionsExt
    869.0 ms  ? Reactant → ReactantSpecialFunctionsExt
    783.4 ms  ? Reactant → ReactantArrayInterfaceExt
    852.8 ms  ? MLDataDevices → MLDataDevicesReactantExt
    760.5 ms  ? LuxCore → LuxCoreReactantExt
    771.8 ms  ? Optimisers → OptimisersReactantExt
    693.7 ms  ? LuxLib → LuxLibReactantExt
    790.8 ms  ? WeightInitializers → WeightInitializersReactantExt
Info Given LuxReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    719.4 ms  ? Lux → LuxReactantExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-6d16-734699a07e98 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")
[  1/ 50]	       MNIST	Time 47.25133s	Training Accuracy: 34.57%	Test Accuracy: 37.50%
[  1/ 50]	FashionMNIST	Time 0.19298s	Training Accuracy: 32.52%	Test Accuracy: 43.75%
[  2/ 50]	       MNIST	Time 0.22306s	Training Accuracy: 36.33%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.23245s	Training Accuracy: 46.19%	Test Accuracy: 46.88%
[  3/ 50]	       MNIST	Time 0.16945s	Training Accuracy: 42.68%	Test Accuracy: 28.12%
[  3/ 50]	FashionMNIST	Time 0.22751s	Training Accuracy: 56.74%	Test Accuracy: 56.25%
[  4/ 50]	       MNIST	Time 0.25773s	Training Accuracy: 51.27%	Test Accuracy: 37.50%
[  4/ 50]	FashionMNIST	Time 0.21111s	Training Accuracy: 64.55%	Test Accuracy: 56.25%
[  5/ 50]	       MNIST	Time 0.23107s	Training Accuracy: 57.03%	Test Accuracy: 40.62%
[  5/ 50]	FashionMNIST	Time 0.23048s	Training Accuracy: 71.19%	Test Accuracy: 56.25%
[  6/ 50]	       MNIST	Time 0.18307s	Training Accuracy: 62.70%	Test Accuracy: 34.38%
[  6/ 50]	FashionMNIST	Time 0.16869s	Training Accuracy: 75.39%	Test Accuracy: 56.25%
[  7/ 50]	       MNIST	Time 0.16684s	Training Accuracy: 69.04%	Test Accuracy: 43.75%
[  7/ 50]	FashionMNIST	Time 0.16792s	Training Accuracy: 75.88%	Test Accuracy: 62.50%
[  8/ 50]	       MNIST	Time 0.17440s	Training Accuracy: 73.93%	Test Accuracy: 46.88%
[  8/ 50]	FashionMNIST	Time 0.16443s	Training Accuracy: 81.25%	Test Accuracy: 65.62%
[  9/ 50]	       MNIST	Time 0.19038s	Training Accuracy: 79.59%	Test Accuracy: 59.38%
[  9/ 50]	FashionMNIST	Time 0.16802s	Training Accuracy: 84.57%	Test Accuracy: 65.62%
[ 10/ 50]	       MNIST	Time 0.17108s	Training Accuracy: 83.20%	Test Accuracy: 53.12%
[ 10/ 50]	FashionMNIST	Time 0.17530s	Training Accuracy: 87.70%	Test Accuracy: 62.50%
[ 11/ 50]	       MNIST	Time 0.16764s	Training Accuracy: 86.13%	Test Accuracy: 53.12%
[ 11/ 50]	FashionMNIST	Time 0.16947s	Training Accuracy: 88.18%	Test Accuracy: 68.75%
[ 12/ 50]	       MNIST	Time 0.21124s	Training Accuracy: 90.23%	Test Accuracy: 50.00%
[ 12/ 50]	FashionMNIST	Time 0.17310s	Training Accuracy: 90.92%	Test Accuracy: 75.00%
[ 13/ 50]	       MNIST	Time 0.17504s	Training Accuracy: 94.34%	Test Accuracy: 59.38%
[ 13/ 50]	FashionMNIST	Time 0.17976s	Training Accuracy: 92.87%	Test Accuracy: 68.75%
[ 14/ 50]	       MNIST	Time 0.18228s	Training Accuracy: 95.12%	Test Accuracy: 53.12%
[ 14/ 50]	FashionMNIST	Time 0.17146s	Training Accuracy: 94.24%	Test Accuracy: 68.75%
[ 15/ 50]	       MNIST	Time 0.17092s	Training Accuracy: 96.29%	Test Accuracy: 68.75%
[ 15/ 50]	FashionMNIST	Time 0.22091s	Training Accuracy: 94.63%	Test Accuracy: 71.88%
[ 16/ 50]	       MNIST	Time 0.16761s	Training Accuracy: 98.14%	Test Accuracy: 65.62%
[ 16/ 50]	FashionMNIST	Time 0.20761s	Training Accuracy: 96.39%	Test Accuracy: 71.88%
[ 17/ 50]	       MNIST	Time 0.18366s	Training Accuracy: 99.61%	Test Accuracy: 65.62%
[ 17/ 50]	FashionMNIST	Time 0.19053s	Training Accuracy: 97.27%	Test Accuracy: 75.00%
[ 18/ 50]	       MNIST	Time 0.17661s	Training Accuracy: 99.71%	Test Accuracy: 68.75%
[ 18/ 50]	FashionMNIST	Time 0.22203s	Training Accuracy: 96.68%	Test Accuracy: 71.88%
[ 19/ 50]	       MNIST	Time 0.20205s	Training Accuracy: 99.80%	Test Accuracy: 62.50%
[ 19/ 50]	FashionMNIST	Time 0.17773s	Training Accuracy: 99.02%	Test Accuracy: 75.00%
[ 20/ 50]	       MNIST	Time 0.17589s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 20/ 50]	FashionMNIST	Time 0.17808s	Training Accuracy: 99.12%	Test Accuracy: 75.00%
[ 21/ 50]	       MNIST	Time 0.17368s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 21/ 50]	FashionMNIST	Time 0.20439s	Training Accuracy: 98.83%	Test Accuracy: 75.00%
[ 22/ 50]	       MNIST	Time 0.18979s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 22/ 50]	FashionMNIST	Time 0.18065s	Training Accuracy: 99.51%	Test Accuracy: 75.00%
[ 23/ 50]	       MNIST	Time 0.18448s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 23/ 50]	FashionMNIST	Time 0.17725s	Training Accuracy: 99.71%	Test Accuracy: 71.88%
[ 24/ 50]	       MNIST	Time 0.18730s	Training Accuracy: 99.90%	Test Accuracy: 68.75%
[ 24/ 50]	FashionMNIST	Time 0.16720s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 25/ 50]	       MNIST	Time 0.17303s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 25/ 50]	FashionMNIST	Time 0.17709s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 26/ 50]	       MNIST	Time 0.17719s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 26/ 50]	FashionMNIST	Time 0.17276s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 27/ 50]	       MNIST	Time 0.16721s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 27/ 50]	FashionMNIST	Time 0.17238s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 28/ 50]	       MNIST	Time 0.16904s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	FashionMNIST	Time 0.17979s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 29/ 50]	       MNIST	Time 0.17906s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 29/ 50]	FashionMNIST	Time 0.19616s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 30/ 50]	       MNIST	Time 0.16879s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 30/ 50]	FashionMNIST	Time 0.18608s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 31/ 50]	       MNIST	Time 0.17991s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 31/ 50]	FashionMNIST	Time 0.19045s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 32/ 50]	       MNIST	Time 0.17816s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 32/ 50]	FashionMNIST	Time 0.17921s	Training Accuracy: 100.00%	Test Accuracy: 78.12%
[ 33/ 50]	       MNIST	Time 0.18032s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 33/ 50]	FashionMNIST	Time 0.17738s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 34/ 50]	       MNIST	Time 0.17516s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 34/ 50]	FashionMNIST	Time 0.17145s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 35/ 50]	       MNIST	Time 0.18243s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 35/ 50]	FashionMNIST	Time 0.18124s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 36/ 50]	       MNIST	Time 0.19813s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 36/ 50]	FashionMNIST	Time 0.17606s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 37/ 50]	       MNIST	Time 0.18434s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 37/ 50]	FashionMNIST	Time 0.16323s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 38/ 50]	       MNIST	Time 0.17885s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 38/ 50]	FashionMNIST	Time 0.17277s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 39/ 50]	       MNIST	Time 0.18511s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 39/ 50]	FashionMNIST	Time 0.17106s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 40/ 50]	       MNIST	Time 0.23123s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 40/ 50]	FashionMNIST	Time 0.17562s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 41/ 50]	       MNIST	Time 0.17778s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 41/ 50]	FashionMNIST	Time 0.18432s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 42/ 50]	       MNIST	Time 0.18230s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 42/ 50]	FashionMNIST	Time 0.19206s	Training Accuracy: 100.00%	Test Accuracy: 75.00%
[ 43/ 50]	       MNIST	Time 0.17065s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 43/ 50]	FashionMNIST	Time 0.17754s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 44/ 50]	       MNIST	Time 0.17320s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 44/ 50]	FashionMNIST	Time 0.17923s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 45/ 50]	       MNIST	Time 0.16997s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 45/ 50]	FashionMNIST	Time 0.17484s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 46/ 50]	       MNIST	Time 0.17846s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 46/ 50]	FashionMNIST	Time 0.19340s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 47/ 50]	       MNIST	Time 0.16674s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 47/ 50]	FashionMNIST	Time 0.18571s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 48/ 50]	       MNIST	Time 0.17814s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 48/ 50]	FashionMNIST	Time 0.17444s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 49/ 50]	       MNIST	Time 0.17209s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 49/ 50]	FashionMNIST	Time 0.17431s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 50/ 50]	       MNIST	Time 0.18021s	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[ 50/ 50]	FashionMNIST	Time 0.16950s	Training Accuracy: 100.00%	Test Accuracy: 71.88%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 59.38%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 71.88%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.