From b70bdf04f3189b5bb352a5811e11dbdcf21c9a73 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 5 Sep 2025 15:48:09 +0200 Subject: [PATCH 01/13] Reapply "Remove branching that was removed up to now in #538" This reverts commit 402a98b71aafea5e1d70e0a3da6c9b92956dcf41. --- cheetah/accelerator/dipole.py | 91 ++++++++++------------------- cheetah/accelerator/quadrupole.py | 19 +++--- cheetah/accelerator/sextupole.py | 9 ++- cheetah/accelerator/solenoid.py | 8 +-- cheetah/particles/parameter_beam.py | 12 ++-- cheetah/track_methods.py | 53 +++++++---------- cheetah/utils/physics.py | 2 +- 7 files changed, 71 insertions(+), 123 deletions(-) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 90b840435..172d8948d 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -154,7 +154,7 @@ def __init__( @property def hx(self) -> torch.Tensor: - return torch.where(self.length == 0.0, 0.0, self.angle / self.length) + return self.angle / self.length # Zero length not caught because not physical @property def dipole_e1(self) -> torch.Tensor: @@ -427,75 +427,45 @@ def _bmadx_fringe_linear( def first_order_transfer_map( self, energy: torch.Tensor, species: Species ) -> torch.Tensor: - device = self.length.device - dtype = self.length.dtype - R_enter = self._transfer_map_enter() R_exit = self._transfer_map_exit() - if torch.any(self.length != 0.0): # Bending magnet with finite length - R = base_rmatrix( - length=self.length, - k1=self.k1, - hx=self.hx, - species=species, - energy=energy, - ) # Tilt is applied after adding edges - else: # Reduce to Thin-Corrector - R = torch.eye(7, device=device, dtype=dtype).repeat( - (*self.length.shape, 1, 1) - ) - R[..., 0, 1] = self.length - R[..., 2, 6] = self.angle - R[..., 2, 3] = self.length + R = base_rmatrix( + length=self.length, + k1=self.k1, + hx=self.hx, + species=species, + energy=energy, + ) # Tilt is applied after adding edges # Apply fringe fields R = R_exit @ R @ R_enter # Apply rotation for tilted magnets - if torch.any(self.tilt != 0): - rotation = rotation_matrix(self.tilt) - R = rotation.transpose(-1, -2) @ R @ rotation + rotation = rotation_matrix(self.tilt) + R = rotation.transpose(-1, -2) @ R @ rotation return R def second_order_transfer_map( self, energy: torch.Tensor, species: Species ) -> torch.Tensor: - device = self.length.device - dtype = self.length.dtype - R_enter = self._transfer_map_enter() R_exit = self._transfer_map_exit() - if torch.any(self.length != 0.0): # Bending magnet with finite length - T = base_ttensor( - length=self.length, - k1=self.k1, - k2=torch.tensor(0.0, dtype=dtype, device=device), - hx=self.hx, - species=species, - energy=energy, - ) - - # Fill the first-order transfer map into the second-order transfer map - T[..., :, 6, :] = base_rmatrix( - length=self.length, - k1=self.k1, - hx=self.hx, - species=species, - energy=energy, - ) - else: # Reduce to Thin-Corrector - R = torch.eye(7, device=device, dtype=dtype).repeat( - (*self.length.shape, 1, 1) - ) - R[..., 0, 1] = self.length - R[..., 2, 6] = self.angle - R[..., 2, 3] = self.length + T = base_ttensor( + length=self.length, + k1=self.k1, + k2=self.length.new_zeros(()), + hx=self.hx, + species=species, + energy=energy, + ) - T = torch.zeros((*self.length.shape, 7, 7), dtype=dtype, device=device) - T[..., :, 6, :] = R + # Fill the first-order transfer map into the second-order transfer map + T[..., :, 6, :] = base_rmatrix( + length=self.length, k1=self.k1, hx=self.hx, species=species, energy=energy + ) # Apply fringe fields T = torch.einsum( @@ -503,15 +473,14 @@ def second_order_transfer_map( ) # Apply rotation for tilted magnets - if torch.any(self.tilt != 0): - rotation = rotation_matrix(self.tilt) - T = torch.einsum( - "...ij,...jkl,...kn,...lm->...inm", - rotation.transpose(-1, -2), - T, - rotation, - rotation, - ) + rotation = rotation_matrix(self.tilt) + T = torch.einsum( + "...ji,...jkl,...kn,...lm->...inm", + rotation, # Switch index labels in einsum instead of transpose (faster) + T, + rotation, + rotation, + ) return T diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index 9b98d7e2a..09f1c4894 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -96,12 +96,10 @@ def first_order_transfer_map( energy=energy, ) - if torch.all(self.misalignment == 0): - return R - else: - R_entry, R_exit = misalignment_matrix(self.misalignment) - R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry) - return R + R_entry, R_exit = misalignment_matrix(self.misalignment) + R = R_exit @ R @ R_entry + + return R def second_order_transfer_map( self, energy: torch.Tensor, species: Species @@ -127,11 +125,10 @@ def second_order_transfer_map( ) # Apply misalignments to the entire second-order transfer map - if not torch.all(self.misalignment == 0): - R_entry, R_exit = misalignment_matrix(self.misalignment) - T = torch.einsum( - "...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry - ) + R_entry, R_exit = misalignment_matrix(self.misalignment) + T = torch.einsum( + "...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry + ) return T diff --git a/cheetah/accelerator/sextupole.py b/cheetah/accelerator/sextupole.py index 15fefa7b1..8e8b2a646 100644 --- a/cheetah/accelerator/sextupole.py +++ b/cheetah/accelerator/sextupole.py @@ -87,11 +87,10 @@ def second_order_transfer_map(self, energy, species): ) # Apply misalignments to the entire second-order transfer map - if not torch.all(self.misalignment == 0): - R_entry, R_exit = misalignment_matrix(self.misalignment) - T = torch.einsum( - "...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry - ) + R_entry, R_exit = misalignment_matrix(self.misalignment) + T = torch.einsum( + "...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry + ) return T diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index 4d2bf4fc6..1d1988d65 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -103,12 +103,8 @@ def first_order_transfer_map( R = R.real - if torch.all(self.misalignment == 0): - return R - else: - R_entry, R_exit = misalignment_matrix(self.misalignment) - R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry) - return R + R_entry, R_exit = misalignment_matrix(self.misalignment) + R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry) @property def is_active(self) -> bool: diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index f28f61d41..0d1b3ed14 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -611,7 +611,7 @@ def mu_x(self) -> torch.Tensor: @property def sigma_x(self) -> torch.Tensor: - return torch.sqrt(torch.clamp_min(self.cov[..., 0, 0], 1e-20)) + return torch.sqrt(self.cov[..., 0, 0]) @property def mu_px(self) -> torch.Tensor: @@ -619,7 +619,7 @@ def mu_px(self) -> torch.Tensor: @property def sigma_px(self) -> torch.Tensor: - return torch.sqrt(torch.clamp_min(self.cov[..., 1, 1], 1e-20)) + return torch.sqrt(self.cov[..., 1, 1]) @property def mu_y(self) -> torch.Tensor: @@ -627,7 +627,7 @@ def mu_y(self) -> torch.Tensor: @property def sigma_y(self) -> torch.Tensor: - return torch.sqrt(torch.clamp_min(self.cov[..., 2, 2], 1e-20)) + return torch.sqrt(self.cov[..., 2, 2]) @property def mu_py(self) -> torch.Tensor: @@ -635,7 +635,7 @@ def mu_py(self) -> torch.Tensor: @property def sigma_py(self) -> torch.Tensor: - return torch.sqrt(torch.clamp_min(self.cov[..., 3, 3], 1e-20)) + return torch.sqrt(self.cov[..., 3, 3]) @property def mu_tau(self) -> torch.Tensor: @@ -643,7 +643,7 @@ def mu_tau(self) -> torch.Tensor: @property def sigma_tau(self) -> torch.Tensor: - return torch.sqrt(torch.clamp_min(self.cov[..., 4, 4], 1e-20)) + return torch.sqrt(self.cov[..., 4, 4]) @property def mu_p(self) -> torch.Tensor: @@ -651,7 +651,7 @@ def mu_p(self) -> torch.Tensor: @property def sigma_p(self) -> torch.Tensor: - return torch.sqrt(torch.clamp_min(self.cov[..., 5, 5], 1e-20)) + return torch.sqrt(self.cov[..., 5, 5]) @property def cov_xpx(self) -> torch.Tensor: diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 1cce73e36..4592beba3 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -43,10 +43,16 @@ def base_rmatrix( cy = torch.cos(ky * length).real sx = (torch.sinc(kx * length / torch.pi) * length).real sy = (torch.sinc(ky * length / torch.pi) * length).real - dx = torch.where(kx2 != 0, hx / kx2 * (1.0 - cx), zero) - r56 = torch.where(kx2 != 0, hx**2 * (length - sx) / kx2 / beta**2, zero) - r56 = r56 - length / beta**2 * igamma2 + r = (0.5 * kx * length / torch.pi).sinc() + dx = hx * 0.5 * length.square() * r.square().real + + kx2_is_not_zero = kx2 != 0 + r56 = ( + torch.where(kx2_is_not_zero, hx**2 * (length - sx) / kx2, zero) + * -length + * igamma2 + ) vector_shape = torch.broadcast_shapes( length.shape, k1.shape, hx.shape, tilt.shape, energy.shape @@ -72,13 +78,8 @@ def base_rmatrix( # rotation needs to be applied accross all vector dimensions. The torch.where is # here to improve numerical stability for the vector elements where no rotation # needs to be applied. - if torch.any((tilt != 0) & ((hx != 0) | (k1 != 0))): - rotation = rotation_matrix(tilt) - R = torch.where( - ((tilt != 0) & ((hx != 0) | (k1 != 0))).unsqueeze(-1).unsqueeze(-1), - rotation.transpose(-1, -2) @ R @ rotation, - R, - ) + rotation = rotation_matrix(tilt) + R = rotation.transpose(-1, -2) @ R @ rotation return R @@ -275,27 +276,15 @@ def base_ttensor( - 0.25 / beta * (length + cy * sy) ) - # Rotate the T tensor for skew / vertical magnets. The rotation only has an effect - # if hx != 0, k1 != 0 or k2 != 0. Note that the first if is here to improve speed - # when no rotation needs to be applied accross all vector dimensions. The - # torch.where is here to improve numerical stability for the vector elements where - # no rotation needs to be applied. - if torch.any((tilt != 0) & ((hx != 0) | (k1 != 0) | (k2 != 0))): - rotation = rotation_matrix(tilt) - T = torch.where( - ((tilt != 0) & ((hx != 0) | (k1 != 0) | (k2 != 0))) - .unsqueeze(-1) - .unsqueeze(-1) - .unsqueeze(-1), - torch.einsum( - "...ij,...jkl,...kn,...lm->...inm", - rotation.transpose(-1, -2), - T, - rotation, - rotation, - ), - T, - ) + rotation = rotation_matrix(tilt) + T = torch.einsum( + "...ji,...jkl,...kn,...lm->...inm", + rotation, # Switch index labels in einsum instead of transpose (faster) + T, + rotation, + rotation, + ) + return T @@ -360,5 +349,3 @@ def misalignment_matrix( R_entry[..., 2, 6] = -misalignment[..., 1] return R_entry, R_exit - return R_entry, R_exit - return R_entry, R_exit diff --git a/cheetah/utils/physics.py b/cheetah/utils/physics.py index cdb09a994..ed0f4fc7e 100644 --- a/cheetah/utils/physics.py +++ b/cheetah/utils/physics.py @@ -13,7 +13,7 @@ def compute_relativistic_factors( :return: gamma, igamma2, beta. """ gamma = energy / particle_mass_eV - igamma2 = torch.where(gamma == 0.0, 0.0, 1 / gamma**2) + igamma2 = 1 / gamma**2 beta = torch.sqrt(1 - igamma2) return gamma, igamma2, beta From 44fb11a8537dbece18232718b3dceaf689bead15 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Mon, 8 Sep 2025 14:24:31 +0200 Subject: [PATCH 02/13] Some work towards fixing failing tests --- cheetah/accelerator/solenoid.py | 2 ++ tests/test_quadrupole.py | 23 ++++++++++++----------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index 1d1988d65..7fbde4130 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -106,6 +106,8 @@ def first_order_transfer_map( R_entry, R_exit = misalignment_matrix(self.misalignment) R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry) + return R + @property def is_active(self) -> bool: return torch.any(self.k != 0).item() diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index 8a30edb7d..c0b2fc5de 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -270,19 +270,20 @@ def test_tilted_quad_transfer_matrix_precision(dtype): k1 = torch.tensor(0.0, dtype=dtype) tilt = torch.tensor(torch.pi / 4, dtype=dtype) - quad = cheetah.Quadrupole(length=length, k1=k1) - skew_quad = cheetah.Quadrupole(length=length, k1=k1, tilt=tilt) - drift = cheetah.Drift(length=length) + quad = cheetah.Quadrupole(length=length, k1=k1, dtype=dtype) + skew_quad = cheetah.Quadrupole(length=length, k1=k1, tilt=tilt, dtype=dtype) + drift = cheetah.Drift(length=length, dtype=dtype) # Compute the transfer matrices energy = torch.tensor(1e9, dtype=dtype) - spiecies = cheetah.Species("electron") + species = cheetah.Species("electron", dtype=dtype) - tm_quad = quad.first_order_transfer_map(energy, spiecies) - tm_skew_quad = skew_quad.first_order_transfer_map(energy, spiecies) - tm_drift = drift.first_order_transfer_map(energy, spiecies) + tm_quad = quad.first_order_transfer_map(energy, species) + tm_skew_quad = skew_quad.first_order_transfer_map(energy, species) + tm_drift = drift.first_order_transfer_map(energy, species) - # Check that the transfer matrices are equal of the dtype - # NOTE: The `==` is used here over `torch.allclose` on purpose - assert (tm_quad == tm_drift).all() - assert (tm_skew_quad == tm_drift).all() + # Check that the transfer matrices are equal to the precision of the dtype + assert torch.allclose(tm_drift, tm_quad) + assert torch.allclose( + tm_drift, tm_skew_quad, atol=1e-8 if dtype == torch.float64 else 1e-7 + ) From ac899c55ddb9fb3f9e38d038f668d7ce249736f8 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 19 Sep 2025 12:51:28 +0200 Subject: [PATCH 03/13] Increase tolerance to pass test --- tests/test_quadrupole.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/test_quadrupole.py b/tests/test_quadrupole.py index b8ec1119c..b97e6566c 100644 --- a/tests/test_quadrupole.py +++ b/tests/test_quadrupole.py @@ -268,8 +268,7 @@ def test_quadrupole_clone_tracking_method(tracking_method): def test_tilted_quad_transfer_matrix_precision(dtype): """ Test that the transfer matrix for a tilted quadrupole element with k1=0 matches the - transfer matrices of a normal quadrupole and a drift element to the precision of the - used dtype. + transfer matrices of a normal quadrupole and a drift element. """ # Create three elements that should have the same transfer matrix length = torch.tensor(0.5, dtype=dtype) @@ -289,7 +288,5 @@ def test_tilted_quad_transfer_matrix_precision(dtype): tm_drift = drift.first_order_transfer_map(energy, species) # Check that the transfer matrices are equal to the precision of the dtype - assert torch.allclose(tm_drift, tm_quad) - assert torch.allclose( - tm_drift, tm_skew_quad, atol=1e-8 if dtype == torch.float64 else 1e-7 - ) + assert torch.allclose(tm_drift, tm_quad, atol=2e-7) + assert torch.allclose(tm_drift, tm_skew_quad, atol=2e-7) From 95fad9471287439b9937c0e47ed5763846d1d7dc Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 19 Sep 2025 14:07:07 +0200 Subject: [PATCH 04/13] Selectively increase accuracies in Ocelot comparison tests to pass them after removing branches that improved numerical accuracy --- tests/test_compare_ocelot.py | 85 ++++++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 22 deletions(-) diff --git a/tests/test_compare_ocelot.py b/tests/test_compare_ocelot.py index 75084bf91..2ac7d2a0c 100644 --- a/tests/test_compare_ocelot.py +++ b/tests/test_compare_ocelot.py @@ -79,11 +79,14 @@ def test_dipole(tracking_method): navigator = ocelot.Navigator(lattice) _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) + assert np.allclose(outgoing_beam.x.cpu().numpy(), outgoing_p_array.x()) + assert np.allclose(outgoing_beam.px.cpu().numpy(), outgoing_p_array.px(), atol=2e-8) + assert np.allclose(outgoing_beam.y.cpu().numpy(), outgoing_p_array.y()) + assert np.allclose(outgoing_beam.py.cpu().numpy(), outgoing_p_array.py(), atol=5e-7) assert np.allclose( - outgoing_beam.particles[:, :6].cpu().numpy(), - outgoing_p_array.rparticles.transpose(), - atol=1e-6, + outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5 ) + assert np.allclose(outgoing_beam.p.cpu().numpy(), outgoing_p_array.p()) def test_dipole_with_float64(): @@ -109,10 +112,14 @@ def test_dipole_with_float64(): navigator = ocelot.Navigator(lattice) _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) + assert np.allclose(outgoing_beam.x.cpu().numpy(), outgoing_p_array.x()) + assert np.allclose(outgoing_beam.px.cpu().numpy(), outgoing_p_array.px()) + assert np.allclose(outgoing_beam.y.cpu().numpy(), outgoing_p_array.y()) + assert np.allclose(outgoing_beam.py.cpu().numpy(), outgoing_p_array.py()) assert np.allclose( - outgoing_beam.particles[:, :6].cpu().numpy(), - outgoing_p_array.rparticles.transpose(), + outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5 ) + assert np.allclose(outgoing_beam.p.cpu().numpy(), outgoing_p_array.p()) def test_dipole_with_fringe_field(): @@ -141,10 +148,14 @@ def test_dipole_with_fringe_field(): navigator = ocelot.Navigator(lattice) _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) + assert np.allclose(outgoing_beam.x.cpu().numpy(), outgoing_p_array.x()) + assert np.allclose(outgoing_beam.px.cpu().numpy(), outgoing_p_array.px()) + assert np.allclose(outgoing_beam.y.cpu().numpy(), outgoing_p_array.y()) + assert np.allclose(outgoing_beam.py.cpu().numpy(), outgoing_p_array.py()) assert np.allclose( - outgoing_beam.particles[:, :6].cpu().numpy(), - outgoing_p_array.rparticles.transpose(), + outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5 ) + assert np.allclose(outgoing_beam.p.cpu().numpy(), outgoing_p_array.p()) def test_dipole_with_fringe_field_and_tilt(): @@ -187,10 +198,14 @@ def test_dipole_with_fringe_field_and_tilt(): navigator = ocelot.Navigator(lattice) _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) + assert np.allclose(outgoing_beam.x.cpu().numpy(), outgoing_p_array.x()) + assert np.allclose(outgoing_beam.px.cpu().numpy(), outgoing_p_array.px()) + assert np.allclose(outgoing_beam.y.cpu().numpy(), outgoing_p_array.y()) + assert np.allclose(outgoing_beam.py.cpu().numpy(), outgoing_p_array.py()) assert np.allclose( - outgoing_beam.particles[:, :6].cpu().numpy(), - outgoing_p_array.rparticles.transpose(), + outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=8e-3 ) + assert np.allclose(outgoing_beam.p.cpu().numpy(), outgoing_p_array.p()) def test_aperture(): @@ -345,7 +360,9 @@ def test_ares_ea(): assert np.allclose(outgoing_beam.px.cpu().numpy(), outgoing_p_array.px()) assert np.allclose(outgoing_beam.y.cpu().numpy(), outgoing_p_array.y()) assert np.allclose(outgoing_beam.py.cpu().numpy(), outgoing_p_array.py()) - assert np.allclose(outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau()) + assert np.allclose( + outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5 + ) assert np.allclose(outgoing_beam.p.cpu().numpy(), outgoing_p_array.p()) @@ -446,10 +463,14 @@ def test_quadrupole(tracking_method): _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) # Split in order to allow for different tolerances for each particle dimension + assert np.allclose(outgoing_beam.x.cpu().numpy(), outgoing_p_array.x()) + assert np.allclose(outgoing_beam.px.cpu().numpy(), outgoing_p_array.px()) + assert np.allclose(outgoing_beam.y.cpu().numpy(), outgoing_p_array.y()) + assert np.allclose(outgoing_beam.py.cpu().numpy(), outgoing_p_array.py()) assert np.allclose( - outgoing_beam.particles[:, :6].cpu().numpy(), - outgoing_p_array.rparticles.transpose(), + outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5 ) + assert np.allclose(outgoing_beam.p.cpu().numpy(), outgoing_p_array.p()) assert np.allclose( outgoing_beam.particle_charges.cpu().numpy(), outgoing_p_array.q_array ) @@ -487,10 +508,14 @@ def test_tilted_quadrupole(): navigator = ocelot.Navigator(lattice) _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) + assert np.allclose(outgoing_beam.x.cpu().numpy(), outgoing_p_array.x()) + assert np.allclose(outgoing_beam.px.cpu().numpy(), outgoing_p_array.px()) + assert np.allclose(outgoing_beam.y.cpu().numpy(), outgoing_p_array.y()) + assert np.allclose(outgoing_beam.py.cpu().numpy(), outgoing_p_array.py()) assert np.allclose( - outgoing_beam.particles[:, :6].cpu().numpy(), - outgoing_p_array.rparticles.transpose(), + outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5 ) + assert np.allclose(outgoing_beam.p.cpu().numpy(), outgoing_p_array.p()) assert np.allclose( outgoing_beam.particle_charges.cpu().numpy(), outgoing_p_array.q_array ) @@ -528,10 +553,14 @@ def test_sbend(): lattice, deepcopy(incoming_p_array), navigator, print_progress=False ) + assert np.allclose(outgoing_beam.x.cpu().numpy(), outgoing_p_array.x()) + assert np.allclose(outgoing_beam.px.cpu().numpy(), outgoing_p_array.px()) + assert np.allclose(outgoing_beam.y.cpu().numpy(), outgoing_p_array.y()) + assert np.allclose(outgoing_beam.py.cpu().numpy(), outgoing_p_array.py()) assert np.allclose( - outgoing_beam.particles[:, :6].cpu().numpy(), - outgoing_p_array.rparticles.transpose(), + outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5 ) + assert np.allclose(outgoing_beam.p.cpu().numpy(), outgoing_p_array.p()) assert np.allclose( outgoing_beam.particle_charges.cpu().numpy(), outgoing_p_array.q_array ) @@ -574,10 +603,14 @@ def test_rbend(): lattice, deepcopy(incoming_p_array), navigator, print_progress=False ) + assert np.allclose(outgoing_beam.x.cpu().numpy(), outgoing_p_array.x()) + assert np.allclose(outgoing_beam.px.cpu().numpy(), outgoing_p_array.px()) + assert np.allclose(outgoing_beam.y.cpu().numpy(), outgoing_p_array.y()) + assert np.allclose(outgoing_beam.py.cpu().numpy(), outgoing_p_array.py()) assert np.allclose( - outgoing_beam.particles[:, :6].cpu().numpy(), - outgoing_p_array.rparticles.transpose(), + outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5 ) + assert np.allclose(outgoing_beam.p.cpu().numpy(), outgoing_p_array.p()) assert np.allclose( outgoing_beam.particle_charges.cpu().numpy(), outgoing_p_array.q_array ) @@ -612,10 +645,14 @@ def test_convert_rbend(): cheetah_segment = cheetah.Segment.from_ocelot(lattice.sequence) outgoing_beam = cheetah_segment.track(incoming_beam) + assert np.allclose(outgoing_beam.x.cpu().numpy(), outgoing_p_array.x()) + assert np.allclose(outgoing_beam.px.cpu().numpy(), outgoing_p_array.px()) + assert np.allclose(outgoing_beam.y.cpu().numpy(), outgoing_p_array.y()) + assert np.allclose(outgoing_beam.py.cpu().numpy(), outgoing_p_array.py()) assert np.allclose( - outgoing_beam.particles[:, :6].cpu().numpy(), - outgoing_p_array.rparticles.transpose(), + outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5 ) + assert np.allclose(outgoing_beam.p.cpu().numpy(), outgoing_p_array.p()) assert np.allclose( outgoing_beam.particle_charges.cpu().numpy(), outgoing_p_array.q_array ) @@ -649,10 +686,14 @@ def test_asymmetric_bend(): cheetah_segment = cheetah.Segment.from_ocelot(lattice.sequence) outgoing_beam = cheetah_segment.track(incoming_beam) + assert np.allclose(outgoing_beam.x.cpu().numpy(), outgoing_p_array.x()) + assert np.allclose(outgoing_beam.px.cpu().numpy(), outgoing_p_array.px()) + assert np.allclose(outgoing_beam.y.cpu().numpy(), outgoing_p_array.y()) + assert np.allclose(outgoing_beam.py.cpu().numpy(), outgoing_p_array.py()) assert np.allclose( - outgoing_beam.particles[:, :6].cpu().numpy(), - outgoing_p_array.rparticles.transpose(), + outgoing_beam.tau.cpu().numpy(), outgoing_p_array.tau(), atol=2e-5 ) + assert np.allclose(outgoing_beam.p.cpu().numpy(), outgoing_p_array.p()) assert np.allclose( outgoing_beam.particle_charges.cpu().numpy(), outgoing_p_array.q_array ) From ea75d9c115bcc4e6b6f6054f6cfd518a8bc0f26e Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 19 Sep 2025 14:16:55 +0200 Subject: [PATCH 05/13] Remove obsolete comment --- cheetah/track_methods.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index bd1eaeae2..a0ebb5d79 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -72,11 +72,6 @@ def base_rmatrix( R[..., 4, 1] = dx / beta R[..., 4, 5] = r56 - # Rotate the R matrix for skew / vertical magnets. The rotation only has an effect - # if hx != 0 or k1 != 0. Note that the first if is here to improve speed when no - # rotation needs to be applied accross all vector dimensions. The torch.where is - # here to improve numerical stability for the vector elements where no rotation - # needs to be applied. rotation = rotation_matrix(tilt) R = rotation.transpose(-1, -2) @ R @ rotation From 7c38b80b3388bc631d64f49dd685b6eb592f33ad Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 19 Sep 2025 14:24:46 +0200 Subject: [PATCH 06/13] Remove branch to `base_rmatrix` in `Cavity` --- cheetah/accelerator/cavity.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index 9e8af1357..f30b1dc7f 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -8,7 +8,6 @@ from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParameterBeam, ParticleBeam, Species -from cheetah.track_methods import base_rmatrix from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors from cheetah.utils.autograd import log1pdiv @@ -76,18 +75,7 @@ def is_skippable(self) -> bool: def _compute_first_order_transfer_map( self, energy: torch.Tensor, species: Species ) -> torch.Tensor: - return torch.where( - (self.voltage != 0).unsqueeze(-1).unsqueeze(-1), - self._cavity_rmatrix(energy, species), - base_rmatrix( - length=self.length, - k1=torch.zeros_like(self.length), - hx=torch.zeros_like(self.length), - species=species, - tilt=torch.zeros_like(self.length), - energy=energy, - ), - ) + return self._cavity_rmatrix(energy, species) def track(self, incoming: Beam) -> Beam: gamma0, igamma2, beta0 = compute_relativistic_factors( From ca4a862e7053ccee0d0be6e4c13db56e2bea71c1 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 19 Sep 2025 18:14:55 +0200 Subject: [PATCH 07/13] Remove branch for non-physical case `gamma == 0.0` --- cheetah/accelerator/solenoid.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index f2da0b3bc..aafb56787 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -71,9 +71,7 @@ def _compute_first_order_transfer_map( self.length.shape, self.k.shape, energy.shape ) - r56 = torch.where( - gamma != 0, self.length / (1 - gamma**2), torch.zeros_like(self.length) - ) + r56 = self.length / (1 - gamma**2) R = torch.eye(7, **factory_kwargs).repeat((*vector_shape, 1, 1)) R[..., 0, 0] = c**2 From 636845fef3a377b57d22468766dc0dbdfe7a7481 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Tue, 23 Sep 2025 09:11:59 +0200 Subject: [PATCH 08/13] Remove one singularity of cavity --- cheetah/accelerator/cavity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index f30b1dc7f..ed40bc2ad 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -298,7 +298,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor, species: Species) -> torch.Tenso elif self.cavity_type == "traveling_wave": # Reference paper: Rosenzweig and Serafini, PhysRevE, Vol.49, p.1599,(1994) - f = Ei / dE * torch.log(1 + (dE / Ei)) + f = log1pdiv(dE / Ei) vector_shape = torch.broadcast_shapes( self.length.shape, f.shape, Ei.shape, Ef.shape From c026a77a96e469e55f688e8fea745cd035d09ba1 Mon Sep 17 00:00:00 2001 From: Christian Hespe Date: Tue, 23 Sep 2025 09:33:04 +0200 Subject: [PATCH 09/13] Remove standing wave singularity --- cheetah/accelerator/cavity.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index ed40bc2ad..23942a17d 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -249,25 +249,17 @@ def _cavity_rmatrix(self, energy: torch.Tensor, species: Species) -> torch.Tenso beta1 = torch.sqrt(1 - 1 / Ef**2) r11 = torch.cos(alpha) - math.sqrt(2.0) * torch.cos(phi) * torch.sin(alpha) - - # In Ocelot r12 is defined as below only if abs(Ep) > 10, and self.length - # otherwise. This is implemented differently here to achieve results - # closer to Bmad. r12 = ( - math.sqrt(8.0) - * energy - / effective_voltage - * torch.sin(alpha) + torch.sinc(alpha / torch.pi) + * log1pdiv(delta_energy / energy) * self.length ) - r21 = -( effective_voltage / ((energy + delta_energy) * math.sqrt(2.0) * self.length) * (0.5 + torch.cos(phi) ** 2) * torch.sin(alpha) ) - r22 = ( Ei / Ef From 2cea65d34893b61201affe3b9ae9c5c5c7b5234e Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 24 Sep 2025 14:34:01 +0200 Subject: [PATCH 10/13] Blindly remove all `where`s --- cheetah/accelerator/dipole.py | 5 +--- cheetah/accelerator/solenoid.py | 2 +- cheetah/track_methods.py | 42 +++++++++++---------------------- 3 files changed, 16 insertions(+), 33 deletions(-) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index e99aa78c9..3883d0829 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -293,9 +293,6 @@ def _bmadx_body( phi1 = torch.arcsin(px / px_norm) g = self.angle / self.length gp = g.unsqueeze(-1) / px_norm - gp_safe = torch.where( - gp != 0, gp, torch.tensor(1e-12, dtype=gp.dtype, device=gp.device) - ) alpha = ( 2 @@ -322,7 +319,7 @@ def _bmadx_body( x2_t3 = torch.cos(self.angle.unsqueeze(-1) + phi1) c1 = x2_t1 + alpha / (x2_t2 + x2_t3) - c2 = x2_t1 + (x2_t2 - x2_t3) / gp_safe + c2 = x2_t1 + (x2_t2 - x2_t3) / gp temp = torch.abs(self.angle.unsqueeze(-1) + phi1) x2 = c1 * (temp < torch.pi / 2) + c2 * (temp >= torch.pi / 2) diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index aafb56787..420e8a3ab 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -65,7 +65,7 @@ def _compute_first_order_transfer_map( c = torch.cos(self.length * self.k) s = torch.sin(self.length * self.k) - s_k = torch.where(self.k == 0.0, self.length, s / self.k) + s_k = s / self.k vector_shape = torch.broadcast_shapes( self.length.shape, self.k.shape, energy.shape diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index a0ebb5d79..08da66a9f 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -46,12 +46,7 @@ def base_rmatrix( r = (0.5 * kx * length / torch.pi).sinc() dx = hx * 0.5 * length.square() * r.square().real - kx2_is_not_zero = kx2 != 0 - r56 = ( - torch.where(kx2_is_not_zero, hx**2 * (length - sx) / kx2, zero) - * -length - * igamma2 - ) + r56 = hx**2 * (length - sx) / kx2 * -length * igamma2 vector_shape = torch.broadcast_shapes( length.shape, k1.shape, hx.shape, tilt.shape, energy.shape @@ -117,33 +112,24 @@ def base_ttensor( cy = torch.cos(ky * length).real sx = (torch.sinc(kx * length / torch.pi) * length).real sy = (torch.sinc(ky * length / torch.pi) * length).real - dx = torch.where(kx != 0, (1.0 - cx) / kx2, length**2 / 2.0) + dx = (1.0 - cx) / kx2 d2y = 0.5 * sy**2 s2y = sy * cy c2y = torch.cos(2 * ky * length).real - fx = torch.where(kx2 != 0, (length - sx) / kx2, length**3 / 6.0) - f2y = torch.where(ky2 != 0, (length - s2y) / ky2, length**3 / 6.0) - - j1 = torch.where(kx2 != 0, (length - sx) / kx2, length**3 / 6.0) - j2 = torch.where( - kx2 != 0, - (3.0 * length - 4.0 * sx + sx * cx) / (2 * kx2**2), - length**5 / 20.0, - ) - j3 = torch.where( - kx2 != 0, - (15.0 * length - 22.5 * sx + 9.0 * sx * cx - 1.5 * sx * cx**2 + kx2 * sx**3) - / (6.0 * kx2**3), - length**7 / 56.0, - ) + fx = (length - sx) / kx2 + f2y = (length - s2y) / ky2 + + j1 = (length - sx) / kx2 + j2 = (3.0 * length - 4.0 * sx + sx * cx) / (2 * kx2**2) + j3 = ( + 15.0 * length - 22.5 * sx + 9.0 * sx * cx - 1.5 * sx * cx**2 + kx2 * sx**3 + ) / (6.0 * kx2**3) j_denominator = kx2 - 4 * ky2 - jc = torch.where(j_denominator != 0, (c2y - cx) / j_denominator, 0.5 * length**2) - js = torch.where( - j_denominator != 0, (cy * sy - sx) / j_denominator, length**3 / 6.0 - ) - jd = torch.where(j_denominator != 0, (d2y - dx) / j_denominator, length**4 / 24.0) - jf = torch.where(j_denominator != 0, (f2y - fx) / j_denominator, length**5 / 120.0) + jc = (c2y - cx) / j_denominator + js = (cy * sy - sx) / j_denominator + jd = (d2y - dx) / j_denominator + jf = (f2y - fx) / j_denominator khk = k2 + 2 * hx * k1 From 29c8bc0048c79cdc627da95dc249ef7f53628dc6 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 26 Sep 2025 15:35:49 +0200 Subject: [PATCH 11/13] Format fix --- cheetah/accelerator/cavity.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index d7dc0ce40..bca26c909 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -8,7 +8,11 @@ from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParameterBeam, ParticleBeam, Species -from cheetah.utils import UniqueNameGenerator, cache_transfer_map, compute_relativistic_factors +from cheetah.utils import ( + UniqueNameGenerator, + cache_transfer_map, + compute_relativistic_factors, +) from cheetah.utils.autograd import log1pdiv generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") From 8de0b598f57b495d4116fdec17da69255f4269d6 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 26 Sep 2025 16:26:42 +0200 Subject: [PATCH 12/13] Somewhat Ocelot-like fix for singularity in `Cavity` `r55` --- cheetah/accelerator/cavity.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index bca26c909..fce301884 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -276,15 +276,17 @@ def _cavity_rmatrix(self, energy: torch.Tensor, species: Species) -> torch.Tenso ) ) - r55 = 1.0 + ( - k - * self.length - * beta0 - * effective_voltage - / species.mass_eV - * torch.sin(phi) - * (Ei * Ef * (beta0 * beta1 - 1) + 1) - / (beta1 * Ef * (Ei - Ef) ** 2) + r55 = 1.0 + torch.where( + dE != 0.0, + ( + k + * self.length + * beta0 + * phi.tan() + * (Ei * Ef * (beta0 * beta1 - 1) + 1) + / (beta1 * Ef * dE) + ), + torch.zeros_like(dE), ) r56 = -self.length / (Ef**2 * Ei * beta1) * (Ef + Ei) / (beta1 + beta0) r65 = ( From e01ddb1eaee3a67d742d1286dcbe321e77e0d92c Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Fri, 26 Sep 2025 16:36:08 +0200 Subject: [PATCH 13/13] Revert "Blindly remove all `where`s" This reverts commit 2cea65d34893b61201affe3b9ae9c5c5c7b5234e. --- cheetah/accelerator/dipole.py | 5 +++- cheetah/accelerator/solenoid.py | 2 +- cheetah/track_methods.py | 42 ++++++++++++++++++++++----------- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 77051a4ab..51d9b43cc 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -293,6 +293,9 @@ def _bmadx_body( phi1 = torch.arcsin(px / px_norm) g = self.angle / self.length gp = g.unsqueeze(-1) / px_norm + gp_safe = torch.where( + gp != 0, gp, torch.tensor(1e-12, dtype=gp.dtype, device=gp.device) + ) alpha = ( 2 @@ -319,7 +322,7 @@ def _bmadx_body( x2_t3 = torch.cos(self.angle.unsqueeze(-1) + phi1) c1 = x2_t1 + alpha / (x2_t2 + x2_t3) - c2 = x2_t1 + (x2_t2 - x2_t3) / gp + c2 = x2_t1 + (x2_t2 - x2_t3) / gp_safe temp = torch.abs(self.angle.unsqueeze(-1) + phi1) x2 = c1 * (temp < torch.pi / 2) + c2 * (temp >= torch.pi / 2) diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index a4a949a95..51fc5ff60 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -70,7 +70,7 @@ def first_order_transfer_map( c = torch.cos(self.length * self.k) s = torch.sin(self.length * self.k) - s_k = s / self.k + s_k = torch.where(self.k == 0.0, self.length, s / self.k) vector_shape = torch.broadcast_shapes( self.length.shape, self.k.shape, energy.shape diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 2eea99d6a..92c1f11b8 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -46,7 +46,12 @@ def base_rmatrix( r = (0.5 * kx * length / torch.pi).sinc() dx = hx * 0.5 * length.square() * r.square().real - r56 = hx**2 * (length - sx) / kx2 * -length * igamma2 + kx2_is_not_zero = kx2 != 0 + r56 = ( + torch.where(kx2_is_not_zero, hx**2 * (length - sx) / kx2, zero) + * -length + * igamma2 + ) vector_shape = torch.broadcast_shapes( length.shape, k1.shape, hx.shape, tilt.shape, energy.shape @@ -112,24 +117,33 @@ def base_ttensor( cy = torch.cos(ky * length).real sx = (torch.sinc(kx * length / torch.pi) * length).real sy = (torch.sinc(ky * length / torch.pi) * length).real - dx = (1.0 - cx) / kx2 + dx = torch.where(kx != 0, (1.0 - cx) / kx2, length**2 / 2.0) d2y = 0.5 * sy**2 s2y = sy * cy c2y = torch.cos(2 * ky * length).real - fx = (length - sx) / kx2 - f2y = (length - s2y) / ky2 - - j1 = (length - sx) / kx2 - j2 = (3.0 * length - 4.0 * sx + sx * cx) / (2 * kx2**2) - j3 = ( - 15.0 * length - 22.5 * sx + 9.0 * sx * cx - 1.5 * sx * cx**2 + kx2 * sx**3 - ) / (6.0 * kx2**3) + fx = torch.where(kx2 != 0, (length - sx) / kx2, length**3 / 6.0) + f2y = torch.where(ky2 != 0, (length - s2y) / ky2, length**3 / 6.0) + + j1 = torch.where(kx2 != 0, (length - sx) / kx2, length**3 / 6.0) + j2 = torch.where( + kx2 != 0, + (3.0 * length - 4.0 * sx + sx * cx) / (2 * kx2**2), + length**5 / 20.0, + ) + j3 = torch.where( + kx2 != 0, + (15.0 * length - 22.5 * sx + 9.0 * sx * cx - 1.5 * sx * cx**2 + kx2 * sx**3) + / (6.0 * kx2**3), + length**7 / 56.0, + ) j_denominator = kx2 - 4 * ky2 - jc = (c2y - cx) / j_denominator - js = (cy * sy - sx) / j_denominator - jd = (d2y - dx) / j_denominator - jf = (f2y - fx) / j_denominator + jc = torch.where(j_denominator != 0, (c2y - cx) / j_denominator, 0.5 * length**2) + js = torch.where( + j_denominator != 0, (cy * sy - sx) / j_denominator, length**3 / 6.0 + ) + jd = torch.where(j_denominator != 0, (d2y - dx) / j_denominator, length**4 / 24.0) + jf = torch.where(j_denominator != 0, (f2y - fx) / j_denominator, length**5 / 120.0) khk = k2 + 2 * hx * k1