Skip to content

Graph Convolutional Networks on Cora

This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.

julia
using Lux,
    Reactant,
    MLDatasets,
    Random,
    Statistics,
    GNNGraphs,
    ConcreteStructs,
    Printf,
    OneHotArrays,
    Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)

Loading Cora Dataset

julia
function loadcora()
    data = Cora()
    gph = data.graphs[1]
    gnngraph = GNNGraph(
        gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
    )
    return (
        gph.node_data.features,
        onehotbatch(gph.node_data.targets, data.metadata["classes"]),
        # We use a dense matrix here to avoid incompatibility with Reactant
        Matrix{Int32}(adjacency_matrix(gnngraph)),
        # We use this since Reactant doesn't yet support gather adjoint
        (1:140, 141:640, 1709:2708),
    )
end
loadcora (generic function with 1 method)

Model Definition

julia
function GCNLayer(args...; kwargs...)
    return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
        @return dense(x) * adj
    end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
    layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
    gcn_layers = [
        GCNLayer(in_dim => out_dim; kwargs...) for
        (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
    ]
    last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
    dropout = Dropout(dropout)

    return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
        for layer in gcn_layers
            x = relu.(layer((x, adj)))
            x = dropout(x)
        end
        @return last_layer((x, adj))[:, mask]
    end
end
GCN (generic function with 1 method)

Helper Functions

julia
function loss_function(model, ps, st, (x, y, adj, mask))
    y_pred, st = model((x, adj, mask), ps, st)
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
    return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
accuracy (generic function with 1 method)

Training the Model

julia
function main(;
    hidden_dim::Int=64,
    dropout::Float64=0.1,
    nb_layers::Int=2,
    use_bias::Bool=true,
    lr::Float64=0.001,
    weight_decay::Float64=0.0,
    patience::Int=20,
    epochs::Int=200,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())

    gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
    ps, st = xdev(Lux.setup(rng, gcn))
    opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

    train_state = Training.TrainState(gcn, ps, st, opt)

    @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)

    val_loss_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
    end

    train_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
    end
    val_model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
    end

    best_loss_val = Inf
    cnt = 0

    for epoch in 1:epochs
        (_, loss, _, train_state) = Lux.Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (features, targets, adj, train_idx),
            train_state;
            return_gradients=Val(false),
        )
        train_acc = accuracy(
            Array(
                train_model_compiled(
                    (features, adj, train_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, train_idx],
        )

        val_loss = first(
            val_loss_compiled(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, val_idx),
            ),
        )
        val_acc = accuracy(
            Array(
                val_model_compiled(
                    (features, adj, val_idx),
                    train_state.parameters,
                    Lux.testmode(train_state.states),
                )[1],
            ),
            Array(targets)[:, val_idx],
        )

        @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
                 Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

        if val_loss < best_loss_val
            best_loss_val = val_loss
            cnt = 0
        else
            cnt += 1
            if cnt == patience
                @printf "Early Stopping at Epoch %d\n" epoch
                break
            end
        end
    end

    Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        test_loss = @jit(
            loss_function(
                gcn,
                train_state.parameters,
                Lux.testmode(train_state.states),
                (features, targets, adj, test_idx),
            )
        )[1]
        test_acc = accuracy(
            Array(
                @jit(
                    gcn(
                        (features, adj, test_idx),
                        train_state.parameters,
                        Lux.testmode(train_state.states),
                    )
                )[1],
            ),
            Array(targets)[:, test_idx],
        )

        @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
    end
    return nothing
end

main()
Precompiling BFloat16sExt...
   1259.3 ms  ✓ LLVM → BFloat16sExt
  1 dependency successfully precompiled in 1 seconds. 30 already precompiled.
Precompiling EnzymeBFloat16sExt...
   6487.8 ms  ✓ Enzyme → EnzymeBFloat16sExt
  1 dependency successfully precompiled in 7 seconds. 47 already precompiled.
2025-07-09 04:15:43.185369: I external/xla/xla/service/service.cc:153] XLA service 0x3d2a0530 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-09 04:15:43.185484: I external/xla/xla/service/service.cc:161]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752034543.186674 1143982 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752034543.186799 1143982 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752034543.186873 1143982 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752034543.201715 1143982 cuda_dnn.cc:471] Loaded cuDNN version 90800
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-4/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{false}()` but is being used within an autodiff call (gradient, jacobian, etc...). This might lead to incorrect results. If you are using a `Lux.jl` model, set it to training mode using `LuxCore.trainmode`.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-4/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:344
Epoch   1	Train Loss: 16.066574	Train Acc: 22.1429%	Val Loss: 6.879176	Val Acc: 24.8000%
Epoch   2	Train Loss: 7.957355	Train Acc: 29.2857%	Val Loss: 2.597338	Val Acc: 33.2000%
Epoch   3	Train Loss: 4.583193	Train Acc: 50.0000%	Val Loss: 1.744905	Val Acc: 46.6000%
Epoch   4	Train Loss: 1.908341	Train Acc: 56.4286%	Val Loss: 1.791154	Val Acc: 46.4000%
Epoch   5	Train Loss: 1.673450	Train Acc: 64.2857%	Val Loss: 1.683114	Val Acc: 50.8000%
Epoch   6	Train Loss: 1.478799	Train Acc: 73.5714%	Val Loss: 1.502969	Val Acc: 57.6000%
Epoch   7	Train Loss: 1.157417	Train Acc: 75.0000%	Val Loss: 1.401800	Val Acc: 62.0000%
Epoch   8	Train Loss: 1.141568	Train Acc: 78.5714%	Val Loss: 1.397459	Val Acc: 61.4000%
Epoch   9	Train Loss: 1.708609	Train Acc: 79.2857%	Val Loss: 1.401590	Val Acc: 62.2000%
Epoch  10	Train Loss: 1.051392	Train Acc: 82.1429%	Val Loss: 1.415236	Val Acc: 64.8000%
Epoch  11	Train Loss: 0.994985	Train Acc: 83.5714%	Val Loss: 1.430862	Val Acc: 64.8000%
Epoch  12	Train Loss: 0.982618	Train Acc: 83.5714%	Val Loss: 1.442781	Val Acc: 65.8000%
Epoch  13	Train Loss: 0.822561	Train Acc: 85.0000%	Val Loss: 1.447394	Val Acc: 66.4000%
Epoch  14	Train Loss: 0.920512	Train Acc: 83.5714%	Val Loss: 1.458303	Val Acc: 67.6000%
Epoch  15	Train Loss: 0.812066	Train Acc: 84.2857%	Val Loss: 1.465748	Val Acc: 67.8000%
Epoch  16	Train Loss: 0.761003	Train Acc: 84.2857%	Val Loss: 1.458907	Val Acc: 67.6000%
Epoch  17	Train Loss: 0.758469	Train Acc: 87.1429%	Val Loss: 1.450104	Val Acc: 67.8000%
Epoch  18	Train Loss: 0.774662	Train Acc: 87.1429%	Val Loss: 1.424158	Val Acc: 67.8000%
Epoch  19	Train Loss: 0.718235	Train Acc: 87.8571%	Val Loss: 1.400826	Val Acc: 67.8000%
Epoch  20	Train Loss: 0.575362	Train Acc: 88.5714%	Val Loss: 1.392055	Val Acc: 68.4000%
Epoch  21	Train Loss: 0.553364	Train Acc: 89.2857%	Val Loss: 1.394837	Val Acc: 67.6000%
Epoch  22	Train Loss: 0.523534	Train Acc: 89.2857%	Val Loss: 1.409594	Val Acc: 68.4000%
Epoch  23	Train Loss: 0.547120	Train Acc: 89.2857%	Val Loss: 1.430590	Val Acc: 68.0000%
Epoch  24	Train Loss: 0.593029	Train Acc: 88.5714%	Val Loss: 1.458196	Val Acc: 68.6000%
Epoch  25	Train Loss: 0.564101	Train Acc: 90.7143%	Val Loss: 1.486715	Val Acc: 68.4000%
Epoch  26	Train Loss: 0.478272	Train Acc: 89.2857%	Val Loss: 1.523585	Val Acc: 67.8000%
Epoch  27	Train Loss: 0.456502	Train Acc: 91.4286%	Val Loss: 1.561760	Val Acc: 67.4000%
Epoch  28	Train Loss: 0.447176	Train Acc: 91.4286%	Val Loss: 1.590273	Val Acc: 67.2000%
Epoch  29	Train Loss: 0.418273	Train Acc: 92.1429%	Val Loss: 1.606814	Val Acc: 67.8000%
Epoch  30	Train Loss: 0.502453	Train Acc: 92.8571%	Val Loss: 1.615381	Val Acc: 67.8000%
Epoch  31	Train Loss: 0.443167	Train Acc: 93.5714%	Val Loss: 1.620447	Val Acc: 67.8000%
Epoch  32	Train Loss: 0.367190	Train Acc: 93.5714%	Val Loss: 1.625762	Val Acc: 67.2000%
Epoch  33	Train Loss: 0.507516	Train Acc: 93.5714%	Val Loss: 1.665179	Val Acc: 67.2000%
Epoch  34	Train Loss: 0.459586	Train Acc: 93.5714%	Val Loss: 1.682184	Val Acc: 67.2000%
Epoch  35	Train Loss: 0.396707	Train Acc: 93.5714%	Val Loss: 1.687292	Val Acc: 67.0000%
Epoch  36	Train Loss: 0.330848	Train Acc: 93.5714%	Val Loss: 1.685548	Val Acc: 67.2000%
Epoch  37	Train Loss: 0.364265	Train Acc: 94.2857%	Val Loss: 1.687667	Val Acc: 67.2000%
Epoch  38	Train Loss: 0.378305	Train Acc: 95.0000%	Val Loss: 1.683174	Val Acc: 67.6000%
Epoch  39	Train Loss: 0.454592	Train Acc: 95.0000%	Val Loss: 1.724287	Val Acc: 67.4000%
Epoch  40	Train Loss: 0.329054	Train Acc: 95.0000%	Val Loss: 1.763165	Val Acc: 67.0000%
Early Stopping at Epoch 40
Test Loss: 1.631037	Test Acc: 68.3000%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.