Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)
Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
end
loadcora (generic function with 1 method)
Model Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end
GCN (generic function with 1 method)
Helper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
accuracy (generic function with 1 method)
Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()
Precompiling EnzymeBFloat16sExt...
6518.8 ms ✓ Enzyme → EnzymeBFloat16sExt
1 dependency successfully precompiled in 7 seconds. 47 already precompiled.
2025-07-14 00:06:50.243037: I external/xla/xla/service/service.cc:153] XLA service 0x4934d7a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-14 00:06:50.243065: I external/xla/xla/service/service.cc:161] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752451610.243902 2725205 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752451610.243982 2725205 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752451610.244031 2725205 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752451610.255571 2725205 cuda_dnn.cc:471] Loaded cuDNN version 90800
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{false}()` but is being used within an autodiff call (gradient, jacobian, etc...). This might lead to incorrect results. If you are using a `Lux.jl` model, set it to training mode using `LuxCore.trainmode`.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-12/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:344
Epoch 1 Train Loss: 14.796565 Train Acc: 20.0000% Val Loss: 7.816453 Val Acc: 22.0000%
Epoch 2 Train Loss: 9.206785 Train Acc: 24.2857% Val Loss: 3.503649 Val Acc: 27.8000%
Epoch 3 Train Loss: 4.558276 Train Acc: 45.0000% Val Loss: 1.994527 Val Acc: 42.2000%
Epoch 4 Train Loss: 2.158406 Train Acc: 50.7143% Val Loss: 2.134480 Val Acc: 43.2000%
Epoch 5 Train Loss: 1.957148 Train Acc: 56.4286% Val Loss: 2.066251 Val Acc: 44.8000%
Epoch 6 Train Loss: 2.038639 Train Acc: 63.5714% Val Loss: 1.778163 Val Acc: 52.2000%
Epoch 7 Train Loss: 1.541075 Train Acc: 71.4286% Val Loss: 1.519329 Val Acc: 59.4000%
Epoch 8 Train Loss: 1.301353 Train Acc: 72.8571% Val Loss: 1.470120 Val Acc: 62.2000%
Epoch 9 Train Loss: 1.274294 Train Acc: 75.0000% Val Loss: 1.504793 Val Acc: 63.0000%
Epoch 10 Train Loss: 1.211376 Train Acc: 75.0000% Val Loss: 1.541491 Val Acc: 64.4000%
Epoch 11 Train Loss: 1.034225 Train Acc: 77.8571% Val Loss: 1.557657 Val Acc: 64.4000%
Epoch 12 Train Loss: 1.256145 Train Acc: 79.2857% Val Loss: 1.554844 Val Acc: 63.8000%
Epoch 13 Train Loss: 1.013252 Train Acc: 80.0000% Val Loss: 1.549081 Val Acc: 64.4000%
Epoch 14 Train Loss: 0.931371 Train Acc: 81.4286% Val Loss: 1.545484 Val Acc: 64.4000%
Epoch 15 Train Loss: 0.912301 Train Acc: 82.8571% Val Loss: 1.550539 Val Acc: 64.6000%
Epoch 16 Train Loss: 0.822234 Train Acc: 84.2857% Val Loss: 1.574267 Val Acc: 64.8000%
Epoch 17 Train Loss: 1.230084 Train Acc: 85.0000% Val Loss: 1.617842 Val Acc: 64.6000%
Epoch 18 Train Loss: 0.802377 Train Acc: 85.7143% Val Loss: 1.697485 Val Acc: 63.0000%
Epoch 19 Train Loss: 0.604222 Train Acc: 84.2857% Val Loss: 1.782103 Val Acc: 63.4000%
Epoch 20 Train Loss: 0.863587 Train Acc: 84.2857% Val Loss: 1.810278 Val Acc: 63.8000%
Epoch 21 Train Loss: 0.775717 Train Acc: 85.0000% Val Loss: 1.755403 Val Acc: 64.4000%
Epoch 22 Train Loss: 0.668179 Train Acc: 87.1429% Val Loss: 1.671531 Val Acc: 65.6000%
Epoch 23 Train Loss: 0.579810 Train Acc: 88.5714% Val Loss: 1.615436 Val Acc: 66.0000%
Epoch 24 Train Loss: 0.529690 Train Acc: 88.5714% Val Loss: 1.590291 Val Acc: 66.8000%
Epoch 25 Train Loss: 0.519709 Train Acc: 88.5714% Val Loss: 1.583071 Val Acc: 66.6000%
Epoch 26 Train Loss: 0.504758 Train Acc: 88.5714% Val Loss: 1.582281 Val Acc: 66.4000%
Epoch 27 Train Loss: 0.529384 Train Acc: 89.2857% Val Loss: 1.590776 Val Acc: 67.2000%
Epoch 28 Train Loss: 0.492651 Train Acc: 89.2857% Val Loss: 1.604312 Val Acc: 67.4000%
Early Stopping at Epoch 28
Test Loss: 1.365980 Test Acc: 72.2000%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.