From ae1f4a8ae704cd2dfbb71716bb9ddf5f5ee35b94 Mon Sep 17 00:00:00 2001 From: yoavk Date: Sat, 19 Jul 2025 17:48:59 +0000 Subject: [PATCH 01/13] Add GNAN model with tests, dataloader and mutagenicity example --- examples/gnan_graph_mutagenicity.py | 206 ++++++++++++++++++++++ test/nn/models/test_gnan.py | 24 +++ torch_geometric/loader/gnan_dataloader.py | 129 ++++++++++++++ torch_geometric/nn/models/__init__.py | 2 + torch_geometric/nn/models/gnan.py | 183 +++++++++++++++++++ 5 files changed, 544 insertions(+) create mode 100644 examples/gnan_graph_mutagenicity.py create mode 100644 test/nn/models/test_gnan.py create mode 100644 torch_geometric/loader/gnan_dataloader.py create mode 100644 torch_geometric/nn/models/gnan.py diff --git a/examples/gnan_graph_mutagenicity.py b/examples/gnan_graph_mutagenicity.py new file mode 100644 index 000000000000..15ed3f4aed33 --- /dev/null +++ b/examples/gnan_graph_mutagenicity.py @@ -0,0 +1,206 @@ +"""Training GNAN (graph‐level) on the Mutagenicity dataset. + +This reproduces, in simplified form, the experiment from the GNAN paper +(Bechler-Speicher *et al.*, 2024) on the Mutagenicity molecule dataset. + +The script demonstrates how to: +1. Load the Mutagenicity dataset from TUDataset. +2. Pre-compute all-pairs shortest-path distances **per graph** and the + corresponding normalisation matrices required by GNAN. +3. Train the *TensorGNAN* model for graph classification. + +Run with: + + python examples/gnan_graph_mutagenicity.py + +Graphs in Mutagenicity are small (≈30 nodes), therefore the dense distance +matrix fits comfortably in memory and can be coqmputed on the fly. +""" + +from __future__ import annotations + +import random +from pathlib import Path +from typing import Tuple + +import networkx as nx +import torch +from torch import nn +from torch_geometric.datasets import TUDataset +from torch_geometric.nn.models import TensorGNAN +from torch_geometric.loader.gnan_dataloader import GNANDataLoader +from torch_geometric.utils import to_networkx +from tqdm import tqdm +import pickle + + +def compute_dist_and_norm(data) -> Tuple[torch.Tensor, torch.Tensor]: + """Return (distance_matrix, normalisation_matrix) for a PyG Data graph.""" + def norm_from_dist(dist: torch.Tensor) -> torch.Tensor: + N = dist.size(0) + norm = torch.zeros_like(dist) + for i in range(N): + row = dist[i] + # Consider only *finite* distances when counting + finite_mask = torch.isfinite(row) + counts = torch.bincount(row[finite_mask].long( + )) if finite_mask.any() else torch.tensor([], dtype=torch.long) + + for j in range(N): + if not torch.isfinite(row[j]): + # No path ⇒ normalisation of 1 to avoid division by zero + norm[i, j] = 1.0 + else: + d = int(row[j].item()) + norm[i, j] = counts[d] if d < len(counts) else 1.0 + # Safety: ensure no zeros + norm[norm == 0] = 1.0 + return norm + + g = to_networkx(data, to_undirected=True) + sp = dict(nx.all_pairs_shortest_path_length(g)) + + N = data.num_nodes + # Initialise with +inf to mark "no path" entries explicitly + dist = torch.full((N, N), float('inf'), dtype=torch.float) + + # Distance from each node to itself is 0 by definition + dist.fill_diagonal_(0.0) + + # Fill finite shortest-path lengths returned by NetworkX + for i, lengths in sp.items(): + for j, d in lengths.items(): + dist[i, j] = float(d) + + # Compute the normalisation matrix; unreachable pairs (inf) get count 1 + norm = norm_from_dist(dist) + return dist, norm + + +class PreprocessDistances: + """PyG Transform that adds GNAN distance attributes to each graph.""" + def __call__(self, data): # noqa: D401 + dist, norm = compute_dist_and_norm(data) + data.node_distances = dist + data.normalization_matrix = norm + return data + + +# ----------------------------------------------------------------------------- + + +def seed_everything(seed: int = 42): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +# ----------------------------------------------------------------------------- + + +def main(): + seed_everything() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + root = Path("data/Mutagenicity") + + # Save/load preprocessed dataset to avoid recomputation + import os + processed_path = root / "mutagenicity_gnan_preprocessed.pt" + if processed_path.exists(): + print(f"Loading preprocessed dataset from {processed_path}") + try: + dataset = torch.load(processed_path) + except (pickle.UnpicklingError, RuntimeError): + print("Could not load preprocessed file, re-creating...") + os.remove(processed_path) + dataset = TUDataset(root=str(root), name="Mutagenicity", + transform=PreprocessDistances()) + torch.save(dataset, processed_path) + else: + print("Preprocessing dataset and saving to disk...") + dataset = TUDataset(root=str(root), name="Mutagenicity", + transform=PreprocessDistances()) + torch.save(dataset, processed_path) + + num_classes = dataset.num_classes + in_channels = dataset.num_features + + print(f"Dataset info: {num_classes} classes, {in_channels} features") + print(f"Dataset size: {len(dataset)} graphs") + + # Simple 80/10/10 split + indices = list(range(len(dataset))) + random.shuffle(indices) + n_train = int(0.8 * len(indices)) + n_val = int(0.1 * len(indices)) + + # print("=" * 60) + train_dataset = dataset[indices[:n_train]] + val_dataset = dataset[indices[n_train:n_train + n_val]] + test_dataset = dataset[indices[n_train + n_val:]] + + # standard PyTorch DataLoader with custom collate function + train_loader = GNANDataLoader(train_dataset, batch_size=1, shuffle=True) + val_loader = GNANDataLoader(val_dataset, batch_size=32, shuffle=False) + test_loader = GNANDataLoader(test_dataset, batch_size=32, shuffle=False) + + model = TensorGNAN( + in_channels=in_channels, + out_channels=1 if num_classes == 2 else num_classes, + n_layers=3, + hidden_channels=64, + dropout=0.3, + normalize_rho=False, + ).to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, + weight_decay=5e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode='max', factor=0.5, patience=3) + criterion = nn.BCEWithLogitsLoss() + + def evaluate(loader): + model.eval() + correct = 0 + for data in tqdm(loader, desc="Evaluating"): + data = data.to(device) + out = model(data) + pred = out.squeeze() > 0 + correct += int((pred == data.y.to(device)).sum()) + return correct / len(loader.dataset) + + best_val_acc = 0.0 + for epoch in range(1, 11): + model.train() + total_loss = 0.0 + num_batches = 0 + + for data in tqdm(train_loader, desc="Training"): + data = data.to(device) + optimizer.zero_grad() + out = model(data) + loss = criterion(out.squeeze(-1), data.y.to(device).float()) + loss.backward() + optimizer.step() + total_loss += loss.item() + num_batches += 1 + + avg_loss = total_loss / num_batches + val_acc = evaluate(val_loader) + test_acc = evaluate(test_loader) + best_val_acc = max(best_val_acc, val_acc) + + print(f"Epoch {epoch:03d} Loss {avg_loss:.4f} " + f"ValAcc {val_acc:.4f} TestAcc {test_acc:.4f}") + + # Step the scheduler based on validation accuracy + scheduler.step(val_acc) + + print("Best validation accuracy:", best_val_acc) + print("Final test accuracy:", evaluate(test_loader)) + + +if __name__ == "__main__": + main() diff --git a/test/nn/models/test_gnan.py b/test/nn/models/test_gnan.py new file mode 100644 index 000000000000..79086ae9fdbc --- /dev/null +++ b/test/nn/models/test_gnan.py @@ -0,0 +1,24 @@ +import torch +from torch_geometric.data import Data +from torch_geometric.nn.models import TensorGNAN + + +def _dummy_data(num_nodes: int = 5, num_feats: int = 4): + x = torch.randn(num_nodes, num_feats) + edge_index = torch.combinations(torch.arange(num_nodes), r=2).t() + # full distance matrix: + rand_dist = torch.rand(num_nodes, num_nodes) + rand_dist = (rand_dist + rand_dist.t()) / 2 # symmetric + norm = rand_dist.sum(dim=1) + data = Data(x=x, edge_index=edge_index) + data.node_distances = rand_dist + data.normalization_matrix = norm + return data + + +def test_tensor_gnan_graph_level(): + data = _dummy_data() + model = TensorGNAN(in_channels=data.num_features, out_channels=2, + n_layers=2, hidden_channels=8) + out = model(data) # [1, 2] + assert out.shape == (1, 2) diff --git a/torch_geometric/loader/gnan_dataloader.py b/torch_geometric/loader/gnan_dataloader.py new file mode 100644 index 000000000000..8fb1e5637d0a --- /dev/null +++ b/torch_geometric/loader/gnan_dataloader.py @@ -0,0 +1,129 @@ +from collections.abc import Sequence +from typing import List, Optional, Union + +import torch +from torch.utils.data import DataLoader as PyTorchDataLoader + +from torch_geometric.data import Batch, Dataset +from torch_geometric.data.data import BaseData +from torch_geometric.data.datapipes import DatasetAdapter + + +def _create_block_diagonal_matrix( + matrix_list: List[torch.Tensor]) -> torch.Tensor: + r"""Create a block diagonal matrix from a list of matrices.""" + if not matrix_list: + return torch.empty(0, 0) + + total_size = sum(matrix.size(0) for matrix in matrix_list) + result = torch.zeros( + total_size, + total_size, + dtype=matrix_list[0].dtype, + device=matrix_list[0].device, + ) + + offset = 0 + for matrix in matrix_list: + size = matrix.size(0) + result[offset:offset + size, offset:offset + size] = matrix + offset += size + + return result + + +class GNANCollater: + def __init__( + self, + follow_batch: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + ): + self.follow_batch = follow_batch + self.exclude_keys = exclude_keys + + def __call__(self, data_list: List[BaseData]) -> Batch: + node_distances_list = [] + normalization_matrix_list = [] + + has_node_distances = hasattr(data_list[0], 'node_distances') + has_normalization_matrix = hasattr(data_list[0], + 'normalization_matrix') + + if has_node_distances: + for data in data_list: + node_distances_list.append(data.node_distances) + delattr(data, 'node_distances') + + if has_normalization_matrix: + for data in data_list: + normalization_matrix_list.append(data.normalization_matrix) + delattr(data, 'normalization_matrix') + + batch = Batch.from_data_list( + data_list, + follow_batch=self.follow_batch, + exclude_keys=self.exclude_keys, + ) + + for i, data in enumerate(data_list): + if has_node_distances: + setattr(data, 'node_distances', node_distances_list[i]) + if has_normalization_matrix: + setattr(data, 'normalization_matrix', + normalization_matrix_list[i]) + + if node_distances_list: + batch.node_distances = _create_block_diagonal_matrix( + node_distances_list) + if normalization_matrix_list: + batch.normalization_matrix = _create_block_diagonal_matrix( + normalization_matrix_list) + + return batch + + +class GNANDataLoader(PyTorchDataLoader): + r"""A data loader which merges data objects from a + :class:`torch_geometric.data.Dataset` to a mini-batch, specifically for + use with the :class:`torch_geometric.nn.models.TensorGNAN` model. + + This loader will batch the :obj:`node_distances` and + :obj:`normalization_matrix` attributes of + :class:`~torch_geometric.data.Data` objects by creating large block- + diagonal matrices. + + For this to work, every data object in the dataset needs to have the + attributes :obj:`node_distances` and :obj:`normalization_matrix`. + + Args: + dataset (Dataset): The dataset from which to load the data. + batch_size (int, optional): How many samples per batch to load. + (default: :obj:`1`) + shuffle (bool, optional): If set to :obj:`True`, the data will be + reshuffled at every epoch. (default: :obj:`False`) + follow_batch (List[str], optional): Creates assignment batch + vectors for each key in the list. (default: :obj:`None`) + exclude_keys (List[str], optional): Will exclude each key in the + list. (default: :obj:`None`) + **kwargs (optional): Additional arguments of + :class:`torch.utils.data.DataLoader`. + """ + def __init__( + self, + dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], + batch_size: int = 1, + shuffle: bool = False, + follow_batch: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + **kwargs, + ): + # Remove for PyTorch Lightning: + kwargs.pop('collate_fn', None) + + super().__init__( + dataset, + batch_size, + shuffle, + collate_fn=GNANCollater(follow_batch, exclude_keys), + **kwargs, + ) diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 269fed1da780..2a9679a01f47 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -40,6 +40,7 @@ from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) from .attract_repel import ARLinkPredictor +from .gnan import TensorGNAN __all__ = classes = [ 'MLP', @@ -93,4 +94,5 @@ 'SGFormer', 'Polynormer', 'ARLinkPredictor', + 'TensorGNAN', ] diff --git a/torch_geometric/nn/models/gnan.py b/torch_geometric/nn/models/gnan.py new file mode 100644 index 000000000000..e28a3ea1e9bc --- /dev/null +++ b/torch_geometric/nn/models/gnan.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import torch +from torch import nn +from torch_geometric.typing import OptTensor +from torch_geometric.data import Data, Batch +from torch_geometric.utils import scatter + +__all__ = [ + 'TensorGNAN', +] + + +def _init_weights(module: nn.Module, std: float = 1.0): + """Utility that mimics Xavier initialisation with configurable std.""" + for name, param in module.named_parameters(): + if 'weight' in name: + nn.init.xavier_normal_(param, gain=std) + elif 'bias' in name: + nn.init.constant_(param, 0.0) + + +class _PerFeatureMLP(nn.Module): + """Simple MLP that is applied to a single scalar feature. + + Args: + out_channels (int): Output dimension per feature ("f" in the paper). + n_layers (int): Number of layers. If ``1``, the MLP is a single Linear. + hidden_channels (int, optional): Hidden dimension. Required when + ``n_layers > 1``. + bias (bool, optional): Use bias terms. (default: ``True``) + dropout (float, optional): Dropout probability after hidden layers. + (default: ``0.0``) + """ + def __init__( + self, + out_channels: int, + n_layers: int, + hidden_channels: int | None = None, + *, + bias: bool = True, + dropout: float = 0.0, + ) -> None: + super().__init__() + + if n_layers == 1: + self.net = nn.Linear(1, out_channels, bias=bias) + else: + assert hidden_channels is not None + layers: list[nn.Module] = [ + nn.Linear(1, hidden_channels, bias=bias), + nn.ReLU(), + nn.Dropout(dropout), + ] + for _ in range(1, n_layers - 1): + layers += [ + nn.Linear(hidden_channels, hidden_channels, bias=bias), + nn.ReLU(), + nn.Dropout(dropout), + ] + layers.append(nn.Linear(hidden_channels, out_channels, bias=bias)) + self.net = nn.Sequential(*layers) + + _init_weights(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [N] + return self.net(x.view(-1, 1)) # [N, out_channels] + + +class _RhoMLP(nn.Module): + """MLP that turns a scalar distance into a scalar or vector weight.""" + def __init__( + self, + out_channels: int, + n_layers: int, + hidden_channels: int | None = None, + *, + bias: bool = True, + dropout: float = 0.0, + ) -> None: + super().__init__() + + if n_layers == 1: + self.net = nn.Linear(1, out_channels, bias=bias) + else: + assert hidden_channels is not None + layers: list[nn.Module] = [ + nn.Linear(1, hidden_channels, bias=bias), + nn.ReLU() + ] + for _ in range(1, n_layers - 1): + layers += [ + nn.Linear(hidden_channels, hidden_channels, bias=bias), + nn.ReLU(), + ] + layers.append(nn.Linear(hidden_channels, out_channels, bias=bias)) + self.net = nn.Sequential(*layers) + + _init_weights(self) + + def forward(self, d: torch.Tensor) -> torch.Tensor: # [...] + return self.net(d.view(-1, 1)) + + +class TensorGNAN(nn.Module): + r"""Dense, tensorised GNAN variant. + + By default it aggregates node scores to produce *graph‐level* predictions + (shape ``[batch_size, out_channels]``). Set ``graph_level=False`` to + obtain *node‐level* predictions instead, in which case the forward returns + a tensor of shape ``[num_nodes, out_channels]`` or ``[len(node_ids), + out_channels]`` if ``node_ids`` is provided. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + n_layers: int, + *, + hidden_channels: int | None = None, + bias: bool = True, + dropout: float = 0.0, + normalize_rho: bool = True, + graph_level: bool = True, + ) -> None: + super().__init__() + + self.normalize_rho = normalize_rho + self.graph_level = graph_level + self.out_channels = out_channels + + self.fs = nn.ModuleList([ + _PerFeatureMLP(out_channels, n_layers, hidden_channels, bias=bias, + dropout=dropout) for _ in range(in_channels) + ]) + self.rho = _RhoMLP(out_channels, n_layers, hidden_channels, bias=True) + + # -------------------------------------------------------------- + def forward(self, data: Data | Batch, + node_ids: OptTensor = None) -> torch.Tensor: + x: torch.Tensor = data.x # type: ignore # [N, F] + dist: torch.Tensor = data.node_distances # type: ignore # [N, N] + norm: torch.Tensor = data.normalization_matrix # type: ignore # [N, N] + + # f_k(x_k) + fx = torch.stack([mlp(x[:, k]) for k, mlp in enumerate(self.fs)], + dim=1) # [N, F, C] + + # Compute ρ on the inverted distances as suggested in the paper + inv_dist = 1.0 / (1.0 + dist) # [N, N] + rho = self.rho(inv_dist.flatten().view(-1, 1)) # [(N*N), C] + rho = rho.view(x.size(0), x.size(0), self.out_channels) # [N, N, C] + + if self.normalize_rho: + norm[norm == 0] = 1.0 + rho = rho / norm.unsqueeze(-1) # broadcast + + # Apply a mask to ρ to prevent information leakage between + # graphs in the same batch. + if hasattr(data, 'batch') and data.batch is not None: + batch_i = data.batch.view(-1, 1) # type: ignore + batch_j = data.batch.view(1, -1) # type: ignore + mask = (batch_i == batch_j).unsqueeze(-1) + rho = rho * mask + + # Perform Σ_i Σ_j ρ(d_ij) Σ_k f_k(x_jk) + f_sum = fx.sum(dim=1) # [N, C] + out = torch.einsum('ijc,jc->ic', rho, f_sum) # [N, C] + + if self.graph_level: + # # Use batch information for proper graph-level aggregation + batch = data.batch # type: ignore + if batch is not None: + graph_out = scatter(out, batch, dim=0, reduce='add') + else: + # Single graph case + graph_out = out.sum(dim=0, keepdim=True) # [1, C] + return graph_out + + # --- node-level mode ------------------------------------------------- + if node_ids is not None: + return out[node_ids] + return out From 6d245856f6f25990b3203ea568d84c42f2752702 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 19 Jul 2025 17:52:24 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/gnan_graph_mutagenicity.py | 10 +++++----- test/nn/models/test_gnan.py | 1 + torch_geometric/loader/gnan_dataloader.py | 5 ++--- torch_geometric/nn/models/gnan.py | 3 ++- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/gnan_graph_mutagenicity.py b/examples/gnan_graph_mutagenicity.py index 15ed3f4aed33..c3de1f79a777 100644 --- a/examples/gnan_graph_mutagenicity.py +++ b/examples/gnan_graph_mutagenicity.py @@ -19,22 +19,22 @@ from __future__ import annotations +import pickle import random from pathlib import Path -from typing import Tuple import networkx as nx import torch from torch import nn +from tqdm import tqdm + from torch_geometric.datasets import TUDataset -from torch_geometric.nn.models import TensorGNAN from torch_geometric.loader.gnan_dataloader import GNANDataLoader +from torch_geometric.nn.models import TensorGNAN from torch_geometric.utils import to_networkx -from tqdm import tqdm -import pickle -def compute_dist_and_norm(data) -> Tuple[torch.Tensor, torch.Tensor]: +def compute_dist_and_norm(data) -> tuple[torch.Tensor, torch.Tensor]: """Return (distance_matrix, normalisation_matrix) for a PyG Data graph.""" def norm_from_dist(dist: torch.Tensor) -> torch.Tensor: N = dist.size(0) diff --git a/test/nn/models/test_gnan.py b/test/nn/models/test_gnan.py index 79086ae9fdbc..6fbd17d824bf 100644 --- a/test/nn/models/test_gnan.py +++ b/test/nn/models/test_gnan.py @@ -1,4 +1,5 @@ import torch + from torch_geometric.data import Data from torch_geometric.nn.models import TensorGNAN diff --git a/torch_geometric/loader/gnan_dataloader.py b/torch_geometric/loader/gnan_dataloader.py index 8fb1e5637d0a..7352896555e4 100644 --- a/torch_geometric/loader/gnan_dataloader.py +++ b/torch_geometric/loader/gnan_dataloader.py @@ -67,10 +67,9 @@ def __call__(self, data_list: List[BaseData]) -> Batch: for i, data in enumerate(data_list): if has_node_distances: - setattr(data, 'node_distances', node_distances_list[i]) + data.node_distances = node_distances_list[i] if has_normalization_matrix: - setattr(data, 'normalization_matrix', - normalization_matrix_list[i]) + data.normalization_matrix = normalization_matrix_list[i] if node_distances_list: batch.node_distances = _create_block_diagonal_matrix( diff --git a/torch_geometric/nn/models/gnan.py b/torch_geometric/nn/models/gnan.py index e28a3ea1e9bc..66d7585515eb 100644 --- a/torch_geometric/nn/models/gnan.py +++ b/torch_geometric/nn/models/gnan.py @@ -2,8 +2,9 @@ import torch from torch import nn + +from torch_geometric.data import Batch, Data from torch_geometric.typing import OptTensor -from torch_geometric.data import Data, Batch from torch_geometric.utils import scatter __all__ = [ From 393b6aba744b7bf22db0aa3404898b7d57334df6 Mon Sep 17 00:00:00 2001 From: yoavk Date: Mon, 4 Aug 2025 15:26:55 +0000 Subject: [PATCH 03/13] add feature groups --- examples/gnan_graph_mutagenicity.py | 2 + torch_geometric/nn/models/gnan.py | 123 ++++++++++++++++++++++++++-- 2 files changed, 118 insertions(+), 7 deletions(-) diff --git a/examples/gnan_graph_mutagenicity.py b/examples/gnan_graph_mutagenicity.py index c3de1f79a777..b996b75a946f 100644 --- a/examples/gnan_graph_mutagenicity.py +++ b/examples/gnan_graph_mutagenicity.py @@ -153,6 +153,8 @@ def main(): hidden_channels=64, dropout=0.3, normalize_rho=False, + # Uncomment the following line to group features together: + # feature_groups=[list(range(in_channels))] ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, diff --git a/torch_geometric/nn/models/gnan.py b/torch_geometric/nn/models/gnan.py index 66d7585515eb..fbac1436a390 100644 --- a/torch_geometric/nn/models/gnan.py +++ b/torch_geometric/nn/models/gnan.py @@ -68,6 +68,55 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [N] return self.net(x.view(-1, 1)) # [N, out_channels] +class _MultiFeatureMLP(nn.Module): + """MLP that processes multiple features together. + + Args: + in_channels (int): Number of input features to process together. + out_channels (int): Output dimension per feature group. + n_layers (int): Number of layers. If ``1``, the MLP is a single Linear. + hidden_channels (int, optional): Hidden dimension. Required when + ``n_layers > 1``. + bias (bool, optional): Use bias terms. (default: ``True``) + dropout (float, optional): Dropout probability after hidden layers. + (default: ``0.0``) + """ + def __init__( + self, + in_channels: int, + out_channels: int, + n_layers: int, + hidden_channels: int | None = None, + *, + bias: bool = True, + dropout: float = 0.0, + ) -> None: + super().__init__() + + if n_layers == 1: + self.net = nn.Linear(in_channels, out_channels, bias=bias) + else: + assert hidden_channels is not None + layers: list[nn.Module] = [ + nn.Linear(in_channels, hidden_channels, bias=bias), + nn.ReLU(), + nn.Dropout(dropout), + ] + for _ in range(1, n_layers - 1): + layers += [ + nn.Linear(hidden_channels, hidden_channels, bias=bias), + nn.ReLU(), + nn.Dropout(dropout), + ] + layers.append(nn.Linear(hidden_channels, out_channels, bias=bias)) + self.net = nn.Sequential(*layers) + + _init_weights(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [N, in_channels] + return self.net(x) # [N, out_channels] + + class _RhoMLP(nn.Module): """MLP that turns a scalar distance into a scalar or vector weight.""" def __init__( @@ -111,6 +160,22 @@ class TensorGNAN(nn.Module): obtain *node‐level* predictions instead, in which case the forward returns a tensor of shape ``[num_nodes, out_channels]`` or ``[len(node_ids), out_channels]`` if ``node_ids`` is provided. + + Args: + in_channels (int): Number of input node features. + out_channels (int): Output dimension. + n_layers (int): Number of layers in the MLPs. + hidden_channels (int, optional): Hidden dimension in the MLPs. + bias (bool, optional): Use bias terms. (default: ``True``) + dropout (float, optional): Dropout probability. (default: ``0.0``) + normalize_rho (bool, optional): Whether to normalize rho weights. + (default: ``True``) + graph_level (bool, optional): Whether to produce graph-level predictions. + (default: ``True``) + feature_groups (List[List[int]], optional): Groups of feature indices to + process together. Each group will be processed by a single MLP that + takes multiple features as input. If None, each feature is processed + by its own MLP (default behavior). (default: ``None``) """ def __init__( self, @@ -123,17 +188,50 @@ def __init__( dropout: float = 0.0, normalize_rho: bool = True, graph_level: bool = True, + feature_groups: list[list[int]] | None = None, ) -> None: super().__init__() self.normalize_rho = normalize_rho self.graph_level = graph_level self.out_channels = out_channels + self.in_channels = in_channels + + # Set up feature groups - default is each feature in its own group + if feature_groups is None: + self.feature_groups = [[i] for i in range(in_channels)] + else: + self.feature_groups = feature_groups + # Validate feature groups + all_features = set() + for group in feature_groups: + if not group: + raise ValueError("Feature groups cannot be empty") + for feat_idx in group: + if feat_idx < 0 or feat_idx >= in_channels: + raise ValueError(f"Feature index {feat_idx} out of range [0, {in_channels})") + if feat_idx in all_features: + raise ValueError(f"Feature index {feat_idx} appears in multiple groups") + all_features.add(feat_idx) + + if len(all_features) != in_channels: + missing = set(range(in_channels)) - all_features + raise ValueError(f"Missing feature indices in groups: {missing}") + + # Create MLPs for each feature group + self.fs = nn.ModuleList() + for group in self.feature_groups: + group_size = len(group) + if group_size == 1: + # Single feature - use original MLP + mlp = _PerFeatureMLP(out_channels, n_layers, hidden_channels, + bias=bias, dropout=dropout) + else: + # Multiple features - use new multi-feature MLP + mlp = _MultiFeatureMLP(group_size, out_channels, n_layers, + hidden_channels, bias=bias, dropout=dropout) + self.fs.append(mlp) - self.fs = nn.ModuleList([ - _PerFeatureMLP(out_channels, n_layers, hidden_channels, bias=bias, - dropout=dropout) for _ in range(in_channels) - ]) self.rho = _RhoMLP(out_channels, n_layers, hidden_channels, bias=True) # -------------------------------------------------------------- @@ -143,9 +241,20 @@ def forward(self, data: Data | Batch, dist: torch.Tensor = data.node_distances # type: ignore # [N, N] norm: torch.Tensor = data.normalization_matrix # type: ignore # [N, N] - # f_k(x_k) - fx = torch.stack([mlp(x[:, k]) for k, mlp in enumerate(self.fs)], - dim=1) # [N, F, C] + # Process features according to groups + fx_list = [] + for group, mlp in zip(self.feature_groups, self.fs): + if len(group) == 1: + # Single feature - extract and process + feat_tensor = x[:, group[0]] # [N] + group_output = mlp(feat_tensor) # [N, C] + else: + # Multiple features - extract and process together + feat_tensor = x[:, group] # [N, group_size] + group_output = mlp(feat_tensor) # [N, C] + fx_list.append(group_output) + + fx = torch.stack(fx_list, dim=1) # [N, num_groups, C] # Compute ρ on the inverted distances as suggested in the paper inv_dist = 1.0 / (1.0 + dist) # [N, N] From edba05fad088320792a364310da8ce90bde203b8 Mon Sep 17 00:00:00 2001 From: yoavk Date: Tue, 5 Aug 2025 09:03:43 +0000 Subject: [PATCH 04/13] add another example to grouping features --- examples/gnan_graph_mutagenicity.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/gnan_graph_mutagenicity.py b/examples/gnan_graph_mutagenicity.py index b996b75a946f..80821ad38d0a 100644 --- a/examples/gnan_graph_mutagenicity.py +++ b/examples/gnan_graph_mutagenicity.py @@ -155,6 +155,7 @@ def main(): normalize_rho=False, # Uncomment the following line to group features together: # feature_groups=[list(range(in_channels))] + # feature_groups=[[0,1,2,3,4,5,6,7,8,9,10,11],[12,13]] ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, From ab1924be574e61512fb8474d77cb321a2c6ef371 Mon Sep 17 00:00:00 2001 From: yoavk Date: Tue, 5 Aug 2025 09:04:50 +0000 Subject: [PATCH 05/13] add node importance to gnan model --- examples/gnan_graph_mutagenicity.py | 19 ++++++++ torch_geometric/nn/models/gnan.py | 73 +++++++++++++++++------------ 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/examples/gnan_graph_mutagenicity.py b/examples/gnan_graph_mutagenicity.py index 80821ad38d0a..8971e6cebfcd 100644 --- a/examples/gnan_graph_mutagenicity.py +++ b/examples/gnan_graph_mutagenicity.py @@ -146,6 +146,10 @@ def main(): val_loader = GNANDataLoader(val_dataset, batch_size=32, shuffle=False) test_loader = GNANDataLoader(test_dataset, batch_size=32, shuffle=False) + # Pick a sample graph from the *test* split to track during training. + sample_graph = test_dataset[0] + sample_graph = sample_graph.to(device) + model = TensorGNAN( in_channels=in_channels, out_channels=1 if num_classes == 2 else num_classes, @@ -198,6 +202,21 @@ def evaluate(loader): print(f"Epoch {epoch:03d} Loss {avg_loss:.4f} " f"ValAcc {val_acc:.4f} TestAcc {test_acc:.4f}") + # -------------------------------------------------------------- + # Print prediction & node importance for the tracked sample graph + # -------------------------------------------------------------- + with torch.no_grad(): + sample_pred = model(sample_graph) + sample_imp = model.node_importance(sample_graph) + + # Flatten tensors for nicer printing (binary classification → 1 logit) + pred_value = sample_pred.squeeze().item() + node_imp_values = sample_imp.squeeze(-1).cpu().tolist() + + print("Sample graph prediction:", f"{pred_value:.4f}") + print("Node importance contributions:") + print(node_imp_values) + # Step the scheduler based on validation accuracy scheduler.step(val_acc) diff --git a/torch_geometric/nn/models/gnan.py b/torch_geometric/nn/models/gnan.py index fbac1436a390..d465c05376a1 100644 --- a/torch_geometric/nn/models/gnan.py +++ b/torch_geometric/nn/models/gnan.py @@ -234,60 +234,75 @@ def __init__( self.rho = _RhoMLP(out_channels, n_layers, hidden_channels, bias=True) - # -------------------------------------------------------------- - def forward(self, data: Data | Batch, - node_ids: OptTensor = None) -> torch.Tensor: - x: torch.Tensor = data.x # type: ignore # [N, F] - dist: torch.Tensor = data.node_distances # type: ignore # [N, N] - norm: torch.Tensor = data.normalization_matrix # type: ignore # [N, N] - - # Process features according to groups - fx_list = [] + def _process_feature_groups(self, x: torch.Tensor): + """Process features according to groups and return fx and f_sum.""" + fx_list: list[torch.Tensor] = [] for group, mlp in zip(self.feature_groups, self.fs): if len(group) == 1: - # Single feature - extract and process feat_tensor = x[:, group[0]] # [N] - group_output = mlp(feat_tensor) # [N, C] else: - # Multiple features - extract and process together - feat_tensor = x[:, group] # [N, group_size] - group_output = mlp(feat_tensor) # [N, C] - fx_list.append(group_output) - - fx = torch.stack(fx_list, dim=1) # [N, num_groups, C] - - # Compute ρ on the inverted distances as suggested in the paper + feat_tensor = x[:, group] # [N, |group|] + fx_list.append(mlp(feat_tensor)) # [N, C] + fx = torch.stack(fx_list, dim=1) # [N, num_groups, C] + f_sum = fx.sum(dim=1) # [N, C] + return fx, f_sum + + def _compute_rho(self, dist: torch.Tensor, norm: torch.Tensor, data) -> torch.Tensor: + """Compute rho tensor, including normalization and masking.""" + x = data.x # type: ignore inv_dist = 1.0 / (1.0 + dist) # [N, N] rho = self.rho(inv_dist.flatten().view(-1, 1)) # [(N*N), C] rho = rho.view(x.size(0), x.size(0), self.out_channels) # [N, N, C] - if self.normalize_rho: - norm[norm == 0] = 1.0 - rho = rho / norm.unsqueeze(-1) # broadcast - - # Apply a mask to ρ to prevent information leakage between - # graphs in the same batch. + norm_safe = norm.clone() + norm_safe[norm_safe == 0] = 1.0 + rho = rho / norm_safe.unsqueeze(-1) # broadcast division if hasattr(data, 'batch') and data.batch is not None: batch_i = data.batch.view(-1, 1) # type: ignore batch_j = data.batch.view(1, -1) # type: ignore mask = (batch_i == batch_j).unsqueeze(-1) rho = rho * mask + return rho + + def forward(self, data: Data | Batch, + node_ids: OptTensor = None) -> torch.Tensor: + x: torch.Tensor = data.x # type: ignore # [N, F] + dist: torch.Tensor = data.node_distances # type: ignore # [N, N] + norm: torch.Tensor = data.normalization_matrix # type: ignore # [N, N] + + _, f_sum = self._process_feature_groups(x) + rho = self._compute_rho(dist, norm, data) # Perform Σ_i Σ_j ρ(d_ij) Σ_k f_k(x_jk) - f_sum = fx.sum(dim=1) # [N, C] out = torch.einsum('ijc,jc->ic', rho, f_sum) # [N, C] if self.graph_level: - # # Use batch information for proper graph-level aggregation batch = data.batch # type: ignore if batch is not None: graph_out = scatter(out, batch, dim=0, reduce='add') else: - # Single graph case graph_out = out.sum(dim=0, keepdim=True) # [1, C] return graph_out - # --- node-level mode ------------------------------------------------- if node_ids is not None: return out[node_ids] return out + + def node_importance(self, data: Data | Batch) -> torch.Tensor: + """Returns the contribution of every node to the + graph‐level prediction. UsingEq. (3) in the paper. + """ + x: torch.Tensor = data.x # type: ignore # [N, F] + dist: torch.Tensor = data.node_distances # type: ignore # [N, N] + norm: torch.Tensor = data.normalization_matrix # type: ignore # [N, N] + + _, f_sum = self._process_feature_groups(x) + rho = self._compute_rho(dist, norm, data) + + # Aggregate over *receiver* nodes i to obtain \sum_i rho(d_{ij}). + rho_sum_over_i = rho.sum(dim=0) # [N, C] + + # Node contribution s_j = (sum_k f_k(x_jk)) * (sum_i rho(d_{ij})). + node_contrib = f_sum * rho_sum_over_i # [N, C] + + return node_contrib From 9c33037dcfddf990281e897dd6061c74defb5cb1 Mon Sep 17 00:00:00 2001 From: yoavk Date: Tue, 5 Aug 2025 09:05:18 +0000 Subject: [PATCH 06/13] add tests for feature groups and node importance --- test/nn/models/test_gnan.py | 59 ++++++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/test/nn/models/test_gnan.py b/test/nn/models/test_gnan.py index 6fbd17d824bf..df0213a14375 100644 --- a/test/nn/models/test_gnan.py +++ b/test/nn/models/test_gnan.py @@ -10,7 +10,8 @@ def _dummy_data(num_nodes: int = 5, num_feats: int = 4): # full distance matrix: rand_dist = torch.rand(num_nodes, num_nodes) rand_dist = (rand_dist + rand_dist.t()) / 2 # symmetric - norm = rand_dist.sum(dim=1) + # Use a simple normalisation matrix of ones to avoid division issues + norm = torch.ones_like(rand_dist) data = Data(x=x, edge_index=edge_index) data.node_distances = rand_dist data.normalization_matrix = norm @@ -23,3 +24,59 @@ def test_tensor_gnan_graph_level(): n_layers=2, hidden_channels=8) out = model(data) # [1, 2] assert out.shape == (1, 2) + + +# ----------------------------------------------------------------------------- +# New tests for feature grouping and node importance +# ----------------------------------------------------------------------------- + + +def test_tensor_gnan_feature_groups(): + """Ensure model works correctly with custom feature grouping.""" + data = _dummy_data(num_nodes=6, num_feats=4) + + # Group features [0,1] together; keep [2] and [3] separate. + feature_groups = [[0, 1], [2], [3]] + + model = TensorGNAN( + in_channels=data.num_features, + out_channels=3, + n_layers=1, # single Linear layer per MLP for easier inspection + feature_groups=feature_groups, + normalize_rho=False, # avoid dependence on normalisation matrix shape + ) + + out = model(data) # [1, 3] + + # Forward pass shape check + assert out.shape == (1, 3) + + # There should be exactly len(feature_groups) MLPs + assert len(model.fs) == len(feature_groups) + + # The first MLP should accept 2 input features (because group size == 2) + first_mlp = model.fs[0] + assert isinstance(first_mlp.net, torch.nn.Linear) + assert first_mlp.net.in_features == 2 + + +def test_tensor_gnan_node_importance(): + """Node contributions should sum to graph‐level prediction.""" + data = _dummy_data(num_nodes=5, num_feats=3) + + model = TensorGNAN( + in_channels=data.num_features, + out_channels=4, + n_layers=1, + normalize_rho=False, # simplifies the equality check + ) + + graph_out = model(data) # [1, 4] + node_contrib = model.node_importance(data) # [N, 4] + + # Shape checks + assert node_contrib.shape == (data.num_nodes, 4) + + # Sum of node contributions equals the graph‐level prediction (Eq. 3) + contrib_sum = node_contrib.sum(dim=0, keepdim=True) # [1, 4] + assert torch.allclose(contrib_sum, graph_out, atol=1e-5) From 8f0ee54ab8991a7faf109b04ea43e587286073ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Aug 2025 09:09:05 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/gnan_graph_mutagenicity.py | 4 +-- test/nn/models/test_gnan.py | 2 +- torch_geometric/nn/models/gnan.py | 39 +++++++++++++++++------------ 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/examples/gnan_graph_mutagenicity.py b/examples/gnan_graph_mutagenicity.py index 8971e6cebfcd..e049db1072f5 100644 --- a/examples/gnan_graph_mutagenicity.py +++ b/examples/gnan_graph_mutagenicity.py @@ -158,8 +158,8 @@ def main(): dropout=0.3, normalize_rho=False, # Uncomment the following line to group features together: - # feature_groups=[list(range(in_channels))] - # feature_groups=[[0,1,2,3,4,5,6,7,8,9,10,11],[12,13]] + # feature_groups=[list(range(in_channels))] + # feature_groups=[[0,1,2,3,4,5,6,7,8,9,10,11],[12,13]] ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, diff --git a/test/nn/models/test_gnan.py b/test/nn/models/test_gnan.py index df0213a14375..6d2cecbc51c0 100644 --- a/test/nn/models/test_gnan.py +++ b/test/nn/models/test_gnan.py @@ -71,7 +71,7 @@ def test_tensor_gnan_node_importance(): normalize_rho=False, # simplifies the equality check ) - graph_out = model(data) # [1, 4] + graph_out = model(data) # [1, 4] node_contrib = model.node_importance(data) # [N, 4] # Shape checks diff --git a/torch_geometric/nn/models/gnan.py b/torch_geometric/nn/models/gnan.py index d465c05376a1..5868e89486f6 100644 --- a/torch_geometric/nn/models/gnan.py +++ b/torch_geometric/nn/models/gnan.py @@ -160,7 +160,7 @@ class TensorGNAN(nn.Module): obtain *node‐level* predictions instead, in which case the forward returns a tensor of shape ``[num_nodes, out_channels]`` or ``[len(node_ids), out_channels]`` if ``node_ids`` is provided. - + Args: in_channels (int): Number of input node features. out_channels (int): Output dimension. @@ -209,14 +209,19 @@ def __init__( raise ValueError("Feature groups cannot be empty") for feat_idx in group: if feat_idx < 0 or feat_idx >= in_channels: - raise ValueError(f"Feature index {feat_idx} out of range [0, {in_channels})") + raise ValueError( + f"Feature index {feat_idx} out of range [0, {in_channels})" + ) if feat_idx in all_features: - raise ValueError(f"Feature index {feat_idx} appears in multiple groups") + raise ValueError( + f"Feature index {feat_idx} appears in multiple groups" + ) all_features.add(feat_idx) - + if len(all_features) != in_channels: missing = set(range(in_channels)) - all_features - raise ValueError(f"Missing feature indices in groups: {missing}") + raise ValueError( + f"Missing feature indices in groups: {missing}") # Create MLPs for each feature group self.fs = nn.ModuleList() @@ -224,12 +229,13 @@ def __init__( group_size = len(group) if group_size == 1: # Single feature - use original MLP - mlp = _PerFeatureMLP(out_channels, n_layers, hidden_channels, - bias=bias, dropout=dropout) + mlp = _PerFeatureMLP(out_channels, n_layers, hidden_channels, + bias=bias, dropout=dropout) else: # Multiple features - use new multi-feature MLP - mlp = _MultiFeatureMLP(group_size, out_channels, n_layers, - hidden_channels, bias=bias, dropout=dropout) + mlp = _MultiFeatureMLP(group_size, out_channels, n_layers, + hidden_channels, bias=bias, + dropout=dropout) self.fs.append(mlp) self.rho = _RhoMLP(out_channels, n_layers, hidden_channels, bias=True) @@ -241,13 +247,14 @@ def _process_feature_groups(self, x: torch.Tensor): if len(group) == 1: feat_tensor = x[:, group[0]] # [N] else: - feat_tensor = x[:, group] # [N, |group|] - fx_list.append(mlp(feat_tensor)) # [N, C] - fx = torch.stack(fx_list, dim=1) # [N, num_groups, C] - f_sum = fx.sum(dim=1) # [N, C] + feat_tensor = x[:, group] # [N, |group|] + fx_list.append(mlp(feat_tensor)) # [N, C] + fx = torch.stack(fx_list, dim=1) # [N, num_groups, C] + f_sum = fx.sum(dim=1) # [N, C] return fx, f_sum - def _compute_rho(self, dist: torch.Tensor, norm: torch.Tensor, data) -> torch.Tensor: + def _compute_rho(self, dist: torch.Tensor, norm: torch.Tensor, + data) -> torch.Tensor: """Compute rho tensor, including normalization and masking.""" x = data.x # type: ignore inv_dist = 1.0 / (1.0 + dist) # [N, N] @@ -300,9 +307,9 @@ def node_importance(self, data: Data | Batch) -> torch.Tensor: rho = self._compute_rho(dist, norm, data) # Aggregate over *receiver* nodes i to obtain \sum_i rho(d_{ij}). - rho_sum_over_i = rho.sum(dim=0) # [N, C] + rho_sum_over_i = rho.sum(dim=0) # [N, C] # Node contribution s_j = (sum_k f_k(x_jk)) * (sum_i rho(d_{ij})). - node_contrib = f_sum * rho_sum_over_i # [N, C] + node_contrib = f_sum * rho_sum_over_i # [N, C] return node_contrib From 976f076cb9af8f1456ae08ac8bdd028a0c1a1f5c Mon Sep 17 00:00:00 2001 From: yoavk Date: Wed, 6 Aug 2025 11:09:51 +0000 Subject: [PATCH 08/13] remove loading dataset from pickle for security reasins --- examples/gnan_graph_mutagenicity.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/examples/gnan_graph_mutagenicity.py b/examples/gnan_graph_mutagenicity.py index e049db1072f5..1f47101e4e79 100644 --- a/examples/gnan_graph_mutagenicity.py +++ b/examples/gnan_graph_mutagenicity.py @@ -106,23 +106,8 @@ def main(): root = Path("data/Mutagenicity") # Save/load preprocessed dataset to avoid recomputation - import os - processed_path = root / "mutagenicity_gnan_preprocessed.pt" - if processed_path.exists(): - print(f"Loading preprocessed dataset from {processed_path}") - try: - dataset = torch.load(processed_path) - except (pickle.UnpicklingError, RuntimeError): - print("Could not load preprocessed file, re-creating...") - os.remove(processed_path) - dataset = TUDataset(root=str(root), name="Mutagenicity", - transform=PreprocessDistances()) - torch.save(dataset, processed_path) - else: - print("Preprocessing dataset and saving to disk...") - dataset = TUDataset(root=str(root), name="Mutagenicity", - transform=PreprocessDistances()) - torch.save(dataset, processed_path) + dataset = TUDataset(root=str(root), name="Mutagenicity", + transform=PreprocessDistances()) num_classes = dataset.num_classes in_channels = dataset.num_features From b9f3cf6e6756db48a090ac479a318a8dfde0f9db Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Aug 2025 11:12:16 +0000 Subject: [PATCH 09/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/gnan_graph_mutagenicity.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/gnan_graph_mutagenicity.py b/examples/gnan_graph_mutagenicity.py index 1f47101e4e79..72576cb46cc4 100644 --- a/examples/gnan_graph_mutagenicity.py +++ b/examples/gnan_graph_mutagenicity.py @@ -19,7 +19,6 @@ from __future__ import annotations -import pickle import random from pathlib import Path From eddc6a77c8cfd9de57a122c916052b5b11e33cb9 Mon Sep 17 00:00:00 2001 From: yoavk Date: Sat, 23 Aug 2025 09:56:23 +0000 Subject: [PATCH 10/13] add test and fixed pre-commit errors --- test/nn/models/test_gnan.py | 84 +++++++++++++++++++++++++++++-- torch_geometric/nn/models/gnan.py | 32 ++++++------ 2 files changed, 95 insertions(+), 21 deletions(-) diff --git a/test/nn/models/test_gnan.py b/test/nn/models/test_gnan.py index 6d2cecbc51c0..ed6624ea2a02 100644 --- a/test/nn/models/test_gnan.py +++ b/test/nn/models/test_gnan.py @@ -1,3 +1,4 @@ +import pytest import torch from torch_geometric.data import Data @@ -26,11 +27,6 @@ def test_tensor_gnan_graph_level(): assert out.shape == (1, 2) -# ----------------------------------------------------------------------------- -# New tests for feature grouping and node importance -# ----------------------------------------------------------------------------- - - def test_tensor_gnan_feature_groups(): """Ensure model works correctly with custom feature grouping.""" data = _dummy_data(num_nodes=6, num_feats=4) @@ -80,3 +76,81 @@ def test_tensor_gnan_node_importance(): # Sum of node contributions equals the graph‐level prediction (Eq. 3) contrib_sum = node_contrib.sum(dim=0, keepdim=True) # [1, 4] assert torch.allclose(contrib_sum, graph_out, atol=1e-5) + + +def test_tensor_gnan_multiple_layers(): + """Model runs with multiple layers in the MLPs.""" + data = _dummy_data(num_nodes=7, num_feats=3) + + model = TensorGNAN( + in_channels=data.num_features, + out_channels=5, + n_layers=3, + hidden_channels=8, + dropout=0.0, + ) + + out = model(data) + assert out.shape == (1, 5) + + +def test_tensor_gnan_batched_data(): + """Batched graphs should be processed independently and aggregated + per-graph. + """ + g1 = _dummy_data(num_nodes=3, num_feats=4) + g2 = _dummy_data(num_nodes=4, num_feats=4) + + # Build a single Data with block-diagonal distances and norms and a + # batch vector + x = torch.cat([g1.x, g2.x], dim=0) + dist = torch.block_diag(g1.node_distances, g2.node_distances) + norm = torch.block_diag(g1.normalization_matrix, g2.normalization_matrix) + batch = torch.tensor([0] * g1.num_nodes + [1] * g2.num_nodes) + + batched = Data(x=x) + batched.node_distances = dist + batched.normalization_matrix = norm + batched.batch = batch + + model = TensorGNAN( + in_channels=4, + out_channels=3, + n_layers=2, + hidden_channels=6, + dropout=0.0, + graph_level=True, + ) + model.eval() + + out_batched = model(batched) # [2, 3] + assert out_batched.shape == (2, 3) + + # Compare against processing each graph separately with the same model + out_g1 = model(g1) # [1, 3] + out_g2 = model(g2) # [1, 3] + stacked = torch.cat([out_g1, out_g2], dim=0) + + assert torch.allclose(out_batched, stacked, atol=1e-5) + + +def test_tensor_gnan_invalid_feature_groups_empty(): + data = _dummy_data(num_nodes=5, num_feats=3) + with pytest.raises(ValueError, match="cannot be empty"): + TensorGNAN( + in_channels=data.num_features, + out_channels=2, + n_layers=1, + feature_groups=[[0], [], [2]], + ) + + +def test_tensor_gnan_invalid_feature_groups_duplicate(): + data = _dummy_data(num_nodes=5, num_feats=3) + with pytest.raises(ValueError, match="appears in multiple groups"): + TensorGNAN( + in_channels=data.num_features, + out_channels=2, + n_layers=1, + feature_groups=[[0, 1], [1], [2]], + ) diff --git a/torch_geometric/nn/models/gnan.py b/torch_geometric/nn/models/gnan.py index 5868e89486f6..38198f6029c7 100644 --- a/torch_geometric/nn/models/gnan.py +++ b/torch_geometric/nn/models/gnan.py @@ -4,7 +4,6 @@ from torch import nn from torch_geometric.data import Batch, Data -from torch_geometric.typing import OptTensor from torch_geometric.utils import scatter __all__ = [ @@ -170,12 +169,12 @@ class TensorGNAN(nn.Module): dropout (float, optional): Dropout probability. (default: ``0.0``) normalize_rho (bool, optional): Whether to normalize rho weights. (default: ``True``) - graph_level (bool, optional): Whether to produce graph-level predictions. - (default: ``True``) - feature_groups (List[List[int]], optional): Groups of feature indices to - process together. Each group will be processed by a single MLP that - takes multiple features as input. If None, each feature is processed - by its own MLP (default behavior). (default: ``None``) + graph_level (bool, optional): Whether to produce graph-level + predictions. (default: ``True``) + feature_groups (List[List[int]], optional): Groups of feature indices + to process together. Each group will be processed by a single MLP + that takes multiple features as input. If None, each feature is + processed by its own MLP (default behavior). (default: ``None``) """ def __init__( self, @@ -210,12 +209,12 @@ def __init__( for feat_idx in group: if feat_idx < 0 or feat_idx >= in_channels: raise ValueError( - f"Feature index {feat_idx} out of range [0, {in_channels})" - ) + f"Feature index {feat_idx} out of range " + f"[0, {in_channels})") if feat_idx in all_features: raise ValueError( - f"Feature index {feat_idx} appears in multiple groups" - ) + f"Feature index {feat_idx} appears in " + "multiple groups") all_features.add(feat_idx) if len(all_features) != in_channels: @@ -271,8 +270,10 @@ def _compute_rho(self, dist: torch.Tensor, norm: torch.Tensor, rho = rho * mask return rho - def forward(self, data: Data | Batch, - node_ids: OptTensor = None) -> torch.Tensor: + def forward( + self, + data: Data | Batch, + ) -> torch.Tensor: x: torch.Tensor = data.x # type: ignore # [N, F] dist: torch.Tensor = data.node_distances # type: ignore # [N, N] norm: torch.Tensor = data.normalization_matrix # type: ignore # [N, N] @@ -291,8 +292,6 @@ def forward(self, data: Data | Batch, graph_out = out.sum(dim=0, keepdim=True) # [1, C] return graph_out - if node_ids is not None: - return out[node_ids] return out def node_importance(self, data: Data | Batch) -> torch.Tensor: @@ -301,7 +300,8 @@ def node_importance(self, data: Data | Batch) -> torch.Tensor: """ x: torch.Tensor = data.x # type: ignore # [N, F] dist: torch.Tensor = data.node_distances # type: ignore # [N, N] - norm: torch.Tensor = data.normalization_matrix # type: ignore # [N, N] + norm: torch.Tensor = data.normalization_matrix # type: ignore + # [N, N] _, f_sum = self._process_feature_groups(x) rho = self._compute_rho(dist, norm, data) From 79c799d5132f3c64abe59de4e2185a26130ebb31 Mon Sep 17 00:00:00 2001 From: yoavk Date: Sat, 23 Aug 2025 10:16:52 +0000 Subject: [PATCH 11/13] add tests for gnan dataloader --- test/loader/test_gnan_loader.py | 95 +++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 test/loader/test_gnan_loader.py diff --git a/test/loader/test_gnan_loader.py b/test/loader/test_gnan_loader.py new file mode 100644 index 000000000000..bbcfabb4c600 --- /dev/null +++ b/test/loader/test_gnan_loader.py @@ -0,0 +1,95 @@ +import torch + +from torch_geometric.data import Data, Batch +from torch_geometric.loader.gnan_dataloader import ( + GNANCollater, + GNANDataLoader, +) +from torch_geometric.nn.models import TensorGNAN + + +def _dummy_data(num_nodes: int = 5, num_feats: int = 4): + x = torch.randn(num_nodes, num_feats) + edge_index = torch.combinations(torch.arange(num_nodes), r=2).t() + # full distance matrix: + rand_dist = torch.rand(num_nodes, num_nodes) + rand_dist = (rand_dist + rand_dist.t()) / 2 # symmetric + # Use a simple normalisation matrix of ones to avoid division issues + norm = torch.ones_like(rand_dist) + data = Data(x=x, edge_index=edge_index) + data.node_distances = rand_dist + data.normalization_matrix = norm + return data + + +def test_gnan_collater_block_diag_and_restore(): + g1 = _dummy_data(num_nodes=3, num_feats=4) + g2 = _dummy_data(num_nodes=4, num_feats=4) + g3 = _dummy_data(num_nodes=2, num_feats=4) + + # Keep copies to verify restoration on originals + d1 = g1.node_distances.clone() + n1 = g1.normalization_matrix.clone() + d2 = g2.node_distances.clone() + n2 = g2.normalization_matrix.clone() + d3 = g3.node_distances.clone() + n3 = g3.normalization_matrix.clone() + + collate = GNANCollater() + batch = collate([g1, g2, g3]) + + assert isinstance(batch, Batch) + N = g1.num_nodes + g2.num_nodes + g3.num_nodes + assert batch.node_distances.shape == (N, N) + assert batch.normalization_matrix.shape == (N, N) + + expected_dist = torch.block_diag(d1, d2, d3) + expected_norm = torch.block_diag(n1, n2, n3) + assert torch.allclose(batch.node_distances, expected_dist) + assert torch.allclose(batch.normalization_matrix, expected_norm) + + # Original Data objects should have attributes restored and unchanged + assert torch.allclose(g1.node_distances, d1) + assert torch.allclose(g1.normalization_matrix, n1) + assert torch.allclose(g2.node_distances, d2) + assert torch.allclose(g2.normalization_matrix, n2) + assert torch.allclose(g3.node_distances, d3) + assert torch.allclose(g3.normalization_matrix, n3) + + +def test_gnan_dataloader_batch_content(): + g1 = _dummy_data(num_nodes=3, num_feats=3) + g2 = _dummy_data(num_nodes=5, num_feats=3) + loader = GNANDataLoader([g1, g2], batch_size=2, shuffle=False) + batch = next(iter(loader)) + + assert isinstance(batch, Batch) + assert batch.x.size(0) == g1.num_nodes + g2.num_nodes + assert hasattr(batch, 'node_distances') and hasattr(batch, 'normalization_matrix') + + expected_dist = torch.block_diag(g1.node_distances, g2.node_distances) + expected_norm = torch.block_diag(g1.normalization_matrix, g2.normalization_matrix) + assert torch.allclose(batch.node_distances, expected_dist) + assert torch.allclose(batch.normalization_matrix, expected_norm) + + +def test_gnan_dataloader_with_tensor_gnan(): + g1 = _dummy_data(num_nodes=3, num_feats=4) + g2 = _dummy_data(num_nodes=4, num_feats=4) + loader = GNANDataLoader([g1, g2], batch_size=2, shuffle=False) + batch = next(iter(loader)) + + model = TensorGNAN( + in_channels=4, + out_channels=3, + n_layers=2, + hidden_channels=8, + graph_level=True, + ) + model.eval() + + # Batched forward vs. separate forwards + out_batched = model(batch) # [2, 3] + out_sep = torch.cat([model(g1), model(g2)], dim=0) + assert out_batched.shape == (2, 3) + assert torch.allclose(out_batched, out_sep, atol=1e-5) From ff95f62525479208cf9273b00180765e92b277f1 Mon Sep 17 00:00:00 2001 From: yoavk Date: Sat, 23 Aug 2025 10:19:14 +0000 Subject: [PATCH 12/13] add to changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 91fba77f8012..dd34398e8157 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,6 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737)) - Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467)) - Add ComplexWebQuestions (CWQ) dataset ([#9950](https://github.com/pyg-team/pytorch_geometric/pull/9950)) +- Add GNAN model and dedicated dataloader plus an example (https://github.com/pyg-team/pytorch_geometric/pull/10371) ### Changed From 4dcc161f8eed97d304f8426d0c70c16b31212292 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 23 Aug 2025 10:20:41 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/loader/test_gnan_loader.py | 155 ++++++++++++++++---------------- 1 file changed, 77 insertions(+), 78 deletions(-) diff --git a/test/loader/test_gnan_loader.py b/test/loader/test_gnan_loader.py index bbcfabb4c600..4d99bd4cbfa6 100644 --- a/test/loader/test_gnan_loader.py +++ b/test/loader/test_gnan_loader.py @@ -1,95 +1,94 @@ import torch -from torch_geometric.data import Data, Batch -from torch_geometric.loader.gnan_dataloader import ( - GNANCollater, - GNANDataLoader, -) +from torch_geometric.data import Batch, Data +from torch_geometric.loader.gnan_dataloader import GNANCollater, GNANDataLoader from torch_geometric.nn.models import TensorGNAN def _dummy_data(num_nodes: int = 5, num_feats: int = 4): - x = torch.randn(num_nodes, num_feats) - edge_index = torch.combinations(torch.arange(num_nodes), r=2).t() - # full distance matrix: - rand_dist = torch.rand(num_nodes, num_nodes) - rand_dist = (rand_dist + rand_dist.t()) / 2 # symmetric - # Use a simple normalisation matrix of ones to avoid division issues - norm = torch.ones_like(rand_dist) - data = Data(x=x, edge_index=edge_index) - data.node_distances = rand_dist - data.normalization_matrix = norm - return data + x = torch.randn(num_nodes, num_feats) + edge_index = torch.combinations(torch.arange(num_nodes), r=2).t() + # full distance matrix: + rand_dist = torch.rand(num_nodes, num_nodes) + rand_dist = (rand_dist + rand_dist.t()) / 2 # symmetric + # Use a simple normalisation matrix of ones to avoid division issues + norm = torch.ones_like(rand_dist) + data = Data(x=x, edge_index=edge_index) + data.node_distances = rand_dist + data.normalization_matrix = norm + return data def test_gnan_collater_block_diag_and_restore(): - g1 = _dummy_data(num_nodes=3, num_feats=4) - g2 = _dummy_data(num_nodes=4, num_feats=4) - g3 = _dummy_data(num_nodes=2, num_feats=4) - - # Keep copies to verify restoration on originals - d1 = g1.node_distances.clone() - n1 = g1.normalization_matrix.clone() - d2 = g2.node_distances.clone() - n2 = g2.normalization_matrix.clone() - d3 = g3.node_distances.clone() - n3 = g3.normalization_matrix.clone() - - collate = GNANCollater() - batch = collate([g1, g2, g3]) - - assert isinstance(batch, Batch) - N = g1.num_nodes + g2.num_nodes + g3.num_nodes - assert batch.node_distances.shape == (N, N) - assert batch.normalization_matrix.shape == (N, N) - - expected_dist = torch.block_diag(d1, d2, d3) - expected_norm = torch.block_diag(n1, n2, n3) - assert torch.allclose(batch.node_distances, expected_dist) - assert torch.allclose(batch.normalization_matrix, expected_norm) - - # Original Data objects should have attributes restored and unchanged - assert torch.allclose(g1.node_distances, d1) - assert torch.allclose(g1.normalization_matrix, n1) - assert torch.allclose(g2.node_distances, d2) - assert torch.allclose(g2.normalization_matrix, n2) - assert torch.allclose(g3.node_distances, d3) - assert torch.allclose(g3.normalization_matrix, n3) + g1 = _dummy_data(num_nodes=3, num_feats=4) + g2 = _dummy_data(num_nodes=4, num_feats=4) + g3 = _dummy_data(num_nodes=2, num_feats=4) + + # Keep copies to verify restoration on originals + d1 = g1.node_distances.clone() + n1 = g1.normalization_matrix.clone() + d2 = g2.node_distances.clone() + n2 = g2.normalization_matrix.clone() + d3 = g3.node_distances.clone() + n3 = g3.normalization_matrix.clone() + + collate = GNANCollater() + batch = collate([g1, g2, g3]) + + assert isinstance(batch, Batch) + N = g1.num_nodes + g2.num_nodes + g3.num_nodes + assert batch.node_distances.shape == (N, N) + assert batch.normalization_matrix.shape == (N, N) + + expected_dist = torch.block_diag(d1, d2, d3) + expected_norm = torch.block_diag(n1, n2, n3) + assert torch.allclose(batch.node_distances, expected_dist) + assert torch.allclose(batch.normalization_matrix, expected_norm) + + # Original Data objects should have attributes restored and unchanged + assert torch.allclose(g1.node_distances, d1) + assert torch.allclose(g1.normalization_matrix, n1) + assert torch.allclose(g2.node_distances, d2) + assert torch.allclose(g2.normalization_matrix, n2) + assert torch.allclose(g3.node_distances, d3) + assert torch.allclose(g3.normalization_matrix, n3) def test_gnan_dataloader_batch_content(): - g1 = _dummy_data(num_nodes=3, num_feats=3) - g2 = _dummy_data(num_nodes=5, num_feats=3) - loader = GNANDataLoader([g1, g2], batch_size=2, shuffle=False) - batch = next(iter(loader)) + g1 = _dummy_data(num_nodes=3, num_feats=3) + g2 = _dummy_data(num_nodes=5, num_feats=3) + loader = GNANDataLoader([g1, g2], batch_size=2, shuffle=False) + batch = next(iter(loader)) - assert isinstance(batch, Batch) - assert batch.x.size(0) == g1.num_nodes + g2.num_nodes - assert hasattr(batch, 'node_distances') and hasattr(batch, 'normalization_matrix') + assert isinstance(batch, Batch) + assert batch.x.size(0) == g1.num_nodes + g2.num_nodes + assert hasattr(batch, 'node_distances') and hasattr( + batch, 'normalization_matrix') - expected_dist = torch.block_diag(g1.node_distances, g2.node_distances) - expected_norm = torch.block_diag(g1.normalization_matrix, g2.normalization_matrix) - assert torch.allclose(batch.node_distances, expected_dist) - assert torch.allclose(batch.normalization_matrix, expected_norm) + expected_dist = torch.block_diag(g1.node_distances, g2.node_distances) + expected_norm = torch.block_diag(g1.normalization_matrix, + g2.normalization_matrix) + assert torch.allclose(batch.node_distances, expected_dist) + assert torch.allclose(batch.normalization_matrix, expected_norm) def test_gnan_dataloader_with_tensor_gnan(): - g1 = _dummy_data(num_nodes=3, num_feats=4) - g2 = _dummy_data(num_nodes=4, num_feats=4) - loader = GNANDataLoader([g1, g2], batch_size=2, shuffle=False) - batch = next(iter(loader)) - - model = TensorGNAN( - in_channels=4, - out_channels=3, - n_layers=2, - hidden_channels=8, - graph_level=True, - ) - model.eval() - - # Batched forward vs. separate forwards - out_batched = model(batch) # [2, 3] - out_sep = torch.cat([model(g1), model(g2)], dim=0) - assert out_batched.shape == (2, 3) - assert torch.allclose(out_batched, out_sep, atol=1e-5) + g1 = _dummy_data(num_nodes=3, num_feats=4) + g2 = _dummy_data(num_nodes=4, num_feats=4) + loader = GNANDataLoader([g1, g2], batch_size=2, shuffle=False) + batch = next(iter(loader)) + + model = TensorGNAN( + in_channels=4, + out_channels=3, + n_layers=2, + hidden_channels=8, + graph_level=True, + ) + model.eval() + + # Batched forward vs. separate forwards + out_batched = model(batch) # [2, 3] + out_sep = torch.cat([model(g1), model(g2)], dim=0) + assert out_batched.shape == (2, 3) + assert torch.allclose(out_batched, out_sep, atol=1e-5)