Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
end
Model Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end
Helper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1757734448.307756 570007 service.cc:163] XLA service 0x2e27e540 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757734448.307937 570007 service.cc:171] StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
I0000 00:00:1757734448.308014 570007 service.cc:171] StreamExecutor device (1): Quadro RTX 5000, Compute Capability 7.5
I0000 00:00:1757734448.314613 570007 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1757734448.314681 570007 gpu_helpers.cc:136] XLA backend allocating 12526534656 bytes on device 0 for BFCAllocator.
I0000 00:00:1757734448.314744 570007 gpu_helpers.cc:136] XLA backend allocating 12526534656 bytes on device 1 for BFCAllocator.
I0000 00:00:1757734448.314767 570007 gpu_helpers.cc:177] XLA backend will use up to 4175511552 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1757734448.314792 570007 gpu_helpers.cc:177] XLA backend will use up to 4175511552 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1757734448.327068 570007 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-16/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-16/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch 1 Train Loss: 16.336132 Train Acc: 22.1429% Val Loss: 7.009547 Val Acc: 22.8000%
Epoch 2 Train Loss: 8.029594 Train Acc: 20.7143% Val Loss: 3.027081 Val Acc: 29.2000%
Epoch 3 Train Loss: 4.390297 Train Acc: 42.1429% Val Loss: 1.811692 Val Acc: 40.4000%
Epoch 4 Train Loss: 1.927010 Train Acc: 55.7143% Val Loss: 1.848529 Val Acc: 43.4000%
Epoch 5 Train Loss: 1.747656 Train Acc: 63.5714% Val Loss: 1.859883 Val Acc: 43.8000%
Epoch 6 Train Loss: 1.636206 Train Acc: 70.0000% Val Loss: 1.754520 Val Acc: 51.0000%
Epoch 7 Train Loss: 1.610943 Train Acc: 72.8571% Val Loss: 1.633135 Val Acc: 56.4000%
Epoch 8 Train Loss: 1.470763 Train Acc: 77.1429% Val Loss: 1.545484 Val Acc: 60.2000%
Epoch 9 Train Loss: 1.298347 Train Acc: 77.1429% Val Loss: 1.483948 Val Acc: 62.8000%
Epoch 10 Train Loss: 1.224091 Train Acc: 80.7143% Val Loss: 1.435608 Val Acc: 64.6000%
Epoch 11 Train Loss: 1.041614 Train Acc: 80.0000% Val Loss: 1.408752 Val Acc: 65.8000%
Epoch 12 Train Loss: 1.044359 Train Acc: 80.7143% Val Loss: 1.401051 Val Acc: 66.2000%
Epoch 13 Train Loss: 1.014527 Train Acc: 80.7143% Val Loss: 1.405285 Val Acc: 66.2000%
Epoch 14 Train Loss: 0.819150 Train Acc: 81.4286% Val Loss: 1.417203 Val Acc: 66.0000%
Epoch 15 Train Loss: 0.876931 Train Acc: 84.2857% Val Loss: 1.417138 Val Acc: 66.8000%
Epoch 16 Train Loss: 1.664872 Train Acc: 85.0000% Val Loss: 1.410902 Val Acc: 67.6000%
Epoch 17 Train Loss: 0.667342 Train Acc: 85.7143% Val Loss: 1.458553 Val Acc: 67.2000%
Epoch 18 Train Loss: 0.756876 Train Acc: 84.2857% Val Loss: 1.518630 Val Acc: 67.0000%
Epoch 19 Train Loss: 0.867558 Train Acc: 85.0000% Val Loss: 1.572213 Val Acc: 66.8000%
Epoch 20 Train Loss: 0.751877 Train Acc: 85.7143% Val Loss: 1.611869 Val Acc: 66.2000%
Epoch 21 Train Loss: 0.624684 Train Acc: 85.7143% Val Loss: 1.631795 Val Acc: 66.6000%
Epoch 22 Train Loss: 0.981921 Train Acc: 87.1429% Val Loss: 1.626106 Val Acc: 66.6000%
Epoch 23 Train Loss: 0.795467 Train Acc: 88.5714% Val Loss: 1.606040 Val Acc: 66.0000%
Epoch 24 Train Loss: 0.898357 Train Acc: 88.5714% Val Loss: 1.573473 Val Acc: 66.8000%
Epoch 25 Train Loss: 0.698883 Train Acc: 88.5714% Val Loss: 1.541680 Val Acc: 68.0000%
Epoch 26 Train Loss: 0.645220 Train Acc: 88.5714% Val Loss: 1.514167 Val Acc: 68.6000%
Epoch 27 Train Loss: 0.613002 Train Acc: 89.2857% Val Loss: 1.493463 Val Acc: 68.4000%
Epoch 28 Train Loss: 0.565419 Train Acc: 88.5714% Val Loss: 1.479156 Val Acc: 68.8000%
Epoch 29 Train Loss: 0.491974 Train Acc: 89.2857% Val Loss: 1.470922 Val Acc: 67.8000%
Epoch 30 Train Loss: 0.447017 Train Acc: 90.7143% Val Loss: 1.466520 Val Acc: 67.8000%
Epoch 31 Train Loss: 0.481589 Train Acc: 90.0000% Val Loss: 1.467708 Val Acc: 68.2000%
Epoch 32 Train Loss: 0.466854 Train Acc: 90.7143% Val Loss: 1.474037 Val Acc: 67.6000%
Early Stopping at Epoch 32
Test Loss: 1.318466 Test Acc: 70.3000%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.