Skip to content

Graph Convolutional Networks on Cora

This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.

julia
using Lux,
    Reactant,
    MLDatasets,
    Random,
    Statistics,
    GNNGraphs,
    ConcreteStructs,
    Printf,
    OneHotArrays,
    Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)

Loading Cora Dataset

julia
function loadcora()
    data = Cora()
    gph = data.graphs[1]
    gnngraph = GNNGraph(
        gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
    )
    return (
        gph.node_data.features,
        onehotbatch(gph.node_data.targets, data.metadata["classes"]),
        # We use a dense matrix here to avoid incompatibility with Reactant
        Matrix{Int32}(adjacency_matrix(gnngraph)),
        # We use this since Reactant doesn't yet support gather adjoint
        (1:140, 141:640, 1709:2708),
    )
end
loadcora (generic function with 1 method)

Model Definition

julia
function GCNLayer(args...; kwargs...)
    return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
        @return dense(x) * adj
    end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
    layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
    gcn_layers = [
        GCNLayer(in_dim => out_dim; kwargs...) for
        (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
    ]
    last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
    dropout = Dropout(dropout)

    return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
        for layer in gcn_layers
            x = relu.(layer((x, adj)))
            x = dropout(x)
        end
        @return last_layer((x, adj))[:, mask]
    end
end
GCN (generic function with 1 method)

Helper Functions

julia
function loss_function(model, ps, st, (x, y, adj, mask))
    y_pred, st = model((x, adj, mask), ps, st)
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
    return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
accuracy (generic function with 1 method)

Training the Model

julia
function main(;
    hidden_dim::Int=64,
    dropout::Float64=0.1,
    nb_layers::Int=2,
    use_bias::Bool=true,
    lr::Float64=0.001,
    weight_decay::Float64=0.0,
    patience::Int=20,
    epochs::Int=200,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())

    gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
    ps, st = xdev(Lux.setup(rng, gcn))
    opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

    train_state = Training.TrainState(gcn, ps, st, opt)

    @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)

    val_loss_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
    end

    train_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
    end
    val_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
    end

    best_loss_val = Inf
    cnt = 0

    for epoch in 1:epochs
        (_, loss, _, train_state) = Lux.Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (features, targets, adj, train_idx),
            train_state;
            return_gradients=Val(false),
        )
        train_acc = accuracy(
            Array(
                train_model_compiled(
                    (features, adj, train_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, train_idx],
        )

        val_loss = first(
            val_loss_compiled(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, val_idx),
            ),
        )
        val_acc = accuracy(
            Array(
                val_model_compiled(
                    (features, adj, val_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, val_idx],
        )

        @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
                 Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

        if val_loss < best_loss_val
            best_loss_val = val_loss
            cnt = 0
        else
            cnt += 1
            if cnt == patience
                @printf "Early Stopping at Epoch %d\n" epoch
                break
            end
        end
    end

    Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        test_loss = @jit(
            loss_function(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, test_idx),
            )
        )[1]
        test_acc = accuracy(
            Array(
                @jit(
                    gcn(
                        (features, adj, test_idx),
                        train_state.parameters,
                        Lux.testmode(train_state.states),
                    )
                )[1],
            ),
            Array(targets)[:, test_idx],
        )

        @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
    end
    return nothing
end

main()
Precompiling EnzymeBFloat16sExt...
   6518.8 ms  ✓ Enzyme → EnzymeBFloat16sExt
  1 dependency successfully precompiled in 7 seconds. 47 already precompiled.
2025-07-14 00:06:50.243037: I external/xla/xla/service/service.cc:153] XLA service 0x4934d7a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-14 00:06:50.243065: I external/xla/xla/service/service.cc:161]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752451610.243902 2725205 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752451610.243982 2725205 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752451610.244031 2725205 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752451610.255571 2725205 cuda_dnn.cc:471] Loaded cuDNN version 90800
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{false}()` but is being used within an autodiff call (gradient, jacobian, etc...). This might lead to incorrect results. If you are using a `Lux.jl` model, set it to training mode using `LuxCore.trainmode`.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:344
Epoch   1	Train Loss: 14.796565	Train Acc: 20.0000%	Val Loss: 7.816453	Val Acc: 22.0000%
Epoch   2	Train Loss: 9.206785	Train Acc: 24.2857%	Val Loss: 3.503649	Val Acc: 27.8000%
Epoch   3	Train Loss: 4.558276	Train Acc: 45.0000%	Val Loss: 1.994527	Val Acc: 42.2000%
Epoch   4	Train Loss: 2.158406	Train Acc: 50.7143%	Val Loss: 2.134480	Val Acc: 43.2000%
Epoch   5	Train Loss: 1.957148	Train Acc: 56.4286%	Val Loss: 2.066251	Val Acc: 44.8000%
Epoch   6	Train Loss: 2.038639	Train Acc: 63.5714%	Val Loss: 1.778163	Val Acc: 52.2000%
Epoch   7	Train Loss: 1.541075	Train Acc: 71.4286%	Val Loss: 1.519329	Val Acc: 59.4000%
Epoch   8	Train Loss: 1.301353	Train Acc: 72.8571%	Val Loss: 1.470120	Val Acc: 62.2000%
Epoch   9	Train Loss: 1.274294	Train Acc: 75.0000%	Val Loss: 1.504793	Val Acc: 63.0000%
Epoch  10	Train Loss: 1.211376	Train Acc: 75.0000%	Val Loss: 1.541491	Val Acc: 64.4000%
Epoch  11	Train Loss: 1.034225	Train Acc: 77.8571%	Val Loss: 1.557657	Val Acc: 64.4000%
Epoch  12	Train Loss: 1.256145	Train Acc: 79.2857%	Val Loss: 1.554844	Val Acc: 63.8000%
Epoch  13	Train Loss: 1.013252	Train Acc: 80.0000%	Val Loss: 1.549081	Val Acc: 64.4000%
Epoch  14	Train Loss: 0.931371	Train Acc: 81.4286%	Val Loss: 1.545484	Val Acc: 64.4000%
Epoch  15	Train Loss: 0.912301	Train Acc: 82.8571%	Val Loss: 1.550539	Val Acc: 64.6000%
Epoch  16	Train Loss: 0.822234	Train Acc: 84.2857%	Val Loss: 1.574267	Val Acc: 64.8000%
Epoch  17	Train Loss: 1.230084	Train Acc: 85.0000%	Val Loss: 1.617842	Val Acc: 64.6000%
Epoch  18	Train Loss: 0.802377	Train Acc: 85.7143%	Val Loss: 1.697485	Val Acc: 63.0000%
Epoch  19	Train Loss: 0.604222	Train Acc: 84.2857%	Val Loss: 1.782103	Val Acc: 63.4000%
Epoch  20	Train Loss: 0.863587	Train Acc: 84.2857%	Val Loss: 1.810278	Val Acc: 63.8000%
Epoch  21	Train Loss: 0.775717	Train Acc: 85.0000%	Val Loss: 1.755403	Val Acc: 64.4000%
Epoch  22	Train Loss: 0.668179	Train Acc: 87.1429%	Val Loss: 1.671531	Val Acc: 65.6000%
Epoch  23	Train Loss: 0.579810	Train Acc: 88.5714%	Val Loss: 1.615436	Val Acc: 66.0000%
Epoch  24	Train Loss: 0.529690	Train Acc: 88.5714%	Val Loss: 1.590291	Val Acc: 66.8000%
Epoch  25	Train Loss: 0.519709	Train Acc: 88.5714%	Val Loss: 1.583071	Val Acc: 66.6000%
Epoch  26	Train Loss: 0.504758	Train Acc: 88.5714%	Val Loss: 1.582281	Val Acc: 66.4000%
Epoch  27	Train Loss: 0.529384	Train Acc: 89.2857%	Val Loss: 1.590776	Val Acc: 67.2000%
Epoch  28	Train Loss: 0.492651	Train Acc: 89.2857%	Val Loss: 1.604312	Val Acc: 67.4000%
Early Stopping at Epoch 28
Test Loss: 1.365980	Test Acc: 72.2000%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.