Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
end
Model Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end
Helper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1760386091.841336 2849446 service.cc:158] XLA service 0x17bece50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760386091.841483 2849446 service.cc:166] StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
I0000 00:00:1760386091.842629 2849446 se_gpu_pjrt_client.cc:1339] Using BFC allocator.
I0000 00:00:1760386091.842725 2849446 gpu_helpers.cc:136] XLA backend allocating 12526534656 bytes on device 0 for BFCAllocator.
I0000 00:00:1760386091.842819 2849446 gpu_helpers.cc:177] XLA backend will use up to 4175511552 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1760386091.853805 2849446 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch 1 Train Loss: 13.836034 Train Acc: 20.7143% Val Loss: 7.001715 Val Acc: 26.4000%
Epoch 2 Train Loss: 8.745947 Train Acc: 25.0000% Val Loss: 4.235206 Val Acc: 24.4000%
Epoch 3 Train Loss: 3.893347 Train Acc: 41.4286% Val Loss: 2.219004 Val Acc: 35.2000%
Epoch 4 Train Loss: 2.026159 Train Acc: 52.8571% Val Loss: 2.200711 Val Acc: 37.8000%
Epoch 5 Train Loss: 1.997150 Train Acc: 59.2857% Val Loss: 2.033007 Val Acc: 43.2000%
Epoch 6 Train Loss: 1.647290 Train Acc: 68.5714% Val Loss: 1.790704 Val Acc: 50.4000%
Epoch 7 Train Loss: 1.207621 Train Acc: 74.2857% Val Loss: 1.608890 Val Acc: 55.4000%
Epoch 8 Train Loss: 1.194252 Train Acc: 77.1429% Val Loss: 1.534446 Val Acc: 59.6000%
Epoch 9 Train Loss: 1.011792 Train Acc: 79.2857% Val Loss: 1.521947 Val Acc: 60.8000%
Epoch 10 Train Loss: 1.098195 Train Acc: 75.7143% Val Loss: 1.592601 Val Acc: 59.4000%
Epoch 11 Train Loss: 1.391503 Train Acc: 80.7143% Val Loss: 1.580464 Val Acc: 60.8000%
Epoch 12 Train Loss: 1.397247 Train Acc: 82.8571% Val Loss: 1.622350 Val Acc: 61.8000%
Epoch 13 Train Loss: 1.048575 Train Acc: 80.0000% Val Loss: 1.683591 Val Acc: 61.8000%
Epoch 14 Train Loss: 0.903289 Train Acc: 79.2857% Val Loss: 1.743905 Val Acc: 62.0000%
Epoch 15 Train Loss: 0.946101 Train Acc: 82.8571% Val Loss: 1.792193 Val Acc: 61.2000%
Epoch 16 Train Loss: 0.895650 Train Acc: 82.1429% Val Loss: 1.837030 Val Acc: 62.0000%
Epoch 17 Train Loss: 1.044762 Train Acc: 83.5714% Val Loss: 1.843937 Val Acc: 62.6000%
Epoch 18 Train Loss: 0.993514 Train Acc: 83.5714% Val Loss: 1.811226 Val Acc: 63.2000%
Epoch 19 Train Loss: 0.756631 Train Acc: 85.0000% Val Loss: 1.768390 Val Acc: 64.0000%
Epoch 20 Train Loss: 0.781535 Train Acc: 85.0000% Val Loss: 1.713567 Val Acc: 65.4000%
Epoch 21 Train Loss: 0.572586 Train Acc: 86.4286% Val Loss: 1.671879 Val Acc: 66.4000%
Epoch 22 Train Loss: 0.701187 Train Acc: 87.1429% Val Loss: 1.630667 Val Acc: 67.0000%
Epoch 23 Train Loss: 0.525414 Train Acc: 88.5714% Val Loss: 1.599651 Val Acc: 67.4000%
Epoch 24 Train Loss: 0.714332 Train Acc: 87.8571% Val Loss: 1.588971 Val Acc: 67.0000%
Epoch 25 Train Loss: 0.615730 Train Acc: 87.1429% Val Loss: 1.580288 Val Acc: 67.6000%
Epoch 26 Train Loss: 0.506644 Train Acc: 87.1429% Val Loss: 1.576501 Val Acc: 67.4000%
Epoch 27 Train Loss: 0.527762 Train Acc: 87.1429% Val Loss: 1.574614 Val Acc: 67.0000%
Epoch 28 Train Loss: 0.713766 Train Acc: 87.8571% Val Loss: 1.592753 Val Acc: 68.2000%
Epoch 29 Train Loss: 0.521682 Train Acc: 90.0000% Val Loss: 1.606794 Val Acc: 68.2000%
Early Stopping at Epoch 29
Test Loss: 1.512579 Test Acc: 68.8000%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.