Skip to content

Graph Convolutional Networks on Cora

This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.

julia
using Lux,
    Reactant,
    MLDatasets,
    Random,
    Statistics,
    GNNGraphs,
    ConcreteStructs,
    Printf,
    OneHotArrays,
    Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Precompiling Reactant...
  13328.3 ms  ? Enzyme
  13441.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  14979.7 ms  ? Reactant
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Enzyme...
Info Given Enzyme was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  13442.8 ms  ? Enzyme
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling LuxEnzymeExt...
  13541.0 ms  ? Enzyme
    802.5 ms  ? Enzyme → EnzymeChainRulesCoreExt
    906.2 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    951.4 ms  ? Enzyme → EnzymeStaticArraysExt
    714.5 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    701.7 ms  ? Enzyme → EnzymeGPUArraysCoreExt
Info Given LuxEnzymeExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    720.1 ms  ? Lux → LuxEnzymeExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeLogExpFunctionsExt...
  13044.7 ms  ? Enzyme
Info Given EnzymeLogExpFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    727.8 ms  ? Enzyme → EnzymeLogExpFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeChainRulesCoreExt...
  13333.6 ms  ? Enzyme
Info Given EnzymeChainRulesCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    695.1 ms  ? Enzyme → EnzymeChainRulesCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeSpecialFunctionsExt...
  13101.3 ms  ? Enzyme
    694.0 ms  ? Enzyme → EnzymeLogExpFunctionsExt
Info Given EnzymeSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    831.8 ms  ? Enzyme → EnzymeSpecialFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeGPUArraysCoreExt...
  13256.2 ms  ? Enzyme
Info Given EnzymeGPUArraysCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    695.7 ms  ? Enzyme → EnzymeGPUArraysCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeStaticArraysExt...
  13103.9 ms  ? Enzyme
Info Given EnzymeStaticArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    938.6 ms  ? Enzyme → EnzymeStaticArraysExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling OptimisersReactantExt...
  13573.6 ms  ? Enzyme
    710.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    751.4 ms  ? Enzyme → EnzymeChainRulesCoreExt
   2108.1 ms  ? Reactant
    692.1 ms  ? Reactant → ReactantStatisticsExt
Info Given OptimisersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    769.1 ms  ? Optimisers → OptimisersReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxCoreReactantExt...
  13351.1 ms  ? Enzyme
    708.8 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2053.9 ms  ? Reactant
Info Given LuxCoreReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    742.3 ms  ? LuxCore → LuxCoreReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling MLDataDevicesReactantExt...
  13515.8 ms  ? Enzyme
    732.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2168.4 ms  ? Reactant
Info Given MLDataDevicesReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    773.7 ms  ? MLDataDevices → MLDataDevicesReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling WeightInitializersReactantExt...
  13661.9 ms  ? Enzyme
    716.0 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    749.3 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    895.9 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   2015.9 ms  ? Reactant
    729.3 ms  ? Reactant → ReactantStatisticsExt
Info Given WeightInitializersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    757.0 ms  ? WeightInitializers → WeightInitializersReactantExt
    888.5 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantKernelAbstractionsExt...
  13202.8 ms  ? Enzyme
    724.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    899.2 ms  ? Enzyme → EnzymeStaticArraysExt
   2112.8 ms  ? Reactant
Info Given ReactantKernelAbstractionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    671.8 ms  ? Reactant → ReactantKernelAbstractionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantArrayInterfaceExt...
  13424.9 ms  ? Enzyme
    743.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1995.9 ms  ? Reactant
Info Given ReactantArrayInterfaceExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    697.8 ms  ? Reactant → ReactantArrayInterfaceExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantSpecialFunctionsExt...
  13235.2 ms  ? Enzyme
    735.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    750.5 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    871.9 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   1913.2 ms  ? Reactant
Info Given ReactantSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    860.6 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantStatisticsExt...
  13470.2 ms  ? Enzyme
    748.7 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2007.3 ms  ? Reactant
Info Given ReactantStatisticsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    684.0 ms  ? Reactant → ReactantStatisticsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxLibReactantExt...
  13345.1 ms  ? Enzyme
    771.6 ms  ? Enzyme → EnzymeChainRulesCoreExt
    900.8 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    947.7 ms  ? Enzyme → EnzymeStaticArraysExt
    796.2 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    710.6 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1962.6 ms  ? Reactant
    726.7 ms  ? Reactant → ReactantStatisticsExt
    756.4 ms  ? Reactant → ReactantKernelAbstractionsExt
    881.1 ms  ? Reactant → ReactantSpecialFunctionsExt
    694.9 ms  ? Reactant → ReactantArrayInterfaceExt
    759.8 ms  ? MLDataDevices → MLDataDevicesReactantExt
    769.4 ms  ? LuxCore → LuxCoreReactantExt
