Graph Convolutional Networks on Cora
This example is based on GCN MLX tutorial. While we are doing this manually, we recommend directly using GNNLux.jl.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
GNNGraphs,
ConcreteStructs,
Printf,
OneHotArrays,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)
Loading Cora Dataset
julia
function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
# We use a dense matrix here to avoid incompatibility with Reactant
Matrix{Int32}(adjacency_matrix(gnngraph)),
# We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708),
)
end
loadcora (generic function with 1 method)
Model Definition
julia
function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end
function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [
GCNLayer(in_dim => out_dim; kwargs...) for
(in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])
]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)
return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end
GCN (generic function with 1 method)
Helper Functions
julia
function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end
accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100
accuracy (generic function with 1 method)
Training the Model
julia
function main(;
hidden_dim::Int=64,
dropout::Float64=0.1,
nb_layers::Int=2,
use_bias::Bool=true,
lr::Float64=0.001,
weight_decay::Float64=0.0,
patience::Int=20,
epochs::Int=200,
)
rng = Random.default_rng()
Random.seed!(rng, 0)
features, targets, adj, (train_idx, val_idx, test_idx) = xdev(loadcora())
gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = xdev(Lux.setup(rng, gcn))
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)
train_state = Training.TrainState(gcn, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
val_loss_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile loss_function(gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))
end
train_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
end
val_model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile gcn((features, adj, val_idx), ps, Lux.testmode(st))
end
best_loss_val = Inf
cnt = 0
for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(),
loss_function,
(features, targets, adj, train_idx),
train_state;
return_gradients=Val(false),
)
train_acc = accuracy(
Array(
train_model_compiled(
(features, adj, train_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, train_idx],
)
val_loss = first(
val_loss_compiled(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, val_idx),
),
)
val_acc = accuracy(
Array(
val_model_compiled(
(features, adj, val_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)[1],
),
Array(targets)[:, val_idx],
)
@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc
if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end
Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
test_loss = @jit(
loss_function(
gcn,
train_state.parameters,
Lux.testmode(train_state.states),
(features, targets, adj, test_idx),
)
)[1]
test_acc = accuracy(
Array(
@jit(
gcn(
(features, adj, test_idx),
train_state.parameters,
Lux.testmode(train_state.states),
)
)[1],
),
Array(targets)[:, test_idx],
)
@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
end
return nothing
end
main()
Precompiling BFloat16sExt...
1259.3 ms ✓ LLVM → BFloat16sExt
1 dependency successfully precompiled in 1 seconds. 30 already precompiled.
Precompiling EnzymeBFloat16sExt...
6487.8 ms ✓ Enzyme → EnzymeBFloat16sExt
1 dependency successfully precompiled in 7 seconds. 47 already precompiled.
2025-07-09 04:15:43.185369: I external/xla/xla/service/service.cc:153] XLA service 0x3d2a0530 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-09 04:15:43.185484: I external/xla/xla/service/service.cc:161] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752034543.186674 1143982 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752034543.186799 1143982 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752034543.186873 1143982 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752034543.201715 1143982 cuda_dnn.cc:471] Loaded cuDNN version 90800
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-4/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Total Trainable Parameters: 0.0964 M
┌ Warning: `training` is set to `Val{false}()` but is being used within an autodiff call (gradient, jacobian, etc...). This might lead to incorrect results. If you are using a `Lux.jl` model, set it to training mode using `LuxCore.trainmode`.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-4/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:344
Epoch 1 Train Loss: 16.066574 Train Acc: 22.1429% Val Loss: 6.879176 Val Acc: 24.8000%
Epoch 2 Train Loss: 7.957355 Train Acc: 29.2857% Val Loss: 2.597338 Val Acc: 33.2000%
Epoch 3 Train Loss: 4.583193 Train Acc: 50.0000% Val Loss: 1.744905 Val Acc: 46.6000%
Epoch 4 Train Loss: 1.908341 Train Acc: 56.4286% Val Loss: 1.791154 Val Acc: 46.4000%
Epoch 5 Train Loss: 1.673450 Train Acc: 64.2857% Val Loss: 1.683114 Val Acc: 50.8000%
Epoch 6 Train Loss: 1.478799 Train Acc: 73.5714% Val Loss: 1.502969 Val Acc: 57.6000%
Epoch 7 Train Loss: 1.157417 Train Acc: 75.0000% Val Loss: 1.401800 Val Acc: 62.0000%
Epoch 8 Train Loss: 1.141568 Train Acc: 78.5714% Val Loss: 1.397459 Val Acc: 61.4000%
Epoch 9 Train Loss: 1.708609 Train Acc: 79.2857% Val Loss: 1.401590 Val Acc: 62.2000%
Epoch 10 Train Loss: 1.051392 Train Acc: 82.1429% Val Loss: 1.415236 Val Acc: 64.8000%
Epoch 11 Train Loss: 0.994985 Train Acc: 83.5714% Val Loss: 1.430862 Val Acc: 64.8000%
Epoch 12 Train Loss: 0.982618 Train Acc: 83.5714% Val Loss: 1.442781 Val Acc: 65.8000%
Epoch 13 Train Loss: 0.822561 Train Acc: 85.0000% Val Loss: 1.447394 Val Acc: 66.4000%
Epoch 14 Train Loss: 0.920512 Train Acc: 83.5714% Val Loss: 1.458303 Val Acc: 67.6000%
Epoch 15 Train Loss: 0.812066 Train Acc: 84.2857% Val Loss: 1.465748 Val Acc: 67.8000%
Epoch 16 Train Loss: 0.761003 Train Acc: 84.2857% Val Loss: 1.458907 Val Acc: 67.6000%
Epoch 17 Train Loss: 0.758469 Train Acc: 87.1429% Val Loss: 1.450104 Val Acc: 67.8000%
Epoch 18 Train Loss: 0.774662 Train Acc: 87.1429% Val Loss: 1.424158 Val Acc: 67.8000%
Epoch 19 Train Loss: 0.718235 Train Acc: 87.8571% Val Loss: 1.400826 Val Acc: 67.8000%
Epoch 20 Train Loss: 0.575362 Train Acc: 88.5714% Val Loss: 1.392055 Val Acc: 68.4000%
Epoch 21 Train Loss: 0.553364 Train Acc: 89.2857% Val Loss: 1.394837 Val Acc: 67.6000%
Epoch 22 Train Loss: 0.523534 Train Acc: 89.2857% Val Loss: 1.409594 Val Acc: 68.4000%
Epoch 23 Train Loss: 0.547120 Train Acc: 89.2857% Val Loss: 1.430590 Val Acc: 68.0000%
Epoch 24 Train Loss: 0.593029 Train Acc: 88.5714% Val Loss: 1.458196 Val Acc: 68.6000%
Epoch 25 Train Loss: 0.564101 Train Acc: 90.7143% Val Loss: 1.486715 Val Acc: 68.4000%
Epoch 26 Train Loss: 0.478272 Train Acc: 89.2857% Val Loss: 1.523585 Val Acc: 67.8000%
Epoch 27 Train Loss: 0.456502 Train Acc: 91.4286% Val Loss: 1.561760 Val Acc: 67.4000%
Epoch 28 Train Loss: 0.447176 Train Acc: 91.4286% Val Loss: 1.590273 Val Acc: 67.2000%
Epoch 29 Train Loss: 0.418273 Train Acc: 92.1429% Val Loss: 1.606814 Val Acc: 67.8000%
Epoch 30 Train Loss: 0.502453 Train Acc: 92.8571% Val Loss: 1.615381 Val Acc: 67.8000%
Epoch 31 Train Loss: 0.443167 Train Acc: 93.5714% Val Loss: 1.620447 Val Acc: 67.8000%
Epoch 32 Train Loss: 0.367190 Train Acc: 93.5714% Val Loss: 1.625762 Val Acc: 67.2000%
Epoch 33 Train Loss: 0.507516 Train Acc: 93.5714% Val Loss: 1.665179 Val Acc: 67.2000%
Epoch 34 Train Loss: 0.459586 Train Acc: 93.5714% Val Loss: 1.682184 Val Acc: 67.2000%
Epoch 35 Train Loss: 0.396707 Train Acc: 93.5714% Val Loss: 1.687292 Val Acc: 67.0000%
Epoch 36 Train Loss: 0.330848 Train Acc: 93.5714% Val Loss: 1.685548 Val Acc: 67.2000%
Epoch 37 Train Loss: 0.364265 Train Acc: 94.2857% Val Loss: 1.687667 Val Acc: 67.2000%
Epoch 38 Train Loss: 0.378305 Train Acc: 95.0000% Val Loss: 1.683174 Val Acc: 67.6000%
Epoch 39 Train Loss: 0.454592 Train Acc: 95.0000% Val Loss: 1.724287 Val Acc: 67.4000%
Epoch 40 Train Loss: 0.329054 Train Acc: 95.0000% Val Loss: 1.763165 Val Acc: 67.0000%
Early Stopping at Epoch 40
Test Loss: 1.631037 Test Acc: 68.3000%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.