Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
end
Model Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end
Helper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1759714263.771802 1329998 service.cc:158] XLA service 0x4aa89970 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1759714263.772026 1329998 service.cc:166] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1759714263.774609 1329998 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1759714263.774723 1329998 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1759714263.775002 1329998 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1759714263.795361 1329998 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch 1 Train Loss: 16.121910 Train Acc: 22.1429% Val Loss: 6.888684 Val Acc: 25.4000%
Epoch 2 Train Loss: 9.096497 Train Acc: 25.7143% Val Loss: 2.945881 Val Acc: 30.4000%
Epoch 3 Train Loss: 3.630679 Train Acc: 43.5714% Val Loss: 1.899566 Val Acc: 44.0000%
Epoch 4 Train Loss: 2.384638 Train Acc: 57.8571% Val Loss: 1.668179 Val Acc: 47.8000%
Epoch 5 Train Loss: 1.543190 Train Acc: 62.8571% Val Loss: 1.793227 Val Acc: 46.8000%
Epoch 6 Train Loss: 1.546407 Train Acc: 68.5714% Val Loss: 1.688540 Val Acc: 51.4000%
Epoch 7 Train Loss: 1.457959 Train Acc: 75.0000% Val Loss: 1.544119 Val Acc: 58.4000%
Epoch 8 Train Loss: 1.166498 Train Acc: 75.7143% Val Loss: 1.479228 Val Acc: 62.4000%
Epoch 9 Train Loss: 1.066184 Train Acc: 79.2857% Val Loss: 1.436001 Val Acc: 64.4000%
Epoch 10 Train Loss: 1.243302 Train Acc: 77.8571% Val Loss: 1.426260 Val Acc: 64.6000%
Epoch 11 Train Loss: 1.082464 Train Acc: 79.2857% Val Loss: 1.426940 Val Acc: 65.4000%
Epoch 12 Train Loss: 1.000281 Train Acc: 80.0000% Val Loss: 1.435458 Val Acc: 66.8000%
Epoch 13 Train Loss: 1.358913 Train Acc: 78.5714% Val Loss: 1.467632 Val Acc: 65.8000%
Epoch 14 Train Loss: 0.729061 Train Acc: 82.1429% Val Loss: 1.528814 Val Acc: 64.6000%
Epoch 15 Train Loss: 0.828481 Train Acc: 81.4286% Val Loss: 1.617502 Val Acc: 63.2000%
Epoch 16 Train Loss: 0.845278 Train Acc: 80.7143% Val Loss: 1.706484 Val Acc: 62.2000%
Epoch 17 Train Loss: 0.812649 Train Acc: 81.4286% Val Loss: 1.754158 Val Acc: 62.2000%
Epoch 18 Train Loss: 1.037319 Train Acc: 82.8571% Val Loss: 1.702474 Val Acc: 63.4000%
Epoch 19 Train Loss: 0.788434 Train Acc: 84.2857% Val Loss: 1.641294 Val Acc: 64.8000%
Epoch 20 Train Loss: 0.692212 Train Acc: 86.4286% Val Loss: 1.602820 Val Acc: 65.2000%
Epoch 21 Train Loss: 0.623656 Train Acc: 86.4286% Val Loss: 1.591236 Val Acc: 66.4000%
Epoch 22 Train Loss: 0.664444 Train Acc: 85.7143% Val Loss: 1.612672 Val Acc: 65.6000%
Epoch 23 Train Loss: 0.595455 Train Acc: 84.2857% Val Loss: 1.626493 Val Acc: 66.0000%
Epoch 24 Train Loss: 0.655217 Train Acc: 85.7143% Val Loss: 1.620098 Val Acc: 66.4000%
Epoch 25 Train Loss: 0.576780 Train Acc: 85.7143% Val Loss: 1.594717 Val Acc: 67.0000%
Epoch 26 Train Loss: 0.775038 Train Acc: 88.5714% Val Loss: 1.557361 Val Acc: 67.4000%
Epoch 27 Train Loss: 0.573833 Train Acc: 87.8571% Val Loss: 1.526408 Val Acc: 68.4000%
Epoch 28 Train Loss: 0.492148 Train Acc: 88.5714% Val Loss: 1.504853 Val Acc: 68.2000%
Epoch 29 Train Loss: 0.607460 Train Acc: 89.2857% Val Loss: 1.495968 Val Acc: 68.2000%
Epoch 30 Train Loss: 0.487236 Train Acc: 89.2857% Val Loss: 1.498273 Val Acc: 68.2000%
Early Stopping at Epoch 30
Test Loss: 1.335748 Test Acc: 69.0000%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.