Skip to content

Graph Convolutional Networks on Cora

This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.

julia
using Lux,
    Reactant,
    MLDatasets,
    Random,
    Statistics,
    GNNGraphs,
    ConcreteStructs,
    Printf,
    OneHotArrays,
    Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

Loading Cora Dataset

julia
function loadcora()
    data = Cora()
    gph = data.graphs[1]
    gnngraph = GNNGraph(
        gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
    )
    return (
        gph.node_data.features,
        onehotbatch(gph.node_data.targets, data.metadata["classes"]),
        # We use a dense matrix here to avoid incompatibility with Reactant
        Matrix{Int32}(adjacency_matrix(gnngraph)),
        # We use this since Reactant doesn't yet support gather adjoint
        (1:140, 141:640, 1709:2708),
    )
end

Model Definition

julia
function GCNLayer(args...; kwargs...)
    return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
        @return dense(x) * adj
    end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
    layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
    gcn_layers = [
        GCNLayer(in_dim => out_dim; kwargs...) for
        (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
    ]
    last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
    dropout = Dropout(dropout)

    return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
        for layer in gcn_layers
            x = relu.(layer((x, adj)))
            x = dropout(x)
        end
        @return last_layer((x, adj))[:, mask]
    end
end

Helper Functions

julia
function loss_function(model, ps, st, (x, y, adj, mask))
    y_pred, st = model((x, adj, mask), ps, st)
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
    return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100

Training the Model

julia
function main(;
    hidden_dim::Int=64,
    dropout::Float64=0.1,
    nb_layers::Int=2,
    use_bias::Bool=true,
    lr::Float64=0.001,
    weight_decay::Float64=0.0,
    patience::Int=20,
    epochs::Int=200,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())

    gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
    ps, st = xdev(Lux.setup(rng, gcn))
    opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

    train_state = Training.TrainState(gcn, ps, st, opt)

    @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)

    val_loss_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
    end

    train_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
    end
    val_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
    end

    best_loss_val = Inf
    cnt = 0

    for epoch in 1:epochs
        (_, loss, _, train_state) = Lux.Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (features, targets, adj, train_idx),
            train_state;
            return_gradients=Val(false),
        )
        train_acc = accuracy(
            Array(
                train_model_compiled(
                    (features, adj, train_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, train_idx],
        )

        val_loss = first(
            val_loss_compiled(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, val_idx),
            ),
        )
        val_acc = accuracy(
            Array(
                val_model_compiled(
                    (features, adj, val_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, val_idx],
        )

        @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
                 Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

        if val_loss < best_loss_val
            best_loss_val = val_loss
            cnt = 0
        else
            cnt += 1
            if cnt == patience
                @printf "Early Stopping at Epoch %d\n" epoch
                break
            end
        end
    end

    Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        test_loss = @jit(
            loss_function(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, test_idx),
            )
        )[1]
        test_acc = accuracy(
            Array(
                @jit(
                    gcn(
                        (features, adj, test_idx),
                        train_state.parameters,
                        Lux.testmode(train_state.states),
                    )
                )[1],
            ),
            Array(targets)[:, test_idx],
        )

        @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
    end
    return nothing
end

main()
2025-08-05 23:25:53.732636: I external/xla/xla/service/service.cc:163] XLA service 0x13ad4c50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-08-05 23:25:53.732751: I external/xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1754436353.733934  324692 se_gpu_pjrt_client.cc:1373] Using BFC allocator.
I0000 00:00:1754436353.734227  324692 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1754436353.734352  324692 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
2025-08-05 23:25:53.752855: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-6/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{false}()` but is being used within an autodiff call (gradient, jacobian, etc...). This might lead to incorrect results. If you are using a `Lux.jl` model, set it to training mode using `LuxCore.trainmode`.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-6/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:344
Epoch   1	Train Loss: 16.434761	Train Acc: 20.0000%	Val Loss: 8.514094	Val Acc: 21.2000%
Epoch   2	Train Loss: 9.712398	Train Acc: 23.5714%	Val Loss: 4.395040	Val Acc: 29.4000%
Epoch   3	Train Loss: 4.724273	Train Acc: 37.8571%	Val Loss: 2.064650	Val Acc: 36.4000%
Epoch   4	Train Loss: 2.234418	Train Acc: 44.2857%	Val Loss: 2.274232	Val Acc: 39.2000%
Epoch   5	Train Loss: 2.302605	Train Acc: 47.1429%	Val Loss: 2.332959	Val Acc: 37.4000%
Epoch   6	Train Loss: 2.403902	Train Acc: 62.8571%	Val Loss: 1.965535	Val Acc: 45.4000%
Epoch   7	Train Loss: 1.610807	Train Acc: 71.4286%	Val Loss: 1.653510	Val Acc: 54.2000%
Epoch   8	Train Loss: 1.336507	Train Acc: 72.1429%	Val Loss: 1.524905	Val Acc: 61.4000%
Epoch   9	Train Loss: 1.174518	Train Acc: 74.2857%	Val Loss: 1.492613	Val Acc: 62.6000%
Epoch  10	Train Loss: 1.125156	Train Acc: 76.4286%	Val Loss: 1.480971	Val Acc: 64.4000%
Epoch  11	Train Loss: 1.222813	Train Acc: 77.8571%	Val Loss: 1.459094	Val Acc: 66.0000%
Epoch  12	Train Loss: 1.021737	Train Acc: 79.2857%	Val Loss: 1.445725	Val Acc: 65.8000%
Epoch  13	Train Loss: 0.973203	Train Acc: 79.2857%	Val Loss: 1.447371	Val Acc: 65.8000%
Epoch  14	Train Loss: 0.825098	Train Acc: 79.2857%	Val Loss: 1.460149	Val Acc: 66.4000%
Epoch  15	Train Loss: 0.740183	Train Acc: 85.0000%	Val Loss: 1.488243	Val Acc: 66.4000%
Epoch  16	Train Loss: 1.266764	Train Acc: 83.5714%	Val Loss: 1.495189	Val Acc: 66.4000%
Epoch  17	Train Loss: 2.369039	Train Acc: 85.0000%	Val Loss: 1.481928	Val Acc: 67.4000%
Epoch  18	Train Loss: 0.696506	Train Acc: 84.2857%	Val Loss: 1.483977	Val Acc: 66.0000%
Epoch  19	Train Loss: 0.660127	Train Acc: 85.0000%	Val Loss: 1.489409	Val Acc: 66.0000%
Epoch  20	Train Loss: 1.067086	Train Acc: 85.0000%	Val Loss: 1.526611	Val Acc: 66.6000%
Epoch  21	Train Loss: 0.606778	Train Acc: 85.0000%	Val Loss: 1.585117	Val Acc: 66.0000%
Epoch  22	Train Loss: 0.648352	Train Acc: 85.7143%	Val Loss: 1.648506	Val Acc: 65.6000%
Epoch  23	Train Loss: 0.743714	Train Acc: 86.4286%	Val Loss: 1.711921	Val Acc: 65.8000%
Epoch  24	Train Loss: 0.716126	Train Acc: 87.1429%	Val Loss: 1.748055	Val Acc: 65.8000%
Epoch  25	Train Loss: 0.749043	Train Acc: 87.1429%	Val Loss: 1.765788	Val Acc: 65.6000%
Epoch  26	Train Loss: 0.678848	Train Acc: 88.5714%	Val Loss: 1.757170	Val Acc: 67.0000%
Epoch  27	Train Loss: 0.847598	Train Acc: 89.2857%	Val Loss: 1.742121	Val Acc: 67.6000%
Epoch  28	Train Loss: 0.673782	Train Acc: 88.5714%	Val Loss: 1.710881	Val Acc: 67.4000%
Epoch  29	Train Loss: 0.633077	Train Acc: 90.0000%	Val Loss: 1.686995	Val Acc: 68.0000%
Epoch  30	Train Loss: 0.517133	Train Acc: 91.4286%	Val Loss: 1.662811	Val Acc: 67.4000%
Epoch  31	Train Loss: 0.574781	Train Acc: 90.0000%	Val Loss: 1.641438	Val Acc: 67.4000%
Epoch  32	Train Loss: 0.551590	Train Acc: 90.0000%	Val Loss: 1.624909	Val Acc: 67.8000%
Early Stopping at Epoch 32
Test Loss: 1.385764	Test Acc: 71.1000%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.