Skip to content

Graph Convolutional Networks on Cora

This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.

julia
using Lux,
    Reactant,
    MLDatasets,
    Random,
    Statistics,
    GNNGraphs,
    ConcreteStructs,
    Printf,
    OneHotArrays,
    Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

Loading Cora Dataset

julia
function loadcora()
    data = Cora()
    gph = data.graphs[1]
    gnngraph = GNNGraph(
        gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
    )
    return (
        gph.node_data.features,
        onehotbatch(gph.node_data.targets, data.metadata["classes"]),
        # We use a dense matrix here to avoid incompatibility with Reactant
        Matrix{Int32}(adjacency_matrix(gnngraph)),
        # We use this since Reactant doesn't yet support gather adjoint
        (1:140, 141:640, 1709:2708),
    )
end

Model Definition

julia
function GCNLayer(args...; kwargs...)
    return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
        @return dense(x) * adj
    end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
    layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
    gcn_layers = [
        GCNLayer(in_dim => out_dim; kwargs...) for
        (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
    ]
    last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
    dropout = Dropout(dropout)

    return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
        for layer in gcn_layers
            x = relu.(layer((x, adj)))
            x = dropout(x)
        end
        @return last_layer((x, adj))[:, mask]
    end
end

Helper Functions

julia
function loss_function(model, ps, st, (x, y, adj, mask))
    y_pred, st = model((x, adj, mask), ps, st)
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
    return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100

Training the Model

julia
function main(;
    hidden_dim::Int=64,
    dropout::Float64=0.1,
    nb_layers::Int=2,
    use_bias::Bool=true,
    lr::Float64=0.001,
    weight_decay::Float64=0.0,
    patience::Int=20,
    epochs::Int=200,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())

    gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
    ps, st = xdev(Lux.setup(rng, gcn))
    opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

    train_state = Training.TrainState(gcn, ps, st, opt)

    @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)

    val_loss_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
    end

    train_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
    end
    val_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
    end

    best_loss_val = Inf
    cnt = 0

    for epoch in 1:epochs
        (_, loss, _, train_state) = Lux.Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (features, targets, adj, train_idx),
            train_state;
            return_gradients=Val(false),
        )
        train_acc = accuracy(
            Array(
                train_model_compiled(
                    (features, adj, train_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, train_idx],
        )

        val_loss = first(
            val_loss_compiled(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, val_idx),
            ),
        )
        val_acc = accuracy(
            Array(
                val_model_compiled(
                    (features, adj, val_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, val_idx],
        )

        @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
                 Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

        if val_loss < best_loss_val
            best_loss_val = val_loss
            cnt = 0
        else
            cnt += 1
            if cnt == patience
                @printf "Early Stopping at Epoch %d\n" epoch
                break
            end
        end
    end

    Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        test_loss = @jit(
            loss_function(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, test_idx),
            )
        )[1]
        test_acc = accuracy(
            Array(
                @jit(
                    gcn(
                        (features, adj, test_idx),
                        train_state.parameters,
                        Lux.testmode(train_state.states),
                    )
                )[1],
            ),
            Array(targets)[:, test_idx],
        )

        @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
    end
    return nothing
end

main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1759714263.771802 1329998 service.cc:158] XLA service 0x4aa89970 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1759714263.772026 1329998 service.cc:166]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1759714263.774609 1329998 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1759714263.774723 1329998 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1759714263.775002 1329998 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1759714263.795361 1329998 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch   1	Train Loss: 16.121910	Train Acc: 22.1429%	Val Loss: 6.888684	Val Acc: 25.4000%
Epoch   2	Train Loss: 9.096497	Train Acc: 25.7143%	Val Loss: 2.945881	Val Acc: 30.4000%
Epoch   3	Train Loss: 3.630679	Train Acc: 43.5714%	Val Loss: 1.899566	Val Acc: 44.0000%
Epoch   4	Train Loss: 2.384638	Train Acc: 57.8571%	Val Loss: 1.668179	Val Acc: 47.8000%
Epoch   5	Train Loss: 1.543190	Train Acc: 62.8571%	Val Loss: 1.793227	Val Acc: 46.8000%
Epoch   6	Train Loss: 1.546407	Train Acc: 68.5714%	Val Loss: 1.688540	Val Acc: 51.4000%
Epoch   7	Train Loss: 1.457959	Train Acc: 75.0000%	Val Loss: 1.544119	Val Acc: 58.4000%
Epoch   8	Train Loss: 1.166498	Train Acc: 75.7143%	Val Loss: 1.479228	Val Acc: 62.4000%
Epoch   9	Train Loss: 1.066184	Train Acc: 79.2857%	Val Loss: 1.436001	Val Acc: 64.4000%
Epoch  10	Train Loss: 1.243302	Train Acc: 77.8571%	Val Loss: 1.426260	Val Acc: 64.6000%
Epoch  11	Train Loss: 1.082464	Train Acc: 79.2857%	Val Loss: 1.426940	Val Acc: 65.4000%
Epoch  12	Train Loss: 1.000281	Train Acc: 80.0000%	Val Loss: 1.435458	Val Acc: 66.8000%
Epoch  13	Train Loss: 1.358913	Train Acc: 78.5714%	Val Loss: 1.467632	Val Acc: 65.8000%
Epoch  14	Train Loss: 0.729061	Train Acc: 82.1429%	Val Loss: 1.528814	Val Acc: 64.6000%
Epoch  15	Train Loss: 0.828481	Train Acc: 81.4286%	Val Loss: 1.617502	Val Acc: 63.2000%
Epoch  16	Train Loss: 0.845278	Train Acc: 80.7143%	Val Loss: 1.706484	Val Acc: 62.2000%
Epoch  17	Train Loss: 0.812649	Train Acc: 81.4286%	Val Loss: 1.754158	Val Acc: 62.2000%
Epoch  18	Train Loss: 1.037319	Train Acc: 82.8571%	Val Loss: 1.702474	Val Acc: 63.4000%
Epoch  19	Train Loss: 0.788434	Train Acc: 84.2857%	Val Loss: 1.641294	Val Acc: 64.8000%
Epoch  20	Train Loss: 0.692212	Train Acc: 86.4286%	Val Loss: 1.602820	Val Acc: 65.2000%
Epoch  21	Train Loss: 0.623656	Train Acc: 86.4286%	Val Loss: 1.591236	Val Acc: 66.4000%
Epoch  22	Train Loss: 0.664444	Train Acc: 85.7143%	Val Loss: 1.612672	Val Acc: 65.6000%
Epoch  23	Train Loss: 0.595455	Train Acc: 84.2857%	Val Loss: 1.626493	Val Acc: 66.0000%
Epoch  24	Train Loss: 0.655217	Train Acc: 85.7143%	Val Loss: 1.620098	Val Acc: 66.4000%
Epoch  25	Train Loss: 0.576780	Train Acc: 85.7143%	Val Loss: 1.594717	Val Acc: 67.0000%
Epoch  26	Train Loss: 0.775038	Train Acc: 88.5714%	Val Loss: 1.557361	Val Acc: 67.4000%
Epoch  27	Train Loss: 0.573833	Train Acc: 87.8571%	Val Loss: 1.526408	Val Acc: 68.4000%
Epoch  28	Train Loss: 0.492148	Train Acc: 88.5714%	Val Loss: 1.504853	Val Acc: 68.2000%
Epoch  29	Train Loss: 0.607460	Train Acc: 89.2857%	Val Loss: 1.495968	Val Acc: 68.2000%
Epoch  30	Train Loss: 0.487236	Train Acc: 89.2857%	Val Loss: 1.498273	Val Acc: 68.2000%
Early Stopping at Epoch 30
Test Loss: 1.335748	Test Acc: 69.0000%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.