-
Notifications
You must be signed in to change notification settings - Fork 3.9k
HypergraphConv TorchScript Compilation Compatibility #10400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #10400 +/- ##
==========================================
- Coverage 86.11% 85.94% -0.18%
==========================================
Files 496 502 +6
Lines 33655 35130 +1475
==========================================
+ Hits 28981 30191 +1210
- Misses 4674 4939 +265 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: Jinu Sunil <[email protected]>
@chrisrosner thanks for the work. |
output = script(torch.randn(4, in_channels), | ||
torch.tensor([[0, 1, 2], [0, 0, 1]]), | ||
hyperedge_attr=torch.randn(2, in_channels)) | ||
assert output.size() == (4, out_channels * conv.heads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output = script(torch.randn(4, in_channels), | |
torch.tensor([[0, 1, 2], [0, 0, 1]]), | |
hyperedge_attr=torch.randn(2, in_channels)) | |
assert output.size() == (4, out_channels * conv.heads) | |
x = torch.randn(4, in_channels) | |
out_1 = conv(x, torch.tensor([[0, 1, 2], [0, 0, 1]]), | |
hyperedge_attr=torch.randn(2, in_channels)) | |
out_2 = script(x, | |
torch.tensor([[0, 1, 2], [0, 0, 1]]), | |
hyperedge_attr=torch.randn(2, in_channels)) | |
assert torch.allclose(out_1, out_2) |
Also add the onlyFullTest
decorator like
@onlyFullTest |
Good to merge after this.
Addresses #10399
This was the minimal set of changes I could figure out to get HypergraphConv to compile with torchscript.
Included are unit tests demonstrating the problems.
The following issues were resolved:
type of alpha changing from None -> tensor
torchscript not understanding the import functional as F
non-attention branch does not define some member variables
default value for dropout is wrong type
Test Plan
pytest -k hyper