Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))
- Add ComplexWebQuestions (CWQ) dataset ([#9950](https://github.com/pyg-team/pytorch_geometric/pull/9950))
- Add GNAN model and dedicated dataloader plus an example (https://github.com/pyg-team/pytorch_geometric/pull/10371)

### Changed

Expand Down
212 changes: 212 additions & 0 deletions examples/gnan_graph_mutagenicity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""Training GNAN (graph‐level) on the Mutagenicity dataset.

This reproduces, in simplified form, the experiment from the GNAN paper
(Bechler-Speicher *et al.*, 2024) on the Mutagenicity molecule dataset.

The script demonstrates how to:
1. Load the Mutagenicity dataset from TUDataset.
2. Pre-compute all-pairs shortest-path distances **per graph** and the
corresponding normalisation matrices required by GNAN.
3. Train the *TensorGNAN* model for graph classification.

Run with:

python examples/gnan_graph_mutagenicity.py

Graphs in Mutagenicity are small (≈30 nodes), therefore the dense distance
matrix fits comfortably in memory and can be coqmputed on the fly.
"""

from __future__ import annotations

import random
from pathlib import Path

import networkx as nx
import torch
from torch import nn
from tqdm import tqdm

from torch_geometric.datasets import TUDataset
from torch_geometric.loader.gnan_dataloader import GNANDataLoader
from torch_geometric.nn.models import TensorGNAN
from torch_geometric.utils import to_networkx


def compute_dist_and_norm(data) -> tuple[torch.Tensor, torch.Tensor]:
"""Return (distance_matrix, normalisation_matrix) for a PyG Data graph."""
def norm_from_dist(dist: torch.Tensor) -> torch.Tensor:
N = dist.size(0)
norm = torch.zeros_like(dist)
for i in range(N):
row = dist[i]
# Consider only *finite* distances when counting
finite_mask = torch.isfinite(row)
counts = torch.bincount(row[finite_mask].long(
)) if finite_mask.any() else torch.tensor([], dtype=torch.long)

for j in range(N):
if not torch.isfinite(row[j]):
# No path ⇒ normalisation of 1 to avoid division by zero
norm[i, j] = 1.0
else:
d = int(row[j].item())
norm[i, j] = counts[d] if d < len(counts) else 1.0
# Safety: ensure no zeros
norm[norm == 0] = 1.0
return norm

g = to_networkx(data, to_undirected=True)
sp = dict(nx.all_pairs_shortest_path_length(g))

N = data.num_nodes
# Initialise with +inf to mark "no path" entries explicitly
dist = torch.full((N, N), float('inf'), dtype=torch.float)

# Distance from each node to itself is 0 by definition
dist.fill_diagonal_(0.0)

# Fill finite shortest-path lengths returned by NetworkX
for i, lengths in sp.items():
for j, d in lengths.items():
dist[i, j] = float(d)

# Compute the normalisation matrix; unreachable pairs (inf) get count 1
norm = norm_from_dist(dist)
return dist, norm


class PreprocessDistances:
"""PyG Transform that adds GNAN distance attributes to each graph."""
def __call__(self, data): # noqa: D401
dist, norm = compute_dist_and_norm(data)
data.node_distances = dist
data.normalization_matrix = norm
return data


# -----------------------------------------------------------------------------


def seed_everything(seed: int = 42):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


# -----------------------------------------------------------------------------


def main():
seed_everything()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
root = Path("data/Mutagenicity")

# Save/load preprocessed dataset to avoid recomputation
dataset = TUDataset(root=str(root), name="Mutagenicity",
transform=PreprocessDistances())

num_classes = dataset.num_classes
in_channels = dataset.num_features

print(f"Dataset info: {num_classes} classes, {in_channels} features")
print(f"Dataset size: {len(dataset)} graphs")

# Simple 80/10/10 split
indices = list(range(len(dataset)))
random.shuffle(indices)
n_train = int(0.8 * len(indices))
n_val = int(0.1 * len(indices))

# print("=" * 60)
train_dataset = dataset[indices[:n_train]]
val_dataset = dataset[indices[n_train:n_train + n_val]]
test_dataset = dataset[indices[n_train + n_val:]]

# standard PyTorch DataLoader with custom collate function
train_loader = GNANDataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = GNANDataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = GNANDataLoader(test_dataset, batch_size=32, shuffle=False)

# Pick a sample graph from the *test* split to track during training.
sample_graph = test_dataset[0]
sample_graph = sample_graph.to(device)

model = TensorGNAN(
in_channels=in_channels,
out_channels=1 if num_classes == 2 else num_classes,
n_layers=3,
hidden_channels=64,
dropout=0.3,
normalize_rho=False,
# Uncomment the following line to group features together:
# feature_groups=[list(range(in_channels))]
# feature_groups=[[0,1,2,3,4,5,6,7,8,9,10,11],[12,13]]
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4,
weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max', factor=0.5, patience=3)
criterion = nn.BCEWithLogitsLoss()

def evaluate(loader):
model.eval()
correct = 0
for data in tqdm(loader, desc="Evaluating"):
data = data.to(device)
out = model(data)
pred = out.squeeze() > 0
correct += int((pred == data.y.to(device)).sum())
return correct / len(loader.dataset)

best_val_acc = 0.0
for epoch in range(1, 11):
model.train()
total_loss = 0.0
num_batches = 0

for data in tqdm(train_loader, desc="Training"):
data = data.to(device)
optimizer.zero_grad()
out = model(data)
loss = criterion(out.squeeze(-1), data.y.to(device).float())
loss.backward()
optimizer.step()
total_loss += loss.item()
num_batches += 1

avg_loss = total_loss / num_batches
val_acc = evaluate(val_loader)
test_acc = evaluate(test_loader)
best_val_acc = max(best_val_acc, val_acc)

print(f"Epoch {epoch:03d} Loss {avg_loss:.4f} "
f"ValAcc {val_acc:.4f} TestAcc {test_acc:.4f}")

# --------------------------------------------------------------
# Print prediction & node importance for the tracked sample graph
# --------------------------------------------------------------
with torch.no_grad():
sample_pred = model(sample_graph)
sample_imp = model.node_importance(sample_graph)

# Flatten tensors for nicer printing (binary classification → 1 logit)
pred_value = sample_pred.squeeze().item()
node_imp_values = sample_imp.squeeze(-1).cpu().tolist()

print("Sample graph prediction:", f"{pred_value:.4f}")
print("Node importance contributions:")
print(node_imp_values)

# Step the scheduler based on validation accuracy
scheduler.step(val_acc)

print("Best validation accuracy:", best_val_acc)
print("Final test accuracy:", evaluate(test_loader))


if __name__ == "__main__":
main()
94 changes: 94 additions & 0 deletions test/loader/test_gnan_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch

from torch_geometric.data import Batch, Data
from torch_geometric.loader.gnan_dataloader import GNANCollater, GNANDataLoader
from torch_geometric.nn.models import TensorGNAN


def _dummy_data(num_nodes: int = 5, num_feats: int = 4):
x = torch.randn(num_nodes, num_feats)
edge_index = torch.combinations(torch.arange(num_nodes), r=2).t()
# full distance matrix:
rand_dist = torch.rand(num_nodes, num_nodes)
rand_dist = (rand_dist + rand_dist.t()) / 2 # symmetric
# Use a simple normalisation matrix of ones to avoid division issues
norm = torch.ones_like(rand_dist)
data = Data(x=x, edge_index=edge_index)
data.node_distances = rand_dist
data.normalization_matrix = norm
return data


def test_gnan_collater_block_diag_and_restore():
g1 = _dummy_data(num_nodes=3, num_feats=4)
g2 = _dummy_data(num_nodes=4, num_feats=4)
g3 = _dummy_data(num_nodes=2, num_feats=4)

# Keep copies to verify restoration on originals
d1 = g1.node_distances.clone()
n1 = g1.normalization_matrix.clone()
d2 = g2.node_distances.clone()
n2 = g2.normalization_matrix.clone()
d3 = g3.node_distances.clone()
n3 = g3.normalization_matrix.clone()

collate = GNANCollater()
batch = collate([g1, g2, g3])

assert isinstance(batch, Batch)
N = g1.num_nodes + g2.num_nodes + g3.num_nodes
assert batch.node_distances.shape == (N, N)
assert batch.normalization_matrix.shape == (N, N)

expected_dist = torch.block_diag(d1, d2, d3)
expected_norm = torch.block_diag(n1, n2, n3)
assert torch.allclose(batch.node_distances, expected_dist)
assert torch.allclose(batch.normalization_matrix, expected_norm)

# Original Data objects should have attributes restored and unchanged
assert torch.allclose(g1.node_distances, d1)
assert torch.allclose(g1.normalization_matrix, n1)
assert torch.allclose(g2.node_distances, d2)
assert torch.allclose(g2.normalization_matrix, n2)
assert torch.allclose(g3.node_distances, d3)
assert torch.allclose(g3.normalization_matrix, n3)


def test_gnan_dataloader_batch_content():
g1 = _dummy_data(num_nodes=3, num_feats=3)
g2 = _dummy_data(num_nodes=5, num_feats=3)
loader = GNANDataLoader([g1, g2], batch_size=2, shuffle=False)
batch = next(iter(loader))

assert isinstance(batch, Batch)
assert batch.x.size(0) == g1.num_nodes + g2.num_nodes
assert hasattr(batch, 'node_distances') and hasattr(
batch, 'normalization_matrix')

expected_dist = torch.block_diag(g1.node_distances, g2.node_distances)
expected_norm = torch.block_diag(g1.normalization_matrix,
g2.normalization_matrix)
assert torch.allclose(batch.node_distances, expected_dist)
assert torch.allclose(batch.normalization_matrix, expected_norm)


def test_gnan_dataloader_with_tensor_gnan():
g1 = _dummy_data(num_nodes=3, num_feats=4)
g2 = _dummy_data(num_nodes=4, num_feats=4)
loader = GNANDataLoader([g1, g2], batch_size=2, shuffle=False)
batch = next(iter(loader))

model = TensorGNAN(
in_channels=4,
out_channels=3,
n_layers=2,
hidden_channels=8,
graph_level=True,
)
model.eval()

# Batched forward vs. separate forwards
out_batched = model(batch) # [2, 3]
out_sep = torch.cat([model(g1), model(g2)], dim=0)
assert out_batched.shape == (2, 3)
assert torch.allclose(out_batched, out_sep, atol=1e-5)
Loading
Loading