Skip to content

Graph Convolutional Networks on Cora

This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.

julia
using Lux,
    Reactant,
    MLDatasets,
    Random,
    Statistics,
    GNNGraphs,
    ConcreteStructs,
    Printf,
    OneHotArrays,
    Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

Loading Cora Dataset

julia
function loadcora()
    data = Cora()
    gph = data.graphs[1]
    gnngraph = GNNGraph(
        gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
    )
    return (
        gph.node_data.features,
        onehotbatch(gph.node_data.targets, data.metadata["classes"]),
        # We use a dense matrix here to avoid incompatibility with Reactant
        Matrix{Int32}(adjacency_matrix(gnngraph)),
        # We use this since Reactant doesn't yet support gather adjoint
        (1:140, 141:640, 1709:2708),
    )
end

Model Definition

julia
function GCNLayer(args...; kwargs...)
    return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
        @return dense(x) * adj
    end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
    layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
    gcn_layers = [
        GCNLayer(in_dim => out_dim; kwargs...) for
        (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
    ]
    last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
    dropout = Dropout(dropout)

    return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
        for layer in gcn_layers
            x = relu.(layer((x, adj)))
            x = dropout(x)
        end
        @return last_layer((x, adj))[:, mask]
    end
end

Helper Functions

julia
function loss_function(model, ps, st, (x, y, adj, mask))
    y_pred, st = model((x, adj, mask), ps, st)
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
    return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100

Training the Model

julia
function main(;
    hidden_dim::Int=64,
    dropout::Float64=0.1,
    nb_layers::Int=2,
    use_bias::Bool=true,
    lr::Float64=0.001,
    weight_decay::Float64=0.0,
    patience::Int=20,
    epochs::Int=200,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())

    gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
    ps, st = xdev(Lux.setup(rng, gcn))
    opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

    train_state = Training.TrainState(gcn, ps, st, opt)

    @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)

    val_loss_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
    end

    train_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
    end
    val_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
    end

    best_loss_val = Inf
    cnt = 0

    for epoch in 1:epochs
        (_, loss, _, train_state) = Lux.Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (features, targets, adj, train_idx),
            train_state;
            return_gradients=Val(false),
        )
        train_acc = accuracy(
            Array(
                train_model_compiled(
                    (features, adj, train_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, train_idx],
        )

        val_loss = first(
            val_loss_compiled(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, val_idx),
            ),
        )
        val_acc = accuracy(
            Array(
                val_model_compiled(
                    (features, adj, val_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, val_idx],
        )

        @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
                 Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

        if val_loss < best_loss_val
            best_loss_val = val_loss
            cnt = 0
        else
            cnt += 1
            if cnt == patience
                @printf "Early Stopping at Epoch %d\n" epoch
                break
            end
        end
    end

    Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        test_loss = @jit(
            loss_function(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, test_idx),
            )
        )[1]
        test_acc = accuracy(
            Array(
                @jit(
                    gcn(
                        (features, adj, test_idx),
                        train_state.parameters,
                        Lux.testmode(train_state.states),
                    )
                )[1],
            ),
            Array(targets)[:, test_idx],
        )

        @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
    end
    return nothing
end

main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1757913805.102953 2638698 service.cc:163] XLA service 0x43d65de0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757913805.103028 2638698 service.cc:171]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1757913805.103929 2638698 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1757913805.103970 2638698 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1757913805.104005 2638698 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1757913805.115381 2638698 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch   1	Train Loss: 16.541618	Train Acc: 22.1429%	Val Loss: 7.779559	Val Acc: 25.0000%
Epoch   2	Train Loss: 8.098966	Train Acc: 21.4286%	Val Loss: 4.113542	Val Acc: 28.6000%
Epoch   3	Train Loss: 4.958270	Train Acc: 42.1429%	Val Loss: 1.929405	Val Acc: 41.4000%
Epoch   4	Train Loss: 1.731582	Train Acc: 53.5714%	Val Loss: 2.011568	Val Acc: 41.8000%
Epoch   5	Train Loss: 2.024558	Train Acc: 59.2857%	Val Loss: 1.961110	Val Acc: 44.0000%
Epoch   6	Train Loss: 1.851948	Train Acc: 67.1429%	Val Loss: 1.736604	Val Acc: 50.2000%
Epoch   7	Train Loss: 1.185628	Train Acc: 73.5714%	Val Loss: 1.604445	Val Acc: 56.2000%
Epoch   8	Train Loss: 1.406230	Train Acc: 75.0000%	Val Loss: 1.536368	Val Acc: 61.4000%
Epoch   9	Train Loss: 1.216595	Train Acc: 76.4286%	Val Loss: 1.512406	Val Acc: 63.0000%
Epoch  10	Train Loss: 1.242570	Train Acc: 77.8571%	Val Loss: 1.486743	Val Acc: 62.8000%
Epoch  11	Train Loss: 1.076788	Train Acc: 80.7143%	Val Loss: 1.468861	Val Acc: 62.2000%
Epoch  12	Train Loss: 0.981122	Train Acc: 81.4286%	Val Loss: 1.454582	Val Acc: 64.0000%
Epoch  13	Train Loss: 0.919534	Train Acc: 81.4286%	Val Loss: 1.448876	Val Acc: 66.0000%
Epoch  14	Train Loss: 0.804165	Train Acc: 81.4286%	Val Loss: 1.454256	Val Acc: 65.0000%
Epoch  15	Train Loss: 0.942576	Train Acc: 82.8571%	Val Loss: 1.465008	Val Acc: 65.6000%
Epoch  16	Train Loss: 0.744699	Train Acc: 85.0000%	Val Loss: 1.512742	Val Acc: 65.2000%
Epoch  17	Train Loss: 0.901531	Train Acc: 85.7143%	Val Loss: 1.555849	Val Acc: 64.8000%
Epoch  18	Train Loss: 0.682303	Train Acc: 85.7143%	Val Loss: 1.594888	Val Acc: 64.8000%
Epoch  19	Train Loss: 0.716223	Train Acc: 86.4286%	Val Loss: 1.581796	Val Acc: 65.0000%
Epoch  20	Train Loss: 0.622499	Train Acc: 86.4286%	Val Loss: 1.550115	Val Acc: 65.6000%
Epoch  21	Train Loss: 0.668506	Train Acc: 87.1429%	Val Loss: 1.517670	Val Acc: 66.4000%
Epoch  22	Train Loss: 0.599883	Train Acc: 87.1429%	Val Loss: 1.483986	Val Acc: 67.6000%
Epoch  23	Train Loss: 0.642122	Train Acc: 87.1429%	Val Loss: 1.458639	Val Acc: 68.4000%
Epoch  24	Train Loss: 0.516070	Train Acc: 87.1429%	Val Loss: 1.436628	Val Acc: 68.8000%
Epoch  25	Train Loss: 0.632938	Train Acc: 86.4286%	Val Loss: 1.429646	Val Acc: 68.0000%
Epoch  26	Train Loss: 0.490367	Train Acc: 87.1429%	Val Loss: 1.429122	Val Acc: 68.6000%
Epoch  27	Train Loss: 0.557422	Train Acc: 87.8571%	Val Loss: 1.438560	Val Acc: 67.2000%
Epoch  28	Train Loss: 0.489264	Train Acc: 87.8571%	Val Loss: 1.455575	Val Acc: 66.8000%
Epoch  29	Train Loss: 0.464357	Train Acc: 87.8571%	Val Loss: 1.475377	Val Acc: 66.6000%
Epoch  30	Train Loss: 0.522491	Train Acc: 86.4286%	Val Loss: 1.503930	Val Acc: 65.8000%
Epoch  31	Train Loss: 0.449556	Train Acc: 86.4286%	Val Loss: 1.538563	Val Acc: 65.4000%
Epoch  32	Train Loss: 0.504742	Train Acc: 87.8571%	Val Loss: 1.551327	Val Acc: 65.2000%
Epoch  33	Train Loss: 0.574246	Train Acc: 92.1429%	Val Loss: 1.607988	Val Acc: 65.8000%
Epoch  34	Train Loss: 0.449091	Train Acc: 92.1429%	Val Loss: 1.711923	Val Acc: 65.4000%
Epoch  35	Train Loss: 0.363941	Train Acc: 90.7143%	Val Loss: 1.833322	Val Acc: 66.0000%
Epoch  36	Train Loss: 0.400379	Train Acc: 90.7143%	Val Loss: 1.938784	Val Acc: 64.6000%
Epoch  37	Train Loss: 0.529520	Train Acc: 90.0000%	Val Loss: 1.995430	Val Acc: 65.0000%
Epoch  38	Train Loss: 0.600717	Train Acc: 90.0000%	Val Loss: 1.997332	Val Acc: 65.6000%
Epoch  39	Train Loss: 0.599057	Train Acc: 92.1429%	Val Loss: 1.945721	Val Acc: 65.8000%
Epoch  40	Train Loss: 0.509541	Train Acc: 92.1429%	Val Loss: 1.866676	Val Acc: 66.4000%
Epoch  41	Train Loss: 0.447534	Train Acc: 92.8571%	Val Loss: 1.780149	Val Acc: 67.4000%
Epoch  42	Train Loss: 0.347354	Train Acc: 94.2857%	Val Loss: 1.694155	Val Acc: 68.4000%
Epoch  43	Train Loss: 0.399075	Train Acc: 95.0000%	Val Loss: 1.628220	Val Acc: 67.4000%
Epoch  44	Train Loss: 0.368840	Train Acc: 95.0000%	Val Loss: 1.579233	Val Acc: 67.2000%
Epoch  45	Train Loss: 0.365899	Train Acc: 95.0000%	Val Loss: 1.555665	Val Acc: 67.0000%
Epoch  46	Train Loss: 0.308438	Train Acc: 95.0000%	Val Loss: 1.548830	Val Acc: 67.8000%
Early Stopping at Epoch 46
Test Loss: 1.341541	Test Acc: 70.5000%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.