Skip to content

Graph Convolutional Networks on Cora

This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.

julia
using Lux,
    Reactant,
    MLDatasets,
    Random,
    Statistics,
    GNNGraphs,
    ConcreteStructs,
    Printf,
    OneHotArrays,
    Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

Loading Cora Dataset

julia
function loadcora()
    data = Cora()
    gph = data.graphs[1]
    gnngraph = GNNGraph(
        gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
    )
    return (
        gph.node_data.features,
        onehotbatch(gph.node_data.targets, data.metadata["classes"]),
        # We use a dense matrix here to avoid incompatibility with Reactant
        Matrix{Int32}(adjacency_matrix(gnngraph)),
        # We use this since Reactant doesn't yet support gather adjoint
        (1:140, 141:640, 1709:2708),
    )
end

Model Definition

julia
function GCNLayer(args...; kwargs...)
    return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
        @return dense(x) * adj
    end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
    layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
    gcn_layers = [
        GCNLayer(in_dim => out_dim; kwargs...) for
        (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
    ]
    last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
    dropout = Dropout(dropout)

    return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
        for layer in gcn_layers
            x = relu.(layer((x, adj)))
            x = dropout(x)
        end
        @return last_layer((x, adj))[:, mask]
    end
end

Helper Functions

julia
function loss_function(model, ps, st, (x, y, adj, mask))
    y_pred, st = model((x, adj, mask), ps, st)
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
    return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100

Training the Model

julia
function main(;
    hidden_dim::Int=64,
    dropout::Float64=0.1,
    nb_layers::Int=2,
    use_bias::Bool=true,
    lr::Float64=0.001,
    weight_decay::Float64=0.0,
    patience::Int=20,
    epochs::Int=200,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())

    gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
    ps, st = xdev(Lux.setup(rng, gcn))
    opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

    train_state = Training.TrainState(gcn, ps, st, opt)

    @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)

    val_loss_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
    end

    train_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
    end
    val_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
    end

    best_loss_val = Inf
    cnt = 0

    for epoch in 1:epochs
        (_, loss, _, train_state) = Lux.Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (features, targets, adj, train_idx),
            train_state;
            return_gradients=Val(false),
        )
        train_acc = accuracy(
            Array(
                train_model_compiled(
                    (features, adj, train_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, train_idx],
        )

        val_loss = first(
            val_loss_compiled(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, val_idx),
            ),
        )
        val_acc = accuracy(
            Array(
                val_model_compiled(
                    (features, adj, val_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, val_idx],
        )

        @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
                 Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

        if val_loss < best_loss_val
            best_loss_val = val_loss
            cnt = 0
        else
            cnt += 1
            if cnt == patience
                @printf "Early Stopping at Epoch %d\n" epoch
                break
            end
        end
    end

    Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        test_loss = @jit(
            loss_function(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, test_idx),
            )
        )[1]
        test_acc = accuracy(
            Array(
                @jit(
                    gcn(
                        (features, adj, test_idx),
                        train_state.parameters,
                        Lux.testmode(train_state.states),
                    )
                )[1],
            ),
            Array(targets)[:, test_idx],
        )

        @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
    end
    return nothing
end

main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1760231437.476673   79531 service.cc:158] XLA service 0x4d9986c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760231437.476762   79531 service.cc:166]   StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
I0000 00:00:1760231437.476768   79531 service.cc:166]   StreamExecutor device (1): Quadro RTX 5000, Compute Capability 7.5
I0000 00:00:1760231437.482079   79531 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1760231437.482127   79531 gpu_helpers.cc:136] XLA backend allocating 12526534656 bytes on device 0 for BFCAllocator.
I0000 00:00:1760231437.482173   79531 gpu_helpers.cc:136] XLA backend allocating 12526534656 bytes on device 1 for BFCAllocator.
I0000 00:00:1760231437.482190   79531 gpu_helpers.cc:177] XLA backend will use up to 4175511552 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1760231437.482207   79531 gpu_helpers.cc:177] XLA backend will use up to 4175511552 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1760231437.493627   79531 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-16/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-16/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch   1	Train Loss: 17.697763	Train Acc: 20.7143%	Val Loss: 6.744421	Val Acc: 26.0000%
Epoch   2	Train Loss: 8.571198	Train Acc: 26.4286%	Val Loss: 2.712524	Val Acc: 31.4000%
Epoch   3	Train Loss: 2.910695	Train Acc: 40.0000%	Val Loss: 2.135307	Val Acc: 38.2000%
Epoch   4	Train Loss: 2.795473	Train Acc: 52.8571%	Val Loss: 1.837171	Val Acc: 46.2000%
Epoch   5	Train Loss: 2.066922	Train Acc: 60.7143%	Val Loss: 1.910931	Val Acc: 46.6000%
Epoch   6	Train Loss: 1.762443	Train Acc: 63.5714%	Val Loss: 1.863445	Val Acc: 48.4000%
Epoch   7	Train Loss: 1.493714	Train Acc: 68.5714%	Val Loss: 1.749854	Val Acc: 52.8000%
Epoch   8	Train Loss: 1.435782	Train Acc: 72.8571%	Val Loss: 1.621837	Val Acc: 59.4000%
Epoch   9	Train Loss: 1.241855	Train Acc: 75.0000%	Val Loss: 1.551770	Val Acc: 62.4000%
Epoch  10	Train Loss: 1.570425	Train Acc: 76.4286%	Val Loss: 1.507813	Val Acc: 63.6000%
Epoch  11	Train Loss: 1.020015	Train Acc: 79.2857%	Val Loss: 1.505572	Val Acc: 65.0000%
Epoch  12	Train Loss: 1.030396	Train Acc: 78.5714%	Val Loss: 1.522349	Val Acc: 64.2000%
Epoch  13	Train Loss: 0.959203	Train Acc: 77.8571%	Val Loss: 1.550502	Val Acc: 63.2000%
Epoch  14	Train Loss: 0.945190	Train Acc: 77.1429%	Val Loss: 1.575250	Val Acc: 62.8000%
Epoch  15	Train Loss: 0.889157	Train Acc: 77.8571%	Val Loss: 1.596992	Val Acc: 63.4000%
Epoch  16	Train Loss: 1.066568	Train Acc: 80.0000%	Val Loss: 1.596375	Val Acc: 63.2000%
Epoch  17	Train Loss: 1.299356	Train Acc: 85.0000%	Val Loss: 1.636672	Val Acc: 64.0000%
Epoch  18	Train Loss: 0.800040	Train Acc: 84.2857%	Val Loss: 1.749711	Val Acc: 64.4000%
Epoch  19	Train Loss: 0.738295	Train Acc: 82.1429%	Val Loss: 1.902346	Val Acc: 62.8000%
Epoch  20	Train Loss: 0.846399	Train Acc: 80.7143%	Val Loss: 2.023381	Val Acc: 62.0000%
Epoch  21	Train Loss: 1.027444	Train Acc: 80.7143%	Val Loss: 2.090489	Val Acc: 61.8000%
Epoch  22	Train Loss: 1.034440	Train Acc: 81.4286%	Val Loss: 2.086383	Val Acc: 62.0000%
Epoch  23	Train Loss: 0.932883	Train Acc: 84.2857%	Val Loss: 2.020211	Val Acc: 62.2000%
Epoch  24	Train Loss: 0.854405	Train Acc: 87.1429%	Val Loss: 1.926812	Val Acc: 62.8000%
Epoch  25	Train Loss: 0.823050	Train Acc: 87.8571%	Val Loss: 1.839584	Val Acc: 65.6000%
Epoch  26	Train Loss: 0.824092	Train Acc: 87.8571%	Val Loss: 1.759949	Val Acc: 65.4000%
Epoch  27	Train Loss: 0.726058	Train Acc: 87.8571%	Val Loss: 1.696039	Val Acc: 65.6000%
Epoch  28	Train Loss: 0.691226	Train Acc: 87.8571%	Val Loss: 1.646164	Val Acc: 65.8000%
Epoch  29	Train Loss: 0.567212	Train Acc: 87.8571%	Val Loss: 1.609809	Val Acc: 67.4000%
Epoch  30	Train Loss: 0.929916	Train Acc: 88.5714%	Val Loss: 1.586185	Val Acc: 67.2000%
Epoch  31	Train Loss: 0.443747	Train Acc: 88.5714%	Val Loss: 1.573454	Val Acc: 67.2000%
Early Stopping at Epoch 31
Test Loss: 1.459343	Test Acc: 70.7000%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.