Info Given LuxLibReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    700.6 ms  ? LuxLib → LuxLibReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantNNlibExt...
  13526.2 ms  ? Enzyme
    715.1 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    751.2 ms  ? Enzyme → EnzymeChainRulesCoreExt
    919.7 ms  ? Enzyme → EnzymeStaticArraysExt
   2044.6 ms  ? Reactant
    711.1 ms  ? Reactant → ReactantStatisticsExt
    727.8 ms  ? Reactant → ReactantKernelAbstractionsExt
Info Given ReactantNNlibExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
   1023.3 ms  ? Reactant → ReactantNNlibExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxReactantExt...
  13301.1 ms  ? Enzyme
    777.3 ms  ? Enzyme → EnzymeChainRulesCoreExt
    899.4 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    957.4 ms  ? Enzyme → EnzymeStaticArraysExt
    753.7 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    707.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    714.2 ms  ? Lux → LuxEnzymeExt
   1919.7 ms  ? Reactant
    713.8 ms  ? Reactant → ReactantKernelAbstractionsExt
    729.5 ms  ? Reactant → ReactantStatisticsExt
    888.3 ms  ? Reactant → ReactantSpecialFunctionsExt
    702.7 ms  ? Reactant → ReactantArrayInterfaceExt
    735.4 ms  ? MLDataDevices → MLDataDevicesReactantExt
    744.8 ms  ? LuxCore → LuxCoreReactantExt
    762.4 ms  ? WeightInitializers → WeightInitializersReactantExt
    808.8 ms  ? Optimisers → OptimisersReactantExt
    704.3 ms  ? LuxLib → LuxLibReactantExt
Info Given LuxReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    694.1 ms  ? Lux → LuxReactantExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantOneHotArraysExt...
  13254.8 ms  ? Enzyme
    707.3 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    760.9 ms  ? Enzyme → EnzymeChainRulesCoreExt
    935.3 ms  ? Enzyme → EnzymeStaticArraysExt
   1932.1 ms  ? Reactant
    707.3 ms  ? Reactant → ReactantStatisticsExt
    727.2 ms  ? Reactant → ReactantKernelAbstractionsExt
Info Given ReactantOneHotArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
   1032.6 ms  ? Reactant → ReactantOneHotArraysExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-73cf-37c3b5c284cc is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541

Loading Cora Dataset

julia
function loadcora()
    data = Cora()
    gph = data.graphs[1]
    gnngraph = GNNGraph(
        gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
    )
    return (
        gph.node_data.features,
        onehotbatch(gph.node_data.targets, data.metadata["classes"]),
        # We use a dense matrix here to avoid incompatibility with Reactant
        Matrix{Int32}(adjacency_matrix(gnngraph)),
        # We use this since Reactant doesn't yet support gather adjoint
        (1:140, 141:640, 1709:2708),
    )
end

Model Definition

