Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
end
Model Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end
Helper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1758301474.460729 1195006 service.cc:158] XLA service 0x43f3d290 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1758301474.460797 1195006 service.cc:166] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1758301474.461718 1195006 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1758301474.461756 1195006 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1758301474.461800 1195006 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1758301474.472947 1195006 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-11/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-11/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch 1 Train Loss: 14.279928 Train Acc: 20.0000% Val Loss: 7.631112 Val Acc: 22.4000%
Epoch 2 Train Loss: 9.303165 Train Acc: 23.5714% Val Loss: 3.934845 Val Acc: 29.6000%
Epoch 3 Train Loss: 4.114150 Train Acc: 39.2857% Val Loss: 2.177299 Val Acc: 35.4000%
Epoch 4 Train Loss: 2.119196 Train Acc: 53.5714% Val Loss: 1.952470 Val Acc: 40.4000%
Epoch 5 Train Loss: 2.029362 Train Acc: 60.7143% Val Loss: 1.825472 Val Acc: 45.8000%
Epoch 6 Train Loss: 1.618327 Train Acc: 63.5714% Val Loss: 1.822781 Val Acc: 49.4000%
Epoch 7 Train Loss: 1.492998 Train Acc: 67.8571% Val Loss: 1.757256 Val Acc: 54.4000%
Epoch 8 Train Loss: 1.486185 Train Acc: 72.8571% Val Loss: 1.699004 Val Acc: 57.0000%
Epoch 9 Train Loss: 1.312258 Train Acc: 75.7143% Val Loss: 1.622224 Val Acc: 59.8000%
Epoch 10 Train Loss: 1.150110 Train Acc: 76.4286% Val Loss: 1.570092 Val Acc: 63.8000%
Epoch 11 Train Loss: 1.230214 Train Acc: 76.4286% Val Loss: 1.536865 Val Acc: 65.0000%
Epoch 12 Train Loss: 1.196295 Train Acc: 77.8571% Val Loss: 1.510300 Val Acc: 65.2000%
Epoch 13 Train Loss: 1.054720 Train Acc: 80.0000% Val Loss: 1.496862 Val Acc: 65.2000%
Epoch 14 Train Loss: 0.952108 Train Acc: 81.4286% Val Loss: 1.489314 Val Acc: 65.0000%
Epoch 15 Train Loss: 1.367762 Train Acc: 81.4286% Val Loss: 1.479288 Val Acc: 65.6000%
Epoch 16 Train Loss: 1.090928 Train Acc: 83.5714% Val Loss: 1.514651 Val Acc: 66.2000%
Epoch 17 Train Loss: 0.854440 Train Acc: 83.5714% Val Loss: 1.593374 Val Acc: 65.8000%
Epoch 18 Train Loss: 0.711868 Train Acc: 81.4286% Val Loss: 1.675784 Val Acc: 64.4000%
Epoch 19 Train Loss: 0.863184 Train Acc: 81.4286% Val Loss: 1.719823 Val Acc: 64.2000%
Epoch 20 Train Loss: 0.930976 Train Acc: 82.1429% Val Loss: 1.704442 Val Acc: 65.2000%
Epoch 21 Train Loss: 0.861661 Train Acc: 84.2857% Val Loss: 1.646146 Val Acc: 65.6000%
Epoch 22 Train Loss: 0.722063 Train Acc: 86.4286% Val Loss: 1.581254 Val Acc: 66.4000%
Epoch 23 Train Loss: 0.777586 Train Acc: 88.5714% Val Loss: 1.521695 Val Acc: 67.0000%
Epoch 24 Train Loss: 0.665991 Train Acc: 89.2857% Val Loss: 1.485319 Val Acc: 69.4000%
Epoch 25 Train Loss: 0.526629 Train Acc: 89.2857% Val Loss: 1.469275 Val Acc: 69.8000%
Epoch 26 Train Loss: 0.627344 Train Acc: 88.5714% Val Loss: 1.468050 Val Acc: 69.4000%
Epoch 27 Train Loss: 0.550234 Train Acc: 88.5714% Val Loss: 1.478206 Val Acc: 69.0000%
Epoch 28 Train Loss: 0.468360 Train Acc: 88.5714% Val Loss: 1.492301 Val Acc: 69.0000%
Epoch 29 Train Loss: 0.723738 Train Acc: 88.5714% Val Loss: 1.506597 Val Acc: 68.6000%
Epoch 30 Train Loss: 0.563018 Train Acc: 90.0000% Val Loss: 1.520076 Val Acc: 68.8000%
Epoch 31 Train Loss: 0.460821 Train Acc: 90.0000% Val Loss: 1.539181 Val Acc: 68.4000%
Epoch 32 Train Loss: 0.493969 Train Acc: 90.7143% Val Loss: 1.563661 Val Acc: 68.8000%
Epoch 33 Train Loss: 0.434047 Train Acc: 91.4286% Val Loss: 1.592195 Val Acc: 68.0000%
Epoch 34 Train Loss: 0.489974 Train Acc: 92.1429% Val Loss: 1.622337 Val Acc: 67.0000%
Epoch 35 Train Loss: 0.387495 Train Acc: 92.8571% Val Loss: 1.655083 Val Acc: 65.8000%
Epoch 36 Train Loss: 0.411174 Train Acc: 92.8571% Val Loss: 1.688826 Val Acc: 65.0000%
Epoch 37 Train Loss: 0.416674 Train Acc: 92.8571% Val Loss: 1.724461 Val Acc: 65.0000%
Epoch 38 Train Loss: 0.401107 Train Acc: 92.8571% Val Loss: 1.752678 Val Acc: 64.8000%
Epoch 39 Train Loss: 0.364621 Train Acc: 93.5714% Val Loss: 1.772014 Val Acc: 64.8000%
Epoch 40 Train Loss: 0.363665 Train Acc: 94.2857% Val Loss: 1.782372 Val Acc: 64.8000%
Epoch 41 Train Loss: 0.399835 Train Acc: 94.2857% Val Loss: 1.781208 Val Acc: 64.8000%
Epoch 42 Train Loss: 0.360145 Train Acc: 94.2857% Val Loss: 1.773606 Val Acc: 65.0000%
Epoch 43 Train Loss: 0.379509 Train Acc: 95.0000% Val Loss: 1.764072 Val Acc: 65.6000%
Epoch 44 Train Loss: 0.329097 Train Acc: 95.0000% Val Loss: 1.756897 Val Acc: 66.2000%
Epoch 45 Train Loss: 0.358217 Train Acc: 94.2857% Val Loss: 1.754727 Val Acc: 66.6000%
Epoch 46 Train Loss: 0.352214 Train Acc: 94.2857% Val Loss: 1.746168 Val Acc: 67.2000%
Early Stopping at Epoch 46
Test Loss: 1.545399 Test Acc: 68.3000%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.