Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
endModel Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
endHelper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1761832949.954871 1835273 service.cc:158] XLA service 0x47501910 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1761832949.954919 1835273 service.cc:166] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1761832949.956858 1835273 se_gpu_pjrt_client.cc:770] Using BFC allocator.
I0000 00:00:1761832949.956998 1835273 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1761832949.957069 1835273 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1761832949.979179 1835273 cuda_dnn.cc:463] Loaded cuDNN version 91400
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-6/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-6/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch 1 Train Loss: 15.485048 Train Acc: 22.1429% Val Loss: 7.573927 Val Acc: 25.8000%
Epoch 2 Train Loss: 10.128677 Train Acc: 22.1429% Val Loss: 3.801739 Val Acc: 29.4000%
Epoch 3 Train Loss: 4.470746 Train Acc: 37.8571% Val Loss: 2.432641 Val Acc: 32.0000%
Epoch 4 Train Loss: 2.426969 Train Acc: 51.4286% Val Loss: 2.113595 Val Acc: 37.8000%
Epoch 5 Train Loss: 1.762580 Train Acc: 59.2857% Val Loss: 1.886148 Val Acc: 45.0000%
Epoch 6 Train Loss: 1.483075 Train Acc: 67.8571% Val Loss: 1.608979 Val Acc: 51.8000%
Epoch 7 Train Loss: 1.267287 Train Acc: 71.4286% Val Loss: 1.503945 Val Acc: 58.8000%
Epoch 8 Train Loss: 1.319186 Train Acc: 72.1429% Val Loss: 1.505008 Val Acc: 60.2000%
Epoch 9 Train Loss: 1.627862 Train Acc: 72.8571% Val Loss: 1.520744 Val Acc: 61.4000%
Epoch 10 Train Loss: 1.249889 Train Acc: 74.2857% Val Loss: 1.519149 Val Acc: 62.0000%
Epoch 11 Train Loss: 1.186345 Train Acc: 78.5714% Val Loss: 1.504691 Val Acc: 62.4000%
Epoch 12 Train Loss: 1.178998 Train Acc: 78.5714% Val Loss: 1.548285 Val Acc: 61.4000%
Epoch 13 Train Loss: 0.900217 Train Acc: 79.2857% Val Loss: 1.609031 Val Acc: 62.2000%
Epoch 14 Train Loss: 0.947796 Train Acc: 80.0000% Val Loss: 1.651340 Val Acc: 62.0000%
Epoch 15 Train Loss: 1.408850 Train Acc: 80.7143% Val Loss: 1.636881 Val Acc: 64.0000%
Epoch 16 Train Loss: 0.877241 Train Acc: 82.1429% Val Loss: 1.619912 Val Acc: 66.2000%
Epoch 17 Train Loss: 0.810141 Train Acc: 81.4286% Val Loss: 1.595838 Val Acc: 66.6000%
Epoch 18 Train Loss: 0.763295 Train Acc: 80.7143% Val Loss: 1.572466 Val Acc: 67.8000%
Epoch 19 Train Loss: 0.878369 Train Acc: 82.1429% Val Loss: 1.545329 Val Acc: 67.2000%
Epoch 20 Train Loss: 0.748767 Train Acc: 82.8571% Val Loss: 1.521424 Val Acc: 66.6000%
Epoch 21 Train Loss: 0.683971 Train Acc: 82.8571% Val Loss: 1.503487 Val Acc: 66.6000%
Epoch 22 Train Loss: 0.610947 Train Acc: 85.0000% Val Loss: 1.498878 Val Acc: 66.0000%
Epoch 23 Train Loss: 0.603293 Train Acc: 85.0000% Val Loss: 1.508867 Val Acc: 66.0000%
Epoch 24 Train Loss: 1.566445 Train Acc: 85.0000% Val Loss: 1.545763 Val Acc: 66.2000%
Epoch 25 Train Loss: 0.565211 Train Acc: 87.1429% Val Loss: 1.609411 Val Acc: 65.0000%
Epoch 26 Train Loss: 0.524949 Train Acc: 87.1429% Val Loss: 1.686930 Val Acc: 64.2000%
Epoch 27 Train Loss: 0.506573 Train Acc: 88.5714% Val Loss: 1.779495 Val Acc: 64.2000%
Epoch 28 Train Loss: 0.619784 Train Acc: 87.8571% Val Loss: 1.844049 Val Acc: 63.2000%
Epoch 29 Train Loss: 0.578553 Train Acc: 87.8571% Val Loss: 1.864435 Val Acc: 63.4000%
Epoch 30 Train Loss: 0.487108 Train Acc: 88.5714% Val Loss: 1.866593 Val Acc: 64.0000%
Epoch 31 Train Loss: 0.490343 Train Acc: 89.2857% Val Loss: 1.841034 Val Acc: 64.4000%
Epoch 32 Train Loss: 0.561964 Train Acc: 90.0000% Val Loss: 1.792846 Val Acc: 66.2000%
Epoch 33 Train Loss: 0.489364 Train Acc: 91.4286% Val Loss: 1.735448 Val Acc: 66.6000%
Epoch 34 Train Loss: 0.614714 Train Acc: 91.4286% Val Loss: 1.695194 Val Acc: 66.4000%
Epoch 35 Train Loss: 0.439960 Train Acc: 92.8571% Val Loss: 1.663030 Val Acc: 66.6000%
Epoch 36 Train Loss: 0.414342 Train Acc: 92.8571% Val Loss: 1.645411 Val Acc: 67.4000%
Epoch 37 Train Loss: 0.395993 Train Acc: 93.5714% Val Loss: 1.639418 Val Acc: 68.2000%
Epoch 38 Train Loss: 0.370129 Train Acc: 93.5714% Val Loss: 1.643817 Val Acc: 68.2000%
Epoch 39 Train Loss: 0.405889 Train Acc: 93.5714% Val Loss: 1.657985 Val Acc: 67.8000%
Epoch 40 Train Loss: 0.799954 Train Acc: 95.7143% Val Loss: 1.680702 Val Acc: 67.6000%
Epoch 41 Train Loss: 0.378267 Train Acc: 95.7143% Val Loss: 1.712448 Val Acc: 67.8000%
Epoch 42 Train Loss: 0.367016 Train Acc: 95.0000% Val Loss: 1.747182 Val Acc: 68.4000%
Early Stopping at Epoch 42
Test Loss: 1.522911 Test Acc: 68.6000%Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.