Skip to content

Graph Convolutional Networks on Cora

This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.

julia
using Lux,
    Reactant,
    MLDatasets,
    Random,
    Statistics,
    GNNGraphs,
    ConcreteStructs,
    Printf,
    OneHotArrays,
    Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)

Loading Cora Dataset

julia
function loadcora()
    data = Cora()
    gph = data.graphs[1]
    gnngraph = GNNGraph(
        gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
    )
    return (
        gph.node_data.features,
        onehotbatch(gph.node_data.targets, data.metadata["classes"]),
        # We use a dense matrix here to avoid incompatibility with Reactant
        Matrix{Int32}(adjacency_matrix(gnngraph)),
        # We use this since Reactant doesn't yet support gather adjoint
        (1:140, 141:640, 1709:2708),
    )
end
loadcora (generic function with 1 method)

Model Definition

julia
function GCNLayer(args...; kwargs...)
    return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
        @return dense(x) * adj
    end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
    layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
    gcn_layers = [
        GCNLayer(in_dim => out_dim; kwargs...) for
        (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
    ]
    last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
    dropout = Dropout(dropout)

    return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
        for layer in gcn_layers
            x = relu.(layer((x, adj)))
            x = dropout(x)
        end
        @return last_layer((x, adj))[:, mask]
    end
end
GCN (generic function with 1 method)

Helper Functions

julia
function loss_function(model, ps, st, (x, y, adj, mask))
    y_pred, st = model((x, adj, mask), ps, st)
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
    return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
accuracy (generic function with 1 method)

Training the Model

julia
function main(;
    hidden_dim::Int=64,
    dropout::Float64=0.1,
    nb_layers::Int=2,
    use_bias::Bool=true,
    lr::Float64=0.001,
    weight_decay::Float64=0.0,
    patience::Int=20,
    epochs::Int=200,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())

    gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
    ps, st = xdev(Lux.setup(rng, gcn))
    opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

    train_state = Training.TrainState(gcn, ps, st, opt)

    @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)

    val_loss_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
    end

    train_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
    end
    val_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
    end

    best_loss_val = Inf
    cnt = 0

    for epoch in 1:epochs
        (_, loss, _, train_state) = Lux.Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (features, targets, adj, train_idx),
            train_state;
            return_gradients=Val(false),
        )
        train_acc = accuracy(
            Array(
                train_model_compiled(
                    (features, adj, train_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, train_idx],
        )

        val_loss = first(
            val_loss_compiled(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, val_idx),
            ),
        )
        val_acc = accuracy(
            Array(
                val_model_compiled(
                    (features, adj, val_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, val_idx],
        )

        @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
                 Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

        if val_loss < best_loss_val
            best_loss_val = val_loss
            cnt = 0
        else
            cnt += 1
            if cnt == patience
                @printf "Early Stopping at Epoch %d\n" epoch
                break
            end
        end
    end

    Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        test_loss = @jit(
            loss_function(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, test_idx),
            )
        )[1]
        test_acc = accuracy(
            Array(
                @jit(
                    gcn(
                        (features, adj, test_idx),
                        train_state.parameters,
                        Lux.testmode(train_state.states),
                    )
                )[1],
            ),
            Array(targets)[:, test_idx],
        )

        @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
    end
    return nothing
end

main()
Precompiling EnzymeBFloat16sExt...
   7151.9 ms  ✓ Enzyme → EnzymeBFloat16sExt
  1 dependency successfully precompiled in 7 seconds. 47 already precompiled.
2025-07-09 04:32:38.704370: I external/xla/xla/service/service.cc:153] XLA service 0x2ccedef0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-09 04:32:38.704478: I external/xla/xla/service/service.cc:161]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752035558.705421 1227177 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752035558.705563 1227177 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752035558.705641 1227177 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752035558.718254 1227177 cuda_dnn.cc:471] Loaded cuDNN version 90800
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{false}()` but is being used within an autodiff call (gradient, jacobian, etc...). This might lead to incorrect results. If you are using a `Lux.jl` model, set it to training mode using `LuxCore.trainmode`.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:344
Epoch   1	Train Loss: 14.544163	Train Acc: 20.0000%	Val Loss: 6.816375	Val Acc: 24.8000%
Epoch   2	Train Loss: 8.176273	Train Acc: 26.4286%	Val Loss: 2.794208	Val Acc: 28.6000%
Epoch   3	Train Loss: 3.054783	Train Acc: 38.5714%	Val Loss: 2.363533	Val Acc: 35.8000%
Epoch   4	Train Loss: 2.530659	Train Acc: 57.1429%	Val Loss: 1.836085	Val Acc: 42.8000%
Epoch   5	Train Loss: 1.810864	Train Acc: 64.2857%	Val Loss: 1.648743	Val Acc: 49.0000%
Epoch   6	Train Loss: 1.389204	Train Acc: 67.8571%	Val Loss: 1.555167	Val Acc: 56.6000%
Epoch   7	Train Loss: 1.305906	Train Acc: 69.2857%	Val Loss: 1.499859	Val Acc: 59.8000%
Epoch   8	Train Loss: 1.251253	Train Acc: 70.7143%	Val Loss: 1.492746	Val Acc: 60.0000%
Epoch   9	Train Loss: 1.804210	Train Acc: 76.4286%	Val Loss: 1.428393	Val Acc: 63.4000%
Epoch  10	Train Loss: 1.025554	Train Acc: 77.8571%	Val Loss: 1.423040	Val Acc: 63.4000%
Epoch  11	Train Loss: 2.099176	Train Acc: 77.8571%	Val Loss: 1.487065	Val Acc: 63.8000%
Epoch  12	Train Loss: 0.933550	Train Acc: 80.0000%	Val Loss: 1.584491	Val Acc: 62.0000%
Epoch  13	Train Loss: 0.949730	Train Acc: 80.0000%	Val Loss: 1.666309	Val Acc: 61.0000%
Epoch  14	Train Loss: 1.165631	Train Acc: 80.7143%	Val Loss: 1.731187	Val Acc: 61.2000%
Epoch  15	Train Loss: 0.927006	Train Acc: 80.7143%	Val Loss: 1.757825	Val Acc: 61.4000%
Epoch  16	Train Loss: 0.953845	Train Acc: 82.8571%	Val Loss: 1.756999	Val Acc: 62.2000%
Epoch  17	Train Loss: 0.983249	Train Acc: 82.8571%	Val Loss: 1.739863	Val Acc: 63.4000%
Epoch  18	Train Loss: 0.905488	Train Acc: 85.7143%	Val Loss: 1.707660	Val Acc: 64.2000%
Epoch  19	Train Loss: 0.984189	Train Acc: 85.7143%	Val Loss: 1.669651	Val Acc: 65.0000%
Epoch  20	Train Loss: 0.875392	Train Acc: 82.1429%	Val Loss: 1.634181	Val Acc: 66.8000%
Epoch  21	Train Loss: 0.798192	Train Acc: 83.5714%	Val Loss: 1.595893	Val Acc: 67.0000%
Epoch  22	Train Loss: 0.641376	Train Acc: 84.2857%	Val Loss: 1.561626	Val Acc: 67.0000%
Epoch  23	Train Loss: 0.560968	Train Acc: 85.0000%	Val Loss: 1.542006	Val Acc: 66.4000%
Epoch  24	Train Loss: 0.679021	Train Acc: 87.1429%	Val Loss: 1.536416	Val Acc: 66.0000%
Epoch  25	Train Loss: 0.524915	Train Acc: 87.8571%	Val Loss: 1.547432	Val Acc: 66.4000%
Epoch  26	Train Loss: 0.753763	Train Acc: 87.8571%	Val Loss: 1.569046	Val Acc: 65.4000%
Epoch  27	Train Loss: 0.476359	Train Acc: 87.8571%	Val Loss: 1.598529	Val Acc: 65.0000%
Epoch  28	Train Loss: 0.478846	Train Acc: 87.8571%	Val Loss: 1.621054	Val Acc: 64.4000%
Epoch  29	Train Loss: 0.527370	Train Acc: 90.0000%	Val Loss: 1.658599	Val Acc: 65.4000%
Epoch  30	Train Loss: 0.469310	Train Acc: 90.7143%	Val Loss: 1.683416	Val Acc: 66.2000%
Early Stopping at Epoch 30
Test Loss: 1.474484	Test Acc: 68.8000%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.