Skip to content

Conversation

jank324
Copy link
Member

@jank324 jank324 commented Sep 5, 2025

Description

Removes or otherwise fixes all branching computations that would break the compute graph and inhibit differentiability in some special cases.

Sub- / replacement-PR for #538.

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Those branches inhibit differentiability in some special cases.

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code and checked that formatting passes (required).
  • I have have fixed all issues found by flake8 (required).
  • I have ensured that all pytest tests pass (required).
  • I have run pytest on a machine with a CUDA GPU and made sure all tests pass (required).
  • I have checked that the documentation builds (required).

Note: We are using a maximum length of 88 characters per line.

@jank324
Copy link
Member Author

jank324 commented Sep 19, 2025

Did #550 really do everything on the cavity? There seems to still be branching to base_rmatrix. @Hespe?

outgoing_beam.particles[:, :6].cpu().numpy(),
outgoing_p_array.rparticles.transpose(),
atol=1e-6,
outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to check all the changed atols, if they are too high. I think some might be.

@Hespe
Copy link
Member

Hespe commented Sep 22, 2025

Did #550 really do everything on the cavity? There seems to still be branching to base_rmatrix. @Hespe?

#550 did only remove the branching with respect to off-crest phase. We did not touch the singularity for cavities with 0 accelerating voltage.

@jank324
Copy link
Member Author

jank324 commented Sep 22, 2025

I'm currently trying to figure out how to get rid of the division by zero (Ei / dE) causing the tests to fail in cavity.py line 301.

f = Ei / dE * torch.log(1 + (dE / Ei))

Does this need a new autograd function? What would that look like? How generic should we make it?

Comment on lines +416 to +417
rotation = rotation_matrix(self.tilt)
R = rotation.mT @ R @ rotation
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be worth it to build an autograd function for this sort of function that only computes the tilt optionally in a forward pass, but always computes the gradient as if it had been in there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be, but that also depends on how expensive the multiplication is. The overhead of the custom autograd functions seems to be non-negligable from the Cavity example we implemented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants