Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
end
Model Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end
Helper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1760231437.476673 79531 service.cc:158] XLA service 0x4d9986c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760231437.476762 79531 service.cc:166] StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
I0000 00:00:1760231437.476768 79531 service.cc:166] StreamExecutor device (1): Quadro RTX 5000, Compute Capability 7.5
I0000 00:00:1760231437.482079 79531 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1760231437.482127 79531 gpu_helpers.cc:136] XLA backend allocating 12526534656 bytes on device 0 for BFCAllocator.
I0000 00:00:1760231437.482173 79531 gpu_helpers.cc:136] XLA backend allocating 12526534656 bytes on device 1 for BFCAllocator.
I0000 00:00:1760231437.482190 79531 gpu_helpers.cc:177] XLA backend will use up to 4175511552 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1760231437.482207 79531 gpu_helpers.cc:177] XLA backend will use up to 4175511552 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1760231437.493627 79531 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-16/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-16/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch 1 Train Loss: 17.697763 Train Acc: 20.7143% Val Loss: 6.744421 Val Acc: 26.0000%
Epoch 2 Train Loss: 8.571198 Train Acc: 26.4286% Val Loss: 2.712524 Val Acc: 31.4000%
Epoch 3 Train Loss: 2.910695 Train Acc: 40.0000% Val Loss: 2.135307 Val Acc: 38.2000%
Epoch 4 Train Loss: 2.795473 Train Acc: 52.8571% Val Loss: 1.837171 Val Acc: 46.2000%
Epoch 5 Train Loss: 2.066922 Train Acc: 60.7143% Val Loss: 1.910931 Val Acc: 46.6000%
Epoch 6 Train Loss: 1.762443 Train Acc: 63.5714% Val Loss: 1.863445 Val Acc: 48.4000%
Epoch 7 Train Loss: 1.493714 Train Acc: 68.5714% Val Loss: 1.749854 Val Acc: 52.8000%
Epoch 8 Train Loss: 1.435782 Train Acc: 72.8571% Val Loss: 1.621837 Val Acc: 59.4000%
Epoch 9 Train Loss: 1.241855 Train Acc: 75.0000% Val Loss: 1.551770 Val Acc: 62.4000%
Epoch 10 Train Loss: 1.570425 Train Acc: 76.4286% Val Loss: 1.507813 Val Acc: 63.6000%
Epoch 11 Train Loss: 1.020015 Train Acc: 79.2857% Val Loss: 1.505572 Val Acc: 65.0000%
Epoch 12 Train Loss: 1.030396 Train Acc: 78.5714% Val Loss: 1.522349 Val Acc: 64.2000%
Epoch 13 Train Loss: 0.959203 Train Acc: 77.8571% Val Loss: 1.550502 Val Acc: 63.2000%
Epoch 14 Train Loss: 0.945190 Train Acc: 77.1429% Val Loss: 1.575250 Val Acc: 62.8000%
Epoch 15 Train Loss: 0.889157 Train Acc: 77.8571% Val Loss: 1.596992 Val Acc: 63.4000%
Epoch 16 Train Loss: 1.066568 Train Acc: 80.0000% Val Loss: 1.596375 Val Acc: 63.2000%
Epoch 17 Train Loss: 1.299356 Train Acc: 85.0000% Val Loss: 1.636672 Val Acc: 64.0000%
Epoch 18 Train Loss: 0.800040 Train Acc: 84.2857% Val Loss: 1.749711 Val Acc: 64.4000%
Epoch 19 Train Loss: 0.738295 Train Acc: 82.1429% Val Loss: 1.902346 Val Acc: 62.8000%
Epoch 20 Train Loss: 0.846399 Train Acc: 80.7143% Val Loss: 2.023381 Val Acc: 62.0000%
Epoch 21 Train Loss: 1.027444 Train Acc: 80.7143% Val Loss: 2.090489 Val Acc: 61.8000%
Epoch 22 Train Loss: 1.034440 Train Acc: 81.4286% Val Loss: 2.086383 Val Acc: 62.0000%
Epoch 23 Train Loss: 0.932883 Train Acc: 84.2857% Val Loss: 2.020211 Val Acc: 62.2000%
Epoch 24 Train Loss: 0.854405 Train Acc: 87.1429% Val Loss: 1.926812 Val Acc: 62.8000%
Epoch 25 Train Loss: 0.823050 Train Acc: 87.8571% Val Loss: 1.839584 Val Acc: 65.6000%
Epoch 26 Train Loss: 0.824092 Train Acc: 87.8571% Val Loss: 1.759949 Val Acc: 65.4000%
Epoch 27 Train Loss: 0.726058 Train Acc: 87.8571% Val Loss: 1.696039 Val Acc: 65.6000%
Epoch 28 Train Loss: 0.691226 Train Acc: 87.8571% Val Loss: 1.646164 Val Acc: 65.8000%
Epoch 29 Train Loss: 0.567212 Train Acc: 87.8571% Val Loss: 1.609809 Val Acc: 67.4000%
Epoch 30 Train Loss: 0.929916 Train Acc: 88.5714% Val Loss: 1.586185 Val Acc: 67.2000%
Epoch 31 Train Loss: 0.443747 Train Acc: 88.5714% Val Loss: 1.573454 Val Acc: 67.2000%
Early Stopping at Epoch 31
Test Loss: 1.459343 Test Acc: 70.7000%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.