Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
end
Model Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end
Helper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1760849726.638538 73885 service.cc:158] XLA service 0x12cb8050 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760849726.638607 73885 service.cc:166] StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
I0000 00:00:1760849726.639504 73885 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1760849726.639544 73885 gpu_helpers.cc:136] XLA backend allocating 12526534656 bytes on device 0 for BFCAllocator.
I0000 00:00:1760849726.639590 73885 gpu_helpers.cc:177] XLA backend will use up to 4175511552 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1760849726.647744 73885 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-17/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-17/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch 1 Train Loss: 15.483308 Train Acc: 22.1429% Val Loss: 7.571783 Val Acc: 25.8000%
Epoch 2 Train Loss: 10.125030 Train Acc: 22.1429% Val Loss: 3.797886 Val Acc: 29.4000%
Epoch 3 Train Loss: 4.467242 Train Acc: 37.8571% Val Loss: 2.431701 Val Acc: 32.0000%
Epoch 4 Train Loss: 2.424877 Train Acc: 51.4286% Val Loss: 2.113642 Val Acc: 37.8000%
Epoch 5 Train Loss: 1.761382 Train Acc: 58.5714% Val Loss: 1.889251 Val Acc: 45.0000%
Epoch 6 Train Loss: 1.484980 Train Acc: 67.8571% Val Loss: 1.611183 Val Acc: 51.6000%
Epoch 7 Train Loss: 1.267712 Train Acc: 71.4286% Val Loss: 1.504884 Val Acc: 58.4000%
Epoch 8 Train Loss: 1.319321 Train Acc: 72.1429% Val Loss: 1.505576 Val Acc: 59.8000%
Epoch 9 Train Loss: 1.617085 Train Acc: 73.5714% Val Loss: 1.520861 Val Acc: 61.2000%
Epoch 10 Train Loss: 1.249781 Train Acc: 74.2857% Val Loss: 1.519172 Val Acc: 62.0000%
Epoch 11 Train Loss: 1.187690 Train Acc: 78.5714% Val Loss: 1.504537 Val Acc: 62.0000%
Epoch 12 Train Loss: 1.179360 Train Acc: 78.5714% Val Loss: 1.547555 Val Acc: 61.8000%
Epoch 13 Train Loss: 0.898748 Train Acc: 80.0000% Val Loss: 1.608348 Val Acc: 62.0000%
Epoch 14 Train Loss: 0.946830 Train Acc: 80.0000% Val Loss: 1.649864 Val Acc: 61.8000%
Epoch 15 Train Loss: 1.425960 Train Acc: 80.7143% Val Loss: 1.633293 Val Acc: 64.4000%
Epoch 16 Train Loss: 0.875585 Train Acc: 82.1429% Val Loss: 1.616586 Val Acc: 66.6000%
Epoch 17 Train Loss: 0.810615 Train Acc: 81.4286% Val Loss: 1.592887 Val Acc: 67.0000%
Epoch 18 Train Loss: 0.763062 Train Acc: 80.7143% Val Loss: 1.569996 Val Acc: 67.4000%
Epoch 19 Train Loss: 0.881348 Train Acc: 82.1429% Val Loss: 1.543069 Val Acc: 67.2000%
Epoch 20 Train Loss: 0.750949 Train Acc: 82.8571% Val Loss: 1.520200 Val Acc: 66.8000%
Epoch 21 Train Loss: 0.685395 Train Acc: 83.5714% Val Loss: 1.504100 Val Acc: 66.6000%
Epoch 22 Train Loss: 0.611383 Train Acc: 85.0000% Val Loss: 1.500499 Val Acc: 66.0000%
Epoch 23 Train Loss: 0.603166 Train Acc: 84.2857% Val Loss: 1.511355 Val Acc: 66.2000%
Epoch 24 Train Loss: 1.565990 Train Acc: 85.7143% Val Loss: 1.550029 Val Acc: 66.0000%
Epoch 25 Train Loss: 0.564261 Train Acc: 88.5714% Val Loss: 1.616223 Val Acc: 64.6000%
Epoch 26 Train Loss: 0.524013 Train Acc: 87.8571% Val Loss: 1.695767 Val Acc: 64.0000%
Epoch 27 Train Loss: 0.508034 Train Acc: 88.5714% Val Loss: 1.788847 Val Acc: 64.0000%
Epoch 28 Train Loss: 0.621814 Train Acc: 87.8571% Val Loss: 1.853112 Val Acc: 63.0000%
Epoch 29 Train Loss: 0.579144 Train Acc: 88.5714% Val Loss: 1.872776 Val Acc: 63.2000%
Epoch 30 Train Loss: 0.491464 Train Acc: 88.5714% Val Loss: 1.874164 Val Acc: 63.8000%
Epoch 31 Train Loss: 0.493937 Train Acc: 89.2857% Val Loss: 1.847677 Val Acc: 64.6000%
Epoch 32 Train Loss: 0.562605 Train Acc: 90.0000% Val Loss: 1.800509 Val Acc: 66.0000%
Epoch 33 Train Loss: 0.490372 Train Acc: 91.4286% Val Loss: 1.742707 Val Acc: 66.0000%
Epoch 34 Train Loss: 0.623589 Train Acc: 91.4286% Val Loss: 1.702445 Val Acc: 65.8000%
Epoch 35 Train Loss: 0.441532 Train Acc: 92.8571% Val Loss: 1.669238 Val Acc: 66.2000%
Epoch 36 Train Loss: 0.414883 Train Acc: 92.1429% Val Loss: 1.649799 Val Acc: 67.4000%
Epoch 37 Train Loss: 0.396852 Train Acc: 93.5714% Val Loss: 1.642260 Val Acc: 68.0000%
Epoch 38 Train Loss: 0.370066 Train Acc: 93.5714% Val Loss: 1.644972 Val Acc: 68.2000%
Epoch 39 Train Loss: 0.402366 Train Acc: 93.5714% Val Loss: 1.657053 Val Acc: 68.6000%
Epoch 40 Train Loss: 0.802922 Train Acc: 95.7143% Val Loss: 1.677369 Val Acc: 67.8000%
Epoch 41 Train Loss: 0.378652 Train Acc: 95.7143% Val Loss: 1.707681 Val Acc: 68.0000%
Epoch 42 Train Loss: 0.366849 Train Acc: 95.0000% Val Loss: 1.735516 Val Acc: 68.2000%
Early Stopping at Epoch 42
Test Loss: 1.518862 Test Acc: 68.8000%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.