Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
end
Model Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end
Helper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1757913805.102953 2638698 service.cc:163] XLA service 0x43d65de0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757913805.103028 2638698 service.cc:171] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
I0000 00:00:1757913805.103929 2638698 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1757913805.103970 2638698 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1757913805.104005 2638698 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1757913805.115381 2638698 cuda_dnn.cc:463] Loaded cuDNN version 91200
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:334
Epoch 1 Train Loss: 16.541618 Train Acc: 22.1429% Val Loss: 7.779559 Val Acc: 25.0000%
Epoch 2 Train Loss: 8.098966 Train Acc: 21.4286% Val Loss: 4.113542 Val Acc: 28.6000%
Epoch 3 Train Loss: 4.958270 Train Acc: 42.1429% Val Loss: 1.929405 Val Acc: 41.4000%
Epoch 4 Train Loss: 1.731582 Train Acc: 53.5714% Val Loss: 2.011568 Val Acc: 41.8000%
Epoch 5 Train Loss: 2.024558 Train Acc: 59.2857% Val Loss: 1.961110 Val Acc: 44.0000%
Epoch 6 Train Loss: 1.851948 Train Acc: 67.1429% Val Loss: 1.736604 Val Acc: 50.2000%
Epoch 7 Train Loss: 1.185628 Train Acc: 73.5714% Val Loss: 1.604445 Val Acc: 56.2000%
Epoch 8 Train Loss: 1.406230 Train Acc: 75.0000% Val Loss: 1.536368 Val Acc: 61.4000%
Epoch 9 Train Loss: 1.216595 Train Acc: 76.4286% Val Loss: 1.512406 Val Acc: 63.0000%
Epoch 10 Train Loss: 1.242570 Train Acc: 77.8571% Val Loss: 1.486743 Val Acc: 62.8000%
Epoch 11 Train Loss: 1.076788 Train Acc: 80.7143% Val Loss: 1.468861 Val Acc: 62.2000%
Epoch 12 Train Loss: 0.981122 Train Acc: 81.4286% Val Loss: 1.454582 Val Acc: 64.0000%
Epoch 13 Train Loss: 0.919534 Train Acc: 81.4286% Val Loss: 1.448876 Val Acc: 66.0000%
Epoch 14 Train Loss: 0.804165 Train Acc: 81.4286% Val Loss: 1.454256 Val Acc: 65.0000%
Epoch 15 Train Loss: 0.942576 Train Acc: 82.8571% Val Loss: 1.465008 Val Acc: 65.6000%
Epoch 16 Train Loss: 0.744699 Train Acc: 85.0000% Val Loss: 1.512742 Val Acc: 65.2000%
Epoch 17 Train Loss: 0.901531 Train Acc: 85.7143% Val Loss: 1.555849 Val Acc: 64.8000%
Epoch 18 Train Loss: 0.682303 Train Acc: 85.7143% Val Loss: 1.594888 Val Acc: 64.8000%
Epoch 19 Train Loss: 0.716223 Train Acc: 86.4286% Val Loss: 1.581796 Val Acc: 65.0000%
Epoch 20 Train Loss: 0.622499 Train Acc: 86.4286% Val Loss: 1.550115 Val Acc: 65.6000%
Epoch 21 Train Loss: 0.668506 Train Acc: 87.1429% Val Loss: 1.517670 Val Acc: 66.4000%
Epoch 22 Train Loss: 0.599883 Train Acc: 87.1429% Val Loss: 1.483986 Val Acc: 67.6000%
Epoch 23 Train Loss: 0.642122 Train Acc: 87.1429% Val Loss: 1.458639 Val Acc: 68.4000%
Epoch 24 Train Loss: 0.516070 Train Acc: 87.1429% Val Loss: 1.436628 Val Acc: 68.8000%
Epoch 25 Train Loss: 0.632938 Train Acc: 86.4286% Val Loss: 1.429646 Val Acc: 68.0000%
Epoch 26 Train Loss: 0.490367 Train Acc: 87.1429% Val Loss: 1.429122 Val Acc: 68.6000%
Epoch 27 Train Loss: 0.557422 Train Acc: 87.8571% Val Loss: 1.438560 Val Acc: 67.2000%
Epoch 28 Train Loss: 0.489264 Train Acc: 87.8571% Val Loss: 1.455575 Val Acc: 66.8000%
Epoch 29 Train Loss: 0.464357 Train Acc: 87.8571% Val Loss: 1.475377 Val Acc: 66.6000%
Epoch 30 Train Loss: 0.522491 Train Acc: 86.4286% Val Loss: 1.503930 Val Acc: 65.8000%
Epoch 31 Train Loss: 0.449556 Train Acc: 86.4286% Val Loss: 1.538563 Val Acc: 65.4000%
Epoch 32 Train Loss: 0.504742 Train Acc: 87.8571% Val Loss: 1.551327 Val Acc: 65.2000%
Epoch 33 Train Loss: 0.574246 Train Acc: 92.1429% Val Loss: 1.607988 Val Acc: 65.8000%
Epoch 34 Train Loss: 0.449091 Train Acc: 92.1429% Val Loss: 1.711923 Val Acc: 65.4000%
Epoch 35 Train Loss: 0.363941 Train Acc: 90.7143% Val Loss: 1.833322 Val Acc: 66.0000%
Epoch 36 Train Loss: 0.400379 Train Acc: 90.7143% Val Loss: 1.938784 Val Acc: 64.6000%
Epoch 37 Train Loss: 0.529520 Train Acc: 90.0000% Val Loss: 1.995430 Val Acc: 65.0000%
Epoch 38 Train Loss: 0.600717 Train Acc: 90.0000% Val Loss: 1.997332 Val Acc: 65.6000%
Epoch 39 Train Loss: 0.599057 Train Acc: 92.1429% Val Loss: 1.945721 Val Acc: 65.8000%
Epoch 40 Train Loss: 0.509541 Train Acc: 92.1429% Val Loss: 1.866676 Val Acc: 66.4000%
Epoch 41 Train Loss: 0.447534 Train Acc: 92.8571% Val Loss: 1.780149 Val Acc: 67.4000%
Epoch 42 Train Loss: 0.347354 Train Acc: 94.2857% Val Loss: 1.694155 Val Acc: 68.4000%
Epoch 43 Train Loss: 0.399075 Train Acc: 95.0000% Val Loss: 1.628220 Val Acc: 67.4000%
Epoch 44 Train Loss: 0.368840 Train Acc: 95.0000% Val Loss: 1.579233 Val Acc: 67.2000%
Epoch 45 Train Loss: 0.365899 Train Acc: 95.0000% Val Loss: 1.555665 Val Acc: 67.0000%
Epoch 46 Train Loss: 0.308438 Train Acc: 95.0000% Val Loss: 1.548830 Val Acc: 67.8000%
Early Stopping at Epoch 46
Test Loss: 1.341541 Test Acc: 70.5000%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.