Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
end
Model Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end
Helper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()
2025-08-05 23:25:53.732636: I external/xla/xla/service/service.cc:163] XLA service 0x13ad4c50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-08-05 23:25:53.732751: I external/xla/xla/service/service.cc:171] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1754436353.733934 324692 se_gpu_pjrt_client.cc:1373] Using BFC allocator.
I0000 00:00:1754436353.734227 324692 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1754436353.734352 324692 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
2025-08-05 23:25:53.752855: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-6/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{false}()` but is being used within an autodiff call (gradient, jacobian, etc...). This might lead to incorrect results. If you are using a `Lux.jl` model, set it to training mode using `LuxCore.trainmode`.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-6/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:344
Epoch 1 Train Loss: 16.434761 Train Acc: 20.0000% Val Loss: 8.514094 Val Acc: 21.2000%
Epoch 2 Train Loss: 9.712398 Train Acc: 23.5714% Val Loss: 4.395040 Val Acc: 29.4000%
Epoch 3 Train Loss: 4.724273 Train Acc: 37.8571% Val Loss: 2.064650 Val Acc: 36.4000%
Epoch 4 Train Loss: 2.234418 Train Acc: 44.2857% Val Loss: 2.274232 Val Acc: 39.2000%
Epoch 5 Train Loss: 2.302605 Train Acc: 47.1429% Val Loss: 2.332959 Val Acc: 37.4000%
Epoch 6 Train Loss: 2.403902 Train Acc: 62.8571% Val Loss: 1.965535 Val Acc: 45.4000%
Epoch 7 Train Loss: 1.610807 Train Acc: 71.4286% Val Loss: 1.653510 Val Acc: 54.2000%
Epoch 8 Train Loss: 1.336507 Train Acc: 72.1429% Val Loss: 1.524905 Val Acc: 61.4000%
Epoch 9 Train Loss: 1.174518 Train Acc: 74.2857% Val Loss: 1.492613 Val Acc: 62.6000%
Epoch 10 Train Loss: 1.125156 Train Acc: 76.4286% Val Loss: 1.480971 Val Acc: 64.4000%
Epoch 11 Train Loss: 1.222813 Train Acc: 77.8571% Val Loss: 1.459094 Val Acc: 66.0000%
Epoch 12 Train Loss: 1.021737 Train Acc: 79.2857% Val Loss: 1.445725 Val Acc: 65.8000%
Epoch 13 Train Loss: 0.973203 Train Acc: 79.2857% Val Loss: 1.447371 Val Acc: 65.8000%
Epoch 14 Train Loss: 0.825098 Train Acc: 79.2857% Val Loss: 1.460149 Val Acc: 66.4000%
Epoch 15 Train Loss: 0.740183 Train Acc: 85.0000% Val Loss: 1.488243 Val Acc: 66.4000%
Epoch 16 Train Loss: 1.266764 Train Acc: 83.5714% Val Loss: 1.495189 Val Acc: 66.4000%
Epoch 17 Train Loss: 2.369039 Train Acc: 85.0000% Val Loss: 1.481928 Val Acc: 67.4000%
Epoch 18 Train Loss: 0.696506 Train Acc: 84.2857% Val Loss: 1.483977 Val Acc: 66.0000%
Epoch 19 Train Loss: 0.660127 Train Acc: 85.0000% Val Loss: 1.489409 Val Acc: 66.0000%
Epoch 20 Train Loss: 1.067086 Train Acc: 85.0000% Val Loss: 1.526611 Val Acc: 66.6000%
Epoch 21 Train Loss: 0.606778 Train Acc: 85.0000% Val Loss: 1.585117 Val Acc: 66.0000%
Epoch 22 Train Loss: 0.648352 Train Acc: 85.7143% Val Loss: 1.648506 Val Acc: 65.6000%
Epoch 23 Train Loss: 0.743714 Train Acc: 86.4286% Val Loss: 1.711921 Val Acc: 65.8000%
Epoch 24 Train Loss: 0.716126 Train Acc: 87.1429% Val Loss: 1.748055 Val Acc: 65.8000%
Epoch 25 Train Loss: 0.749043 Train Acc: 87.1429% Val Loss: 1.765788 Val Acc: 65.6000%
Epoch 26 Train Loss: 0.678848 Train Acc: 88.5714% Val Loss: 1.757170 Val Acc: 67.0000%
Epoch 27 Train Loss: 0.847598 Train Acc: 89.2857% Val Loss: 1.742121 Val Acc: 67.6000%
Epoch 28 Train Loss: 0.673782 Train Acc: 88.5714% Val Loss: 1.710881 Val Acc: 67.4000%
Epoch 29 Train Loss: 0.633077 Train Acc: 90.0000% Val Loss: 1.686995 Val Acc: 68.0000%
Epoch 30 Train Loss: 0.517133 Train Acc: 91.4286% Val Loss: 1.662811 Val Acc: 67.4000%
Epoch 31 Train Loss: 0.574781 Train Acc: 90.0000% Val Loss: 1.641438 Val Acc: 67.4000%
Epoch 32 Train Loss: 0.551590 Train Acc: 90.0000% Val Loss: 1.624909 Val Acc: 67.8000%
Early Stopping at Epoch 32
Test Loss: 1.385764 Test Acc: 71.1000%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.