-
Notifications
You must be signed in to change notification settings - Fork 23
Remove branching #553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Remove branching #553
Conversation
This reverts commit 402a98b.
…em after removing branches that improved numerical accuracy
outgoing_beam.particles[:, :6].cpu().numpy(), | ||
outgoing_p_array.rparticles.transpose(), | ||
atol=1e-6, | ||
outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to check all the changed atol
s, if they are too high. I think some might be.
I'm currently trying to figure out how to get rid of the division by zero ( f = Ei / dE * torch.log(1 + (dE / Ei)) Does this need a new autograd function? What would that look like? How generic should we make it? |
rotation = rotation_matrix(self.tilt) | ||
R = rotation.mT @ R @ rotation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be worth it to build an autograd function for this sort of function that only computes the tilt optionally in a forward pass, but always computes the gradient as if it had been in there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be, but that also depends on how expensive the multiplication is. The overhead of the custom autograd functions seems to be non-negligable from the Cavity
example we implemented.
Description
Removes or otherwise fixes all branching computations that would break the compute graph and inhibit differentiability in some special cases.
Sub- / replacement-PR for #538.
Motivation and Context
Those branches inhibit differentiability in some special cases.
Types of changes
Checklist
flake8
(required).pytest
tests pass (required).pytest
on a machine with a CUDA GPU and made sure all tests pass (required).Note: We are using a maximum length of 88 characters per line.