Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)
Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
end
loadcora (generic function with 1 method)
Model Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end
GCN (generic function with 1 method)
Helper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
accuracy (generic function with 1 method)
Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()
Precompiling EnzymeBFloat16sExt...
7151.9 ms ✓ Enzyme → EnzymeBFloat16sExt
1 dependency successfully precompiled in 7 seconds. 47 already precompiled.
2025-07-09 04:32:38.704370: I external/xla/xla/service/service.cc:153] XLA service 0x2ccedef0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-09 04:32:38.704478: I external/xla/xla/service/service.cc:161] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752035558.705421 1227177 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752035558.705563 1227177 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752035558.705641 1227177 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752035558.718254 1227177 cuda_dnn.cc:471] Loaded cuDNN version 90800
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{false}()` but is being used within an autodiff call (gradient, jacobian, etc...). This might lead to incorrect results. If you are using a `Lux.jl` model, set it to training mode using `LuxCore.trainmode`.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:344
Epoch 1 Train Loss: 14.544163 Train Acc: 20.0000% Val Loss: 6.816375 Val Acc: 24.8000%
Epoch 2 Train Loss: 8.176273 Train Acc: 26.4286% Val Loss: 2.794208 Val Acc: 28.6000%
Epoch 3 Train Loss: 3.054783 Train Acc: 38.5714% Val Loss: 2.363533 Val Acc: 35.8000%
Epoch 4 Train Loss: 2.530659 Train Acc: 57.1429% Val Loss: 1.836085 Val Acc: 42.8000%
Epoch 5 Train Loss: 1.810864 Train Acc: 64.2857% Val Loss: 1.648743 Val Acc: 49.0000%
Epoch 6 Train Loss: 1.389204 Train Acc: 67.8571% Val Loss: 1.555167 Val Acc: 56.6000%
Epoch 7 Train Loss: 1.305906 Train Acc: 69.2857% Val Loss: 1.499859 Val Acc: 59.8000%
Epoch 8 Train Loss: 1.251253 Train Acc: 70.7143% Val Loss: 1.492746 Val Acc: 60.0000%
Epoch 9 Train Loss: 1.804210 Train Acc: 76.4286% Val Loss: 1.428393 Val Acc: 63.4000%
Epoch 10 Train Loss: 1.025554 Train Acc: 77.8571% Val Loss: 1.423040 Val Acc: 63.4000%
Epoch 11 Train Loss: 2.099176 Train Acc: 77.8571% Val Loss: 1.487065 Val Acc: 63.8000%
Epoch 12 Train Loss: 0.933550 Train Acc: 80.0000% Val Loss: 1.584491 Val Acc: 62.0000%
Epoch 13 Train Loss: 0.949730 Train Acc: 80.0000% Val Loss: 1.666309 Val Acc: 61.0000%
Epoch 14 Train Loss: 1.165631 Train Acc: 80.7143% Val Loss: 1.731187 Val Acc: 61.2000%
Epoch 15 Train Loss: 0.927006 Train Acc: 80.7143% Val Loss: 1.757825 Val Acc: 61.4000%
Epoch 16 Train Loss: 0.953845 Train Acc: 82.8571% Val Loss: 1.756999 Val Acc: 62.2000%
Epoch 17 Train Loss: 0.983249 Train Acc: 82.8571% Val Loss: 1.739863 Val Acc: 63.4000%
Epoch 18 Train Loss: 0.905488 Train Acc: 85.7143% Val Loss: 1.707660 Val Acc: 64.2000%
Epoch 19 Train Loss: 0.984189 Train Acc: 85.7143% Val Loss: 1.669651 Val Acc: 65.0000%
Epoch 20 Train Loss: 0.875392 Train Acc: 82.1429% Val Loss: 1.634181 Val Acc: 66.8000%
Epoch 21 Train Loss: 0.798192 Train Acc: 83.5714% Val Loss: 1.595893 Val Acc: 67.0000%
Epoch 22 Train Loss: 0.641376 Train Acc: 84.2857% Val Loss: 1.561626 Val Acc: 67.0000%
Epoch 23 Train Loss: 0.560968 Train Acc: 85.0000% Val Loss: 1.542006 Val Acc: 66.4000%
Epoch 24 Train Loss: 0.679021 Train Acc: 87.1429% Val Loss: 1.536416 Val Acc: 66.0000%
Epoch 25 Train Loss: 0.524915 Train Acc: 87.8571% Val Loss: 1.547432 Val Acc: 66.4000%
Epoch 26 Train Loss: 0.753763 Train Acc: 87.8571% Val Loss: 1.569046 Val Acc: 65.4000%
Epoch 27 Train Loss: 0.476359 Train Acc: 87.8571% Val Loss: 1.598529 Val Acc: 65.0000%
Epoch 28 Train Loss: 0.478846 Train Acc: 87.8571% Val Loss: 1.621054 Val Acc: 64.4000%
Epoch 29 Train Loss: 0.527370 Train Acc: 90.0000% Val Loss: 1.658599 Val Acc: 65.4000%
Epoch 30 Train Loss: 0.469310 Train Acc: 90.7143% Val Loss: 1.683416 Val Acc: 66.2000%
Early Stopping at Epoch 30
Test Loss: 1.474484 Test Acc: 68.8000%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.