julia
function GCNLayer(args...; kwargs...)
    return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
        @return dense(x) * adj
    end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
    layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
    gcn_layers = [
        GCNLayer(in_dim => out_dim; kwargs...) for
        (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
    ]
    last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
    dropout = Dropout(dropout)

    return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
        for layer in gcn_layers
            x = relu.(layer((x, adj)))
            x = dropout(x)
        end
        @return last_layer((x, adj))[:, mask]
    end
end

Helper Functions

julia
function loss_function(model, ps, st, (x, y, adj, mask))
    y_pred, st = model((x, adj, mask), ps, st)
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
    return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100

Training the Model

julia
function main(;
    hidden_dim::Int=64,
    dropout::Float64=0.1,
    nb_layers::Int=2,
    use_bias::Bool=true,
    lr::Float64=0.001,
    weight_decay::Float64=0.0,
    patience::Int=20,
    epochs::Int=200,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())

    gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
    ps, st = xdev(Lux.setup(rng, gcn))
    opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

    train_state = Training.TrainState(gcn, ps, st, opt)

    @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)

    val_loss_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
    end

    train_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
    end
    val_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
    end

    best_loss_val = Inf
    cnt = 0

    for epoch in 1:epochs
        (_, loss, _, train_state) = Lux.Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (features, targets, adj, train_idx),
            train_state;
            return_gradients=Val(false),
        )
        train_acc = accuracy(
            Array(
                train_model_compiled(
                    (features, adj, train_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, train_idx],
        )

        val_loss = first(
            val_loss_compiled(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, val_idx),
            ),
        )
        val_acc = accuracy(
            Array(
                val_model_compiled(
                    (features, adj, val_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, val_idx],
        )

        @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
                 Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

        if val_loss < best_loss_val
            best_loss_val = val_loss
            cnt = 0
        else
            cnt += 1
            if cnt == patience
                @printf "Early Stopping at Epoch %d\n" epoch
                break
            end
        end
    end

    Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        test_loss = @jit(
            loss_function(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, test_idx),
            )
        )[1]
        test_acc = accuracy(
            Array(
                @jit(
                    gcn(
                        (features, adj, test_idx),
                        train_state.parameters,
                        Lux.testmode(train_state.states),
                    )
                )[1],
            ),
            Array(targets)[:, test_idx],
        )

        @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
    end
    return nothing
end

main()
Precompiling EnzymeBFloat16sExt...
  13349.3 ms  ? Enzyme
Info Given EnzymeBFloat16sExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    692.4 ms  ? Enzyme → EnzymeBFloat16sExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5644-40f260d4fd9b is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-7/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{false}()` but is being used within an autodiff call (gradient, jacobian, etc...). This might lead to incorrect results. If you are using a `Lux.jl` model, set it to training mode using `LuxCore.trainmode`.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-7/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:344
Epoch   1	Train Loss: 16.900967	Train Acc: 20.7143%	Val Loss: 7.453496	Val Acc: 26.8000%
Epoch   2	Train Loss: 8.828428	Train Acc: 23.5714%	Val Loss: 4.012742	Val Acc: 31.2000%
Epoch   3	Train Loss: 4.781260	Train Acc: 38.5714%	Val Loss: 2.289675	Val Acc: 34.8000%
Epoch   4	Train Loss: 2.112979	Train Acc: 47.8571%	Val Loss: 2.407374	Val Acc: 34.8000%
Epoch   5	Train Loss: 2.381186	Train Acc: 50.7143%	Val Loss: 2.453577	Val Acc: 37.8000%
Epoch   6	Train Loss: 1.831160	Train Acc: 57.8571%	Val Loss: 2.267815	Val Acc: 43.2000%
Epoch   7	Train Loss: 1.600700	Train Acc: 68.5714%	Val Loss: 1.987743	Val Acc: 48.6000%
Epoch   8	Train Loss: 1.359828	Train Acc: 71.4286%	Val Loss: 1.756406	Val Acc: 56.2000%
Epoch   9	Train Loss: 1.329730	Train Acc: 75.7143%	Val Loss: 1.649770	Val Acc: 59.6000%
Epoch  10	Train Loss: 1.151535	Train Acc: 78.5714%	Val Loss: 1.605249	Val Acc: 61.6000%
Epoch  11	Train Loss: 1.448032	Train Acc: 77.1429%	Val Loss: 1.601787	Val Acc: 63.0000%
Epoch  12	Train Loss: 1.011211	Train Acc: 77.1429%	Val Loss: 1.602901	Val Acc: 63.0000%
Epoch  13	Train Loss: 0.960311	Train Acc: 77.8571%	Val Loss: 1.599380	Val Acc: 63.8000%
Epoch  14	Train Loss: 1.057467	Train Acc: 78.5714%	Val Loss: 1.591251	Val Acc: 64.0000%
Epoch  15	Train Loss: 1.806404	Train Acc: 80.0000%	Val Loss: 1.581501	Val Acc: 64.6000%
Epoch  16	Train Loss: 0.860397	Train Acc: 79.2857%	Val Loss: 1.595711	Val Acc: 65.2000%
Epoch  17	Train Loss: 0.734396	Train Acc: 80.7143%	Val Loss: 1.618162	Val Acc: 65.6000%
Epoch  18	Train Loss: 0.673426	Train Acc: 82.8571%	Val Loss: 1.651589	Val Acc: 65.4000%
Epoch  19	Train Loss: 0.730562	Train Acc: 82.8571%	Val Loss: 1.689287	Val Acc: 64.4000%
Epoch  20	Train Loss: 0.744401	Train Acc: 83.5714%	Val Loss: 1.711970	Val Acc: 64.6000%
Epoch  21	Train Loss: 0.724164	Train Acc: 84.2857%	Val Loss: 1.712736	Val Acc: 64.6000%
Epoch  22	Train Loss: 0.721066	Train Acc: 85.0000%	Val Loss: 1.695096	Val Acc: 65.0000%
Epoch  23	Train Loss: 0.651210	Train Acc: 85.7143%	Val Loss: 1.672541	Val Acc: 64.8000%
Epoch  24	Train Loss: 0.717768	Train Acc: 85.7143%	Val Loss: 1.668745	Val Acc: 65.4000%
Epoch  25	Train Loss: 0.813148	Train Acc: 87.1429%	Val Loss: 1.648924	Val Acc: 65.4000%
Epoch  26	Train Loss: 0.546969	Train Acc: 87.8571%	Val Loss: 1.624019	Val Acc: 64.8000%
Epoch  27	Train Loss: 0.516378	Train Acc: 89.2857%	Val Loss: 1.598994	Val Acc: 65.0000%
Epoch  28	Train Loss: 0.479124	Train Acc: 89.2857%	Val Loss: 1.581874	Val Acc: 64.6000%
Epoch  29	Train Loss: 0.469026	Train Acc: 90.0000%	Val Loss: 1.567386	Val Acc: 64.8000%
Epoch  30	Train Loss: 0.490631	Train Acc: 90.0000%	Val Loss: 1.559431	Val Acc: 66.0000%
Epoch  31	Train Loss: 0.480068	Train Acc: 90.0000%	Val Loss: 1.558629	Val Acc: 66.4000%
Epoch  32	Train Loss: 0.474387	Train Acc: 90.0000%	Val Loss: 1.559061	Val Acc: 66.4000%
Epoch  33	Train Loss: 0.473973	Train Acc: 91.4286%	Val Loss: 1.566357	Val Acc: 66.4000%
Epoch  34	Train Loss: 0.590703	Train Acc: 92.1429%	Val Loss: 1.578132	Val Acc: 66.0000%
Epoch  35	Train Loss: 0.494798	Train Acc: 92.8571%	Val Loss: 1.590551	Val Acc: 66.2000%
Epoch  36	Train Loss: 0.493164	Train Acc: 93.5714%	Val Loss: 1.595978	Val Acc: 66.0000%
Epoch  37	Train Loss: 0.415880	Train Acc: 93.5714%	Val Loss: 1.601474	Val Acc: 65.8000%
Epoch  38	Train Loss: 0.422072	Train Acc: 93.5714%	Val Loss: 1.607498	Val Acc: 65.6000%
Epoch  39	Train Loss: 0.434265	Train Acc: 92.8571%	Val Loss: 1.617870	Val Acc: 65.8000%
Epoch  40	Train Loss: 0.374637	Train Acc: 92.8571%	Val Loss: 1.631233	Val Acc: 65.6000%
Epoch  41	Train Loss: 0.402360	Train Acc: 92.8571%	Val Loss: 1.645928	Val Acc: 66.0000%
Epoch  42	Train Loss: 0.465095	Train Acc: 93.5714%	Val Loss: 1.671271	Val Acc: 67.0000%
Epoch  43	Train Loss: 0.423041	Train Acc: 93.5714%	Val Loss: 1.719895	Val Acc: 67.0000%
Epoch  44	Train Loss: 0.720114	Train Acc: 94.2857%	Val Loss: 1.827938	Val Acc: 67.0000%
Epoch  45	Train Loss: 0.347155	Train Acc: 92.1429%	Val Loss: 1.936710	Val Acc: 66.6000%
Epoch  46	Train Loss: 0.475836	Train Acc: 91.4286%	Val Loss: 2.019416	Val Acc: 65.8000%
Epoch  47	Train Loss: 0.463268	Train Acc: 91.4286%	Val Loss: 2.064098	Val Acc: 65.4000%
Epoch  48	Train Loss: 0.397085	Train Acc: 90.7143%	Val Loss: 2.081776	Val Acc: 65.2000%
Epoch  49	Train Loss: 0.447446	Train Acc: 91.4286%	Val Loss: 2.064368	Val Acc: 65.4000%
Epoch  50	Train Loss: 0.476110	Train Acc: 93.5714%	Val Loss: 2.029231	Val Acc: 65.6000%
Epoch  51	Train Loss: 0.459895	Train Acc: 95.0000%	Val Loss: 1.977381	Val Acc: 66.2000%
Early Stopping at Epoch 51
Test Loss: 1.904060	Test Acc: 66.1000%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.