Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))
- Add ComplexWebQuestions (CWQ) dataset ([#9950](https://github.com/pyg-team/pytorch_geometric/pull/9950))
- Fixed `HypergraphConv` TorchScript compilation errors ([#10400](https://github.com/pyg-team/pytorch_geometric/pull/10400))

### Changed

Expand Down
10 changes: 10 additions & 0 deletions test/nn/conv/test_hypergraph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,13 @@ def test_hypergraph_conv_with_more_edges_than_nodes():
assert out.size() == (num_nodes, out_channels)
out = conv(x, hyperedge_index, hyperedge_weight)
assert out.size() == (num_nodes, out_channels)


def test_hypergraph_jit():
in_channels, out_channels = (2, 3)
conv = HypergraphConv(in_channels, out_channels, use_attention=True)
script = torch.jit.script(conv)
output = script(torch.randn(4, in_channels),
torch.tensor([[0, 1, 2], [0, 0, 1]]),
hyperedge_attr=torch.randn(2, in_channels))
assert output.size() == (4, out_channels * conv.heads)
Comment on lines +55 to +58
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output = script(torch.randn(4, in_channels),
torch.tensor([[0, 1, 2], [0, 0, 1]]),
hyperedge_attr=torch.randn(2, in_channels))
assert output.size() == (4, out_channels * conv.heads)
x = torch.randn(4, in_channels)
out_1 = conv(x, torch.tensor([[0, 1, 2], [0, 0, 1]]),
hyperedge_attr=torch.randn(2, in_channels))
out_2 = script(x,
torch.tensor([[0, 1, 2], [0, 0, 1]]),
hyperedge_attr=torch.randn(2, in_channels))
assert torch.allclose(out_1, out_2)

Also add the onlyFullTest decorator like

.
Good to merge after this.

15 changes: 9 additions & 6 deletions torch_geometric/nn/conv/hypergraph_conv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter

Expand Down Expand Up @@ -83,7 +82,7 @@ def __init__(
heads: int = 1,
concat: bool = True,
negative_slope: float = 0.2,
dropout: float = 0,
dropout: float = 0.0,
bias: bool = True,
**kwargs,
):
Expand All @@ -108,8 +107,11 @@ def __init__(
else:
self.heads = 1
self.concat = True
self.negative_slope = negative_slope
self.dropout = dropout
self.lin = Linear(in_channels, out_channels, bias=False,
weight_initializer='glorot')
self.att = torch.empty(0)

if bias and concat:
self.bias = Parameter(torch.empty(heads * out_channels))
Expand Down Expand Up @@ -162,7 +164,7 @@ def forward(self, x: Tensor, hyperedge_index: Tensor,

x = self.lin(x)

alpha = None
alpha = torch.empty(0)
if self.use_attention:
assert hyperedge_attr is not None
x = x.view(-1, self.heads, self.out_channels)
Expand All @@ -172,12 +174,13 @@ def forward(self, x: Tensor, hyperedge_index: Tensor,
x_i = x[hyperedge_index[0]]
x_j = hyperedge_attr[hyperedge_index[1]]
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = torch.nn.functional.leaky_relu(alpha, self.negative_slope)
if self.attention_mode == 'node':
alpha = softmax(alpha, hyperedge_index[1], num_nodes=num_edges)
else:
alpha = softmax(alpha, hyperedge_index[0], num_nodes=num_nodes)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
alpha = torch.nn.functional.dropout(alpha, p=self.dropout,
training=self.training)

D = scatter(hyperedge_weight[hyperedge_index[1]], hyperedge_index[0],
dim=0, dim_size=num_nodes, reduce='sum')
Expand Down Expand Up @@ -209,7 +212,7 @@ def message(self, x_j: Tensor, norm_i: Tensor, alpha: Tensor) -> Tensor:

out = norm_i.view(-1, 1, 1) * x_j.view(-1, H, F)

if alpha is not None:
if alpha.numel() > 0:
out = alpha.view(-1, self.heads, 1) * out

return out
Loading