Skip to content

Graph Convolutional Networks on Cora

This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.

julia
using Lux,
    Reactant,
    MLDatasets,
    Random,
    Statistics,
    GNNGraphs,
    ConcreteStructs,
    Printf,
    OneHotArrays,
    Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

Loading Cora Dataset

julia
function loadcora()
    data = Cora()
    gph = data.graphs[1]
    gnngraph = GNNGraph(
        gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
    )
    return (
        gph.node_data.features,
        onehotbatch(gph.node_data.targets, data.metadata["classes"]),
        # We use a dense matrix here to avoid incompatibility with Reactant
        Matrix{Int32}(adjacency_matrix(gnngraph)),
        # We use this since Reactant doesn't yet support gather adjoint
        (1:140, 141:640, 1709:2708),
    )
end

Model Definition

julia
function GCNLayer(args...; kwargs...)
    return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
        @return dense(x) * adj
    end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
    layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
    gcn_layers = [
        GCNLayer(in_dim => out_dim; kwargs...) for
        (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
    ]
    last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
    dropout = Dropout(dropout)

    return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
        for layer in gcn_layers
            x = relu.(layer((x, adj)))
            x = dropout(x)
        end
        @return last_layer((x, adj))[:, mask]
    end
end

Helper Functions

julia
function loss_function(model, ps, st, (x, y, adj, mask))
    y_pred, st = model((x, adj, mask), ps, st)
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
    return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100

Training the Model

julia
function main(;
    hidden_dim::Int=64,
    dropout::Float64=0.1,
    nb_layers::Int=2,
    use_bias::Bool=true,
    lr::Float64=0.001,
    weight_decay::Float64=0.0,
    patience::Int=20,
    epochs::Int=200,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())

    gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
    ps, st = xdev(Lux.setup(rng, gcn))
    opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

    train_state = Training.TrainState(gcn, ps, st, opt)

    @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)

    val_loss_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
    end

    train_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
    end
    val_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
    end

    best_loss_val = Inf
    cnt = 0

    for epoch in 1:epochs
        (_, loss, _, train_state) = Lux.Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (features, targets, adj, train_idx),
            train_state;
            return_gradients=Val(false),
        )
        train_acc = accuracy(
            Array(
                train_model_compiled(
                    (features, adj, train_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, train_idx],
        )

        val_loss = first(
            val_loss_compiled(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, val_idx),
            ),
        )
        val_acc = accuracy(
            Array(
                val_model_compiled(
                    (features, adj, val_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, val_idx],
        )

        @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
                 Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

        if val_loss < best_loss_val
            best_loss_val = val_loss
            cnt = 0
        else
            cnt += 1
            if cnt == patience
                @printf "Early Stopping at Epoch %d\n" epoch
                break
            end
        end
    end

    Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        test_loss = @jit(
            loss_function(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, test_idx),
            )
        )[1]
        test_acc = accuracy(
            Array(
                @jit(
                    gcn(
                        (features, adj, test_idx),
                        train_state.parameters,
                        Lux.testmode(train_state.states),
                    )
                )[1],
            ),
            Array(targets)[:, test_idx],
        )

        @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
    end
    return nothing
end

main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1757734448.307756  570007 service.cc:163] XLA service 0x2e27e540 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757734448.307937  570007 service.cc:171]   StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
I0000 00:00:1757734448.308014  570007 service.cc:171]   StreamExecutor device (1): Quadro RTX 5000, Compute Capability 7.5
I0000 00:00:1757734448.314613  570007 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1757734448.314681  570007 gpu_helpers.cc:136] XLA backend allocating 12526534656 bytes on device 0 for BFCAllocator.
I0000 00:00:1757734448.314744  570007 gpu_helpers.cc:136] XLA backend allocating 12526534656 bytes on device 1 for BFCAllocator.
I0000 00:00:1757734448.314767  570007 gpu_helpers.cc:177] XLA backend will use up to 4175511552 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1757734448.314792  570007 gpu_helpers.cc:177] XLA backend will use up to 4175511552 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1757734448.327068  570007 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-16/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-16/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch   1	Train Loss: 16.336132	Train Acc: 22.1429%	Val Loss: 7.009547	Val Acc: 22.8000%
Epoch   2	Train Loss: 8.029594	Train Acc: 20.7143%	Val Loss: 3.027081	Val Acc: 29.2000%
Epoch   3	Train Loss: 4.390297	Train Acc: 42.1429%	Val Loss: 1.811692	Val Acc: 40.4000%
Epoch   4	Train Loss: 1.927010	Train Acc: 55.7143%	Val Loss: 1.848529	Val Acc: 43.4000%
Epoch   5	Train Loss: 1.747656	Train Acc: 63.5714%	Val Loss: 1.859883	Val Acc: 43.8000%
Epoch   6	Train Loss: 1.636206	Train Acc: 70.0000%	Val Loss: 1.754520	Val Acc: 51.0000%
Epoch   7	Train Loss: 1.610943	Train Acc: 72.8571%	Val Loss: 1.633135	Val Acc: 56.4000%
Epoch   8	Train Loss: 1.470763	Train Acc: 77.1429%	Val Loss: 1.545484	Val Acc: 60.2000%
Epoch   9	Train Loss: 1.298347	Train Acc: 77.1429%	Val Loss: 1.483948	Val Acc: 62.8000%
Epoch  10	Train Loss: 1.224091	Train Acc: 80.7143%	Val Loss: 1.435608	Val Acc: 64.6000%
Epoch  11	Train Loss: 1.041614	Train Acc: 80.0000%	Val Loss: 1.408752	Val Acc: 65.8000%
Epoch  12	Train Loss: 1.044359	Train Acc: 80.7143%	Val Loss: 1.401051	Val Acc: 66.2000%
Epoch  13	Train Loss: 1.014527	Train Acc: 80.7143%	Val Loss: 1.405285	Val Acc: 66.2000%
Epoch  14	Train Loss: 0.819150	Train Acc: 81.4286%	Val Loss: 1.417203	Val Acc: 66.0000%
Epoch  15	Train Loss: 0.876931	Train Acc: 84.2857%	Val Loss: 1.417138	Val Acc: 66.8000%
Epoch  16	Train Loss: 1.664872	Train Acc: 85.0000%	Val Loss: 1.410902	Val Acc: 67.6000%
Epoch  17	Train Loss: 0.667342	Train Acc: 85.7143%	Val Loss: 1.458553	Val Acc: 67.2000%
Epoch  18	Train Loss: 0.756876	Train Acc: 84.2857%	Val Loss: 1.518630	Val Acc: 67.0000%
Epoch  19	Train Loss: 0.867558	Train Acc: 85.0000%	Val Loss: 1.572213	Val Acc: 66.8000%
Epoch  20	Train Loss: 0.751877	Train Acc: 85.7143%	Val Loss: 1.611869	Val Acc: 66.2000%
Epoch  21	Train Loss: 0.624684	Train Acc: 85.7143%	Val Loss: 1.631795	Val Acc: 66.6000%
Epoch  22	Train Loss: 0.981921	Train Acc: 87.1429%	Val Loss: 1.626106	Val Acc: 66.6000%
Epoch  23	Train Loss: 0.795467	Train Acc: 88.5714%	Val Loss: 1.606040	Val Acc: 66.0000%
Epoch  24	Train Loss: 0.898357	Train Acc: 88.5714%	Val Loss: 1.573473	Val Acc: 66.8000%
Epoch  25	Train Loss: 0.698883	Train Acc: 88.5714%	Val Loss: 1.541680	Val Acc: 68.0000%
Epoch  26	Train Loss: 0.645220	Train Acc: 88.5714%	Val Loss: 1.514167	Val Acc: 68.6000%
Epoch  27	Train Loss: 0.613002	Train Acc: 89.2857%	Val Loss: 1.493463	Val Acc: 68.4000%
Epoch  28	Train Loss: 0.565419	Train Acc: 88.5714%	Val Loss: 1.479156	Val Acc: 68.8000%
Epoch  29	Train Loss: 0.491974	Train Acc: 89.2857%	Val Loss: 1.470922	Val Acc: 67.8000%
Epoch  30	Train Loss: 0.447017	Train Acc: 90.7143%	Val Loss: 1.466520	Val Acc: 67.8000%
Epoch  31	Train Loss: 0.481589	Train Acc: 90.0000%	Val Loss: 1.467708	Val Acc: 68.2000%
Epoch  32	Train Loss: 0.466854	Train Acc: 90.7143%	Val Loss: 1.474037	Val Acc: 67.6000%
Early Stopping at Epoch 32
Test Loss: 1.318466	Test Acc: 70.3000%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.