Skip to content

Graph Convolutional Networks on Cora

This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.

julia
using Lux,
    Reactant,
    MLDatasets,
    Random,
    Statistics,
    GNNGraphs,
    ConcreteStructs,
    Printf,
    OneHotArrays,
    Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

Loading Cora Dataset

julia
function loadcora()
    data = Cora()
    gph = data.graphs[1]
    gnngraph = GNNGraph(
        gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
    )
    return (
        gph.node_data.features,
        onehotbatch(gph.node_data.targets, data.metadata["classes"]),
        # We use a dense matrix here to avoid incompatibility with Reactant
        Matrix{Int32}(adjacency_matrix(gnngraph)),
        # We use this since Reactant doesn't yet support gather adjoint
        (1:140, 141:640, 1709:2708),
    )
end

Model Definition

julia
function GCNLayer(args...; kwargs...)
    return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
        @return dense(x) * adj
    end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
    layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
    gcn_layers = [
        GCNLayer(in_dim => out_dim; kwargs...) for
        (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
    ]
    last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
    dropout = Dropout(dropout)

    return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
        for layer in gcn_layers
            x = relu.(layer((x, adj)))
            x = dropout(x)
        end
        @return last_layer((x, adj))[:, mask]
    end
end

Helper Functions

julia
function loss_function(model, ps, st, (x, y, adj, mask))
    y_pred, st = model((x, adj, mask), ps, st)
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
    return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100

Training the Model

julia
function main(;
    hidden_dim::Int=64,
    dropout::Float64=0.1,
    nb_layers::Int=2,
    use_bias::Bool=true,
    lr::Float64=0.001,
    weight_decay::Float64=0.0,
    patience::Int=20,
    epochs::Int=200,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())

    gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
    ps, st = xdev(Lux.setup(rng, gcn))
    opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

    train_state = Training.TrainState(gcn, ps, st, opt)

    @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)

    val_loss_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
    end

    train_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
    end
    val_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
    end

    best_loss_val = Inf
    cnt = 0

    for epoch in 1:epochs
        (_, loss, _, train_state) = Lux.Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (features, targets, adj, train_idx),
            train_state;
            return_gradients=Val(false),
        )
        train_acc = accuracy(
            Array(
                train_model_compiled(
                    (features, adj, train_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, train_idx],
        )

        val_loss = first(
            val_loss_compiled(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, val_idx),
            ),
        )
        val_acc = accuracy(
            Array(
                val_model_compiled(
                    (features, adj, val_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, val_idx],
        )

        @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
                 Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

        if val_loss < best_loss_val
            best_loss_val = val_loss
            cnt = 0
        else
            cnt += 1
            if cnt == patience
                @printf "Early Stopping at Epoch %d\n" epoch
                break
            end
        end
    end

    Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        test_loss = @jit(
            loss_function(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, test_idx),
            )
        )[1]
        test_acc = accuracy(
            Array(
                @jit(
                    gcn(
                        (features, adj, test_idx),
                        train_state.parameters,
                        Lux.testmode(train_state.states),
                    )
                )[1],
            ),
            Array(targets)[:, test_idx],
        )

        @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
    end
    return nothing
end

main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1761832949.954871 1835273 service.cc:158] XLA service 0x47501910 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1761832949.954919 1835273 service.cc:166]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1761832949.956858 1835273 se_gpu_pjrt_client.cc:770] Using BFC allocator.
I0000 00:00:1761832949.956998 1835273 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1761832949.957069 1835273 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1761832949.979179 1835273 cuda_dnn.cc:463] Loaded cuDNN version 91400
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-6/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-6/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch   1	Train Loss: 15.485048	Train Acc: 22.1429%	Val Loss: 7.573927	Val Acc: 25.8000%
Epoch   2	Train Loss: 10.128677	Train Acc: 22.1429%	Val Loss: 3.801739	Val Acc: 29.4000%
Epoch   3	Train Loss: 4.470746	Train Acc: 37.8571%	Val Loss: 2.432641	Val Acc: 32.0000%
Epoch   4	Train Loss: 2.426969	Train Acc: 51.4286%	Val Loss: 2.113595	Val Acc: 37.8000%
Epoch   5	Train Loss: 1.762580	Train Acc: 59.2857%	Val Loss: 1.886148	Val Acc: 45.0000%
Epoch   6	Train Loss: 1.483075	Train Acc: 67.8571%	Val Loss: 1.608979	Val Acc: 51.8000%
Epoch   7	Train Loss: 1.267287	Train Acc: 71.4286%	Val Loss: 1.503945	Val Acc: 58.8000%
Epoch   8	Train Loss: 1.319186	Train Acc: 72.1429%	Val Loss: 1.505008	Val Acc: 60.2000%
Epoch   9	Train Loss: 1.627862	Train Acc: 72.8571%	Val Loss: 1.520744	Val Acc: 61.4000%
Epoch  10	Train Loss: 1.249889	Train Acc: 74.2857%	Val Loss: 1.519149	Val Acc: 62.0000%
Epoch  11	Train Loss: 1.186345	Train Acc: 78.5714%	Val Loss: 1.504691	Val Acc: 62.4000%
Epoch  12	Train Loss: 1.178998	Train Acc: 78.5714%	Val Loss: 1.548285	Val Acc: 61.4000%
Epoch  13	Train Loss: 0.900217	Train Acc: 79.2857%	Val Loss: 1.609031	Val Acc: 62.2000%
Epoch  14	Train Loss: 0.947796	Train Acc: 80.0000%	Val Loss: 1.651340	Val Acc: 62.0000%
Epoch  15	Train Loss: 1.408850	Train Acc: 80.7143%	Val Loss: 1.636881	Val Acc: 64.0000%
Epoch  16	Train Loss: 0.877241	Train Acc: 82.1429%	Val Loss: 1.619912	Val Acc: 66.2000%
Epoch  17	Train Loss: 0.810141	Train Acc: 81.4286%	Val Loss: 1.595838	Val Acc: 66.6000%
Epoch  18	Train Loss: 0.763295	Train Acc: 80.7143%	Val Loss: 1.572466	Val Acc: 67.8000%
Epoch  19	Train Loss: 0.878369	Train Acc: 82.1429%	Val Loss: 1.545329	Val Acc: 67.2000%
Epoch  20	Train Loss: 0.748767	Train Acc: 82.8571%	Val Loss: 1.521424	Val Acc: 66.6000%
Epoch  21	Train Loss: 0.683971	Train Acc: 82.8571%	Val Loss: 1.503487	Val Acc: 66.6000%
Epoch  22	Train Loss: 0.610947	Train Acc: 85.0000%	Val Loss: 1.498878	Val Acc: 66.0000%
Epoch  23	Train Loss: 0.603293	Train Acc: 85.0000%	Val Loss: 1.508867	Val Acc: 66.0000%
Epoch  24	Train Loss: 1.566445	Train Acc: 85.0000%	Val Loss: 1.545763	Val Acc: 66.2000%
Epoch  25	Train Loss: 0.565211	Train Acc: 87.1429%	Val Loss: 1.609411	Val Acc: 65.0000%
Epoch  26	Train Loss: 0.524949	Train Acc: 87.1429%	Val Loss: 1.686930	Val Acc: 64.2000%
Epoch  27	Train Loss: 0.506573	Train Acc: 88.5714%	Val Loss: 1.779495	Val Acc: 64.2000%
Epoch  28	Train Loss: 0.619784	Train Acc: 87.8571%	Val Loss: 1.844049	Val Acc: 63.2000%
Epoch  29	Train Loss: 0.578553	Train Acc: 87.8571%	Val Loss: 1.864435	Val Acc: 63.4000%
Epoch  30	Train Loss: 0.487108	Train Acc: 88.5714%	Val Loss: 1.866593	Val Acc: 64.0000%
Epoch  31	Train Loss: 0.490343	Train Acc: 89.2857%	Val Loss: 1.841034	Val Acc: 64.4000%
Epoch  32	Train Loss: 0.561964	Train Acc: 90.0000%	Val Loss: 1.792846	Val Acc: 66.2000%
Epoch  33	Train Loss: 0.489364	Train Acc: 91.4286%	Val Loss: 1.735448	Val Acc: 66.6000%
Epoch  34	Train Loss: 0.614714	Train Acc: 91.4286%	Val Loss: 1.695194	Val Acc: 66.4000%
Epoch  35	Train Loss: 0.439960	Train Acc: 92.8571%	Val Loss: 1.663030	Val Acc: 66.6000%
Epoch  36	Train Loss: 0.414342	Train Acc: 92.8571%	Val Loss: 1.645411	Val Acc: 67.4000%
Epoch  37	Train Loss: 0.395993	Train Acc: 93.5714%	Val Loss: 1.639418	Val Acc: 68.2000%
Epoch  38	Train Loss: 0.370129	Train Acc: 93.5714%	Val Loss: 1.643817	Val Acc: 68.2000%
Epoch  39	Train Loss: 0.405889	Train Acc: 93.5714%	Val Loss: 1.657985	Val Acc: 67.8000%
Epoch  40	Train Loss: 0.799954	Train Acc: 95.7143%	Val Loss: 1.680702	Val Acc: 67.6000%
Epoch  41	Train Loss: 0.378267	Train Acc: 95.7143%	Val Loss: 1.712448	Val Acc: 67.8000%
Epoch  42	Train Loss: 0.367016	Train Acc: 95.0000%	Val Loss: 1.747182	Val Acc: 68.4000%
Early Stopping at Epoch 42
Test Loss: 1.522911	Test Acc: 68.6000%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.