Skip to content

Graph Convolutional Networks on Cora

This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.

julia
using Lux,
    Reactant,
    MLDatasets,
    Random,
    Statistics,
    GNNGraphs,
    ConcreteStructs,
    Printf,
    OneHotArrays,
    Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

Loading Cora Dataset

julia
function loadcora()
    data = Cora()
    gph = data.graphs[1]
    gnngraph = GNNGraph(
        gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
    )
    return (
        gph.node_data.features,
        onehotbatch(gph.node_data.targets, data.metadata["classes"]),
        # We use a dense matrix here to avoid incompatibility with Reactant
        Matrix{Int32}(adjacency_matrix(gnngraph)),
        # We use this since Reactant doesn't yet support gather adjoint
        (1:140, 141:640, 1709:2708),
    )
end

Model Definition

julia
function GCNLayer(args...; kwargs...)
    return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
        @return dense(x) * adj
    end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
    layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
    gcn_layers = [
        GCNLayer(in_dim => out_dim; kwargs...) for
        (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
    ]
    last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
    dropout = Dropout(dropout)

    return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
        for layer in gcn_layers
            x = relu.(layer((x, adj)))
            x = dropout(x)
        end
        @return last_layer((x, adj))[:, mask]
    end
end

Helper Functions

julia
function loss_function(model, ps, st, (x, y, adj, mask))
    y_pred, st = model((x, adj, mask), ps, st)
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
    return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100

Training the Model

julia
function main(;
    hidden_dim::Int=64,
    dropout::Float64=0.1,
    nb_layers::Int=2,
    use_bias::Bool=true,
    lr::Float64=0.001,
    weight_decay::Float64=0.0,
    patience::Int=20,
    epochs::Int=200,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())

    gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
    ps, st = xdev(Lux.setup(rng, gcn))
    opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

    train_state = Training.TrainState(gcn, ps, st, opt)

    @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)

    val_loss_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
    end

    train_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
    end
    val_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
    end

    best_loss_val = Inf
    cnt = 0

    for epoch in 1:epochs
        (_, loss, _, train_state) = Lux.Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (features, targets, adj, train_idx),
            train_state;
            return_gradients=Val(false),
        )
        train_acc = accuracy(
            Array(
                train_model_compiled(
                    (features, adj, train_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, train_idx],
        )

        val_loss = first(
            val_loss_compiled(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, val_idx),
            ),
        )
        val_acc = accuracy(
            Array(
                val_model_compiled(
                    (features, adj, val_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, val_idx],
        )

        @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
                 Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

        if val_loss < best_loss_val
            best_loss_val = val_loss
            cnt = 0
        else
            cnt += 1
            if cnt == patience
                @printf "Early Stopping at Epoch %d\n" epoch
                break
            end
        end
    end

    Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        test_loss = @jit(
            loss_function(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, test_idx),
            )
        )[1]
        test_acc = accuracy(
            Array(
                @jit(
                    gcn(
                        (features, adj, test_idx),
                        train_state.parameters,
                        Lux.testmode(train_state.states),
                    )
                )[1],
            ),
            Array(targets)[:, test_idx],
        )

        @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
    end
    return nothing
end

main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1758301474.460729 1195006 service.cc:158] XLA service 0x43f3d290 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1758301474.460797 1195006 service.cc:166]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1758301474.461718 1195006 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1758301474.461756 1195006 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1758301474.461800 1195006 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1758301474.472947 1195006 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-11/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-11/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch   1	Train Loss: 14.279928	Train Acc: 20.0000%	Val Loss: 7.631112	Val Acc: 22.4000%
Epoch   2	Train Loss: 9.303165	Train Acc: 23.5714%	Val Loss: 3.934845	Val Acc: 29.6000%
Epoch   3	Train Loss: 4.114150	Train Acc: 39.2857%	Val Loss: 2.177299	Val Acc: 35.4000%
Epoch   4	Train Loss: 2.119196	Train Acc: 53.5714%	Val Loss: 1.952470	Val Acc: 40.4000%
Epoch   5	Train Loss: 2.029362	Train Acc: 60.7143%	Val Loss: 1.825472	Val Acc: 45.8000%
Epoch   6	Train Loss: 1.618327	Train Acc: 63.5714%	Val Loss: 1.822781	Val Acc: 49.4000%
Epoch   7	Train Loss: 1.492998	Train Acc: 67.8571%	Val Loss: 1.757256	Val Acc: 54.4000%
Epoch   8	Train Loss: 1.486185	Train Acc: 72.8571%	Val Loss: 1.699004	Val Acc: 57.0000%
Epoch   9	Train Loss: 1.312258	Train Acc: 75.7143%	Val Loss: 1.622224	Val Acc: 59.8000%
Epoch  10	Train Loss: 1.150110	Train Acc: 76.4286%	Val Loss: 1.570092	Val Acc: 63.8000%
Epoch  11	Train Loss: 1.230214	Train Acc: 76.4286%	Val Loss: 1.536865	Val Acc: 65.0000%
Epoch  12	Train Loss: 1.196295	Train Acc: 77.8571%	Val Loss: 1.510300	Val Acc: 65.2000%
Epoch  13	Train Loss: 1.054720	Train Acc: 80.0000%	Val Loss: 1.496862	Val Acc: 65.2000%
Epoch  14	Train Loss: 0.952108	Train Acc: 81.4286%	Val Loss: 1.489314	Val Acc: 65.0000%
Epoch  15	Train Loss: 1.367762	Train Acc: 81.4286%	Val Loss: 1.479288	Val Acc: 65.6000%
Epoch  16	Train Loss: 1.090928	Train Acc: 83.5714%	Val Loss: 1.514651	Val Acc: 66.2000%
Epoch  17	Train Loss: 0.854440	Train Acc: 83.5714%	Val Loss: 1.593374	Val Acc: 65.8000%
Epoch  18	Train Loss: 0.711868	Train Acc: 81.4286%	Val Loss: 1.675784	Val Acc: 64.4000%
Epoch  19	Train Loss: 0.863184	Train Acc: 81.4286%	Val Loss: 1.719823	Val Acc: 64.2000%
Epoch  20	Train Loss: 0.930976	Train Acc: 82.1429%	Val Loss: 1.704442	Val Acc: 65.2000%
Epoch  21	Train Loss: 0.861661	Train Acc: 84.2857%	Val Loss: 1.646146	Val Acc: 65.6000%
Epoch  22	Train Loss: 0.722063	Train Acc: 86.4286%	Val Loss: 1.581254	Val Acc: 66.4000%
Epoch  23	Train Loss: 0.777586	Train Acc: 88.5714%	Val Loss: 1.521695	Val Acc: 67.0000%
Epoch  24	Train Loss: 0.665991	Train Acc: 89.2857%	Val Loss: 1.485319	Val Acc: 69.4000%
Epoch  25	Train Loss: 0.526629	Train Acc: 89.2857%	Val Loss: 1.469275	Val Acc: 69.8000%
Epoch  26	Train Loss: 0.627344	Train Acc: 88.5714%	Val Loss: 1.468050	Val Acc: 69.4000%
Epoch  27	Train Loss: 0.550234	Train Acc: 88.5714%	Val Loss: 1.478206	Val Acc: 69.0000%
Epoch  28	Train Loss: 0.468360	Train Acc: 88.5714%	Val Loss: 1.492301	Val Acc: 69.0000%
Epoch  29	Train Loss: 0.723738	Train Acc: 88.5714%	Val Loss: 1.506597	Val Acc: 68.6000%
Epoch  30	Train Loss: 0.563018	Train Acc: 90.0000%	Val Loss: 1.520076	Val Acc: 68.8000%
Epoch  31	Train Loss: 0.460821	Train Acc: 90.0000%	Val Loss: 1.539181	Val Acc: 68.4000%
Epoch  32	Train Loss: 0.493969	Train Acc: 90.7143%	Val Loss: 1.563661	Val Acc: 68.8000%
Epoch  33	Train Loss: 0.434047	Train Acc: 91.4286%	Val Loss: 1.592195	Val Acc: 68.0000%
Epoch  34	Train Loss: 0.489974	Train Acc: 92.1429%	Val Loss: 1.622337	Val Acc: 67.0000%
Epoch  35	Train Loss: 0.387495	Train Acc: 92.8571%	Val Loss: 1.655083	Val Acc: 65.8000%
Epoch  36	Train Loss: 0.411174	Train Acc: 92.8571%	Val Loss: 1.688826	Val Acc: 65.0000%
Epoch  37	Train Loss: 0.416674	Train Acc: 92.8571%	Val Loss: 1.724461	Val Acc: 65.0000%
Epoch  38	Train Loss: 0.401107	Train Acc: 92.8571%	Val Loss: 1.752678	Val Acc: 64.8000%
Epoch  39	Train Loss: 0.364621	Train Acc: 93.5714%	Val Loss: 1.772014	Val Acc: 64.8000%
Epoch  40	Train Loss: 0.363665	Train Acc: 94.2857%	Val Loss: 1.782372	Val Acc: 64.8000%
Epoch  41	Train Loss: 0.399835	Train Acc: 94.2857%	Val Loss: 1.781208	Val Acc: 64.8000%
Epoch  42	Train Loss: 0.360145	Train Acc: 94.2857%	Val Loss: 1.773606	Val Acc: 65.0000%
Epoch  43	Train Loss: 0.379509	Train Acc: 95.0000%	Val Loss: 1.764072	Val Acc: 65.6000%
Epoch  44	Train Loss: 0.329097	Train Acc: 95.0000%	Val Loss: 1.756897	Val Acc: 66.2000%
Epoch  45	Train Loss: 0.358217	Train Acc: 94.2857%	Val Loss: 1.754727	Val Acc: 66.6000%
Epoch  46	Train Loss: 0.352214	Train Acc: 94.2857%	Val Loss: 1.746168	Val Acc: 67.2000%
Early Stopping at Epoch 46
Test Loss: 1.545399	Test Acc: 68.3000%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.