Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 15 additions & 33 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParameterBeam, ParticleBeam, Species
from cheetah.track_methods import base_rmatrix
from cheetah.utils import (
UniqueNameGenerator,
cache_transfer_map,
Expand Down Expand Up @@ -81,18 +80,7 @@ def is_skippable(self) -> bool:
def first_order_transfer_map(
self, energy: torch.Tensor, species: Species
) -> torch.Tensor:
return torch.where(
(self.voltage != 0).unsqueeze(-1).unsqueeze(-1),
self._cavity_rmatrix(energy, species),
base_rmatrix(
length=self.length,
k1=torch.zeros_like(self.length),
hx=torch.zeros_like(self.length),
species=species,
tilt=torch.zeros_like(self.length),
energy=energy,
),
)
return self._cavity_rmatrix(energy, species)

def track(self, incoming: Beam) -> Beam:
gamma0, igamma2, beta0 = compute_relativistic_factors(
Expand Down Expand Up @@ -268,25 +256,17 @@ def _cavity_rmatrix(self, energy: torch.Tensor, species: Species) -> torch.Tenso
beta1 = torch.sqrt(1 - 1 / Ef**2)

r11 = torch.cos(alpha) - math.sqrt(2.0) * torch.cos(phi) * torch.sin(alpha)

# In Ocelot r12 is defined as below only if abs(Ep) > 10, and self.length
# otherwise. This is implemented differently here to achieve results
# closer to Bmad.
r12 = (
math.sqrt(8.0)
* energy
/ effective_voltage
* torch.sin(alpha)
torch.sinc(alpha / torch.pi)
* log1pdiv(delta_energy / energy)
* self.length
)

r21 = -(
effective_voltage
/ ((energy + delta_energy) * math.sqrt(2.0) * self.length)
* (0.5 + torch.cos(phi) ** 2)
* torch.sin(alpha)
)

r22 = (
Ei
/ Ef
Expand All @@ -296,15 +276,17 @@ def _cavity_rmatrix(self, energy: torch.Tensor, species: Species) -> torch.Tenso
)
)

r55 = 1.0 + (
k
* self.length
* beta0
* effective_voltage
/ species.mass_eV
* torch.sin(phi)
* (Ei * Ef * (beta0 * beta1 - 1) + 1)
/ (beta1 * Ef * (Ei - Ef) ** 2)
r55 = 1.0 + torch.where(
dE != 0.0,
(
k
* self.length
* beta0
* phi.tan()
* (Ei * Ef * (beta0 * beta1 - 1) + 1)
/ (beta1 * Ef * dE)
),
torch.zeros_like(dE),
)
r56 = -self.length / (Ef**2 * Ei * beta1) * (Ef + Ei) / (beta1 + beta0)
r65 = (
Expand All @@ -317,7 +299,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor, species: Species) -> torch.Tenso

elif self.cavity_type == "traveling_wave":
# Reference paper: Rosenzweig and Serafini, PhysRevE, Vol.49, p.1599,(1994)
f = Ei / dE * torch.log(1 + (dE / Ei))
f = log1pdiv(dE / Ei)

vector_shape = torch.broadcast_shapes(
self.length.shape, f.shape, Ei.shape, Ef.shape
Expand Down
77 changes: 26 additions & 51 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(

@property
def hx(self) -> torch.Tensor:
return torch.where(self.length == 0.0, 0.0, self.angle / self.length)
return self.angle / self.length # Zero length not caught because not physical

@property
def dipole_e1(self) -> torch.Tensor:
Expand Down Expand Up @@ -402,82 +402,57 @@ def _bmadx_fringe_linear(
def first_order_transfer_map(
self, energy: torch.Tensor, species: Species
) -> torch.Tensor:
factory_kwargs = {"device": self.length.device, "dtype": self.length.dtype}

R_enter = self._transfer_map_enter()
R_exit = self._transfer_map_exit()

if torch.any(self.length != 0.0): # Bending magnet with finite length
R = base_rmatrix(
length=self.length,
k1=self.k1,
hx=self.hx,
species=species,
energy=energy,
) # Tilt is applied after adding edges
else: # Reduce to Thin-Corrector
R = torch.eye(7, **factory_kwargs).repeat((*self.length.shape, 1, 1))
R[..., 0, 1] = self.length
R[..., 2, 6] = self.angle
R[..., 2, 3] = self.length
R = base_rmatrix(
length=self.length,
k1=self.k1,
hx=self.hx,
species=species,
energy=energy,
) # Tilt is applied after adding edges

# Apply fringe fields
R = R_exit @ R @ R_enter

# Apply rotation for tilted magnets
if torch.any(self.tilt != 0):
rotation = rotation_matrix(self.tilt)
R = rotation.mT @ R @ rotation
rotation = rotation_matrix(self.tilt)
R = rotation.mT @ R @ rotation
Comment on lines +420 to +421
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be worth it to build an autograd function for this sort of function that only computes the tilt optionally in a forward pass, but always computes the gradient as if it had been in there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be, but that also depends on how expensive the multiplication is. The overhead of the custom autograd functions seems to be non-negligable from the Cavity example we implemented.


return R

@cache_transfer_map
def second_order_transfer_map(
self, energy: torch.Tensor, species: Species
) -> torch.Tensor:
factory_kwargs = {"device": self.length.device, "dtype": self.length.dtype}

R_enter = self._transfer_map_enter()
R_exit = self._transfer_map_exit()

if torch.any(self.length != 0.0): # Bending magnet with finite length
T = base_ttensor(
length=self.length,
k1=self.k1,
k2=torch.tensor(0.0, **factory_kwargs),
hx=self.hx,
species=species,
energy=energy,
)

# Fill the first-order transfer map into the second-order transfer map
T[..., :, 6, :] = base_rmatrix(
length=self.length,
k1=self.k1,
hx=self.hx,
species=species,
energy=energy,
)
else: # Reduce to Thin-Corrector
R = torch.eye(7, **factory_kwargs).repeat((*self.length.shape, 1, 1))
R[..., 0, 1] = self.length
R[..., 2, 6] = self.angle
R[..., 2, 3] = self.length
T = base_ttensor(
length=self.length,
k1=self.k1,
k2=self.length.new_zeros(()),
hx=self.hx,
species=species,
energy=energy,
)

T = torch.zeros((*self.length.shape, 7, 7), **factory_kwargs)
T[..., :, 6, :] = R
# Fill the first-order transfer map into the second-order transfer map
T[..., :, 6, :] = base_rmatrix(
length=self.length, k1=self.k1, hx=self.hx, species=species, energy=energy
)

# Apply fringe fields
T = torch.einsum(
"...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_enter, R_enter
)

# Apply rotation for tilted magnets
if torch.any(self.tilt != 0):
rotation = rotation_matrix(self.tilt)
T = torch.einsum(
"...ji,...jkl,...kn,...lm->...inm", rotation, T, rotation, rotation
)
rotation = rotation_matrix(self.tilt)
T = torch.einsum(
"...ji,...jkl,...kn,...lm->...inm", rotation, T, rotation, rotation
)

return T

Expand Down
14 changes: 6 additions & 8 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ def first_order_transfer_map(
energy=energy,
)

if torch.any(self.misalignment != 0):
R_entry, R_exit = misalignment_matrix(self.misalignment)
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)
R_entry, R_exit = misalignment_matrix(self.misalignment)
R = R_exit @ R @ R_entry

return R

Expand Down Expand Up @@ -126,11 +125,10 @@ def second_order_transfer_map(
)

# Apply misalignments to the entire second-order transfer map
if not torch.all(self.misalignment == 0):
R_entry, R_exit = misalignment_matrix(self.misalignment)
T = torch.einsum(
"...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry
)
R_entry, R_exit = misalignment_matrix(self.misalignment)
T = torch.einsum(
"...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry
)

return T

Expand Down
9 changes: 4 additions & 5 deletions cheetah/accelerator/sextupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,10 @@ def second_order_transfer_map(self, energy, species):
)

# Apply misalignments to the entire second-order transfer map
if not torch.all(self.misalignment == 0):
R_entry, R_exit = misalignment_matrix(self.misalignment)
T = torch.einsum(
"...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry
)
R_entry, R_exit = misalignment_matrix(self.misalignment)
T = torch.einsum(
"...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry
)

return T

Expand Down
14 changes: 5 additions & 9 deletions cheetah/accelerator/solenoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def first_order_transfer_map(
self.length.shape, self.k.shape, energy.shape
)

r56 = torch.where(
gamma != 0, self.length / (1 - gamma**2), torch.zeros_like(self.length)
)
r56 = self.length / (1 - gamma**2)

R = torch.eye(7, **factory_kwargs).repeat((*vector_shape, 1, 1))
R[..., 0, 0] = c**2
Expand All @@ -101,12 +99,10 @@ def first_order_transfer_map(

R = R.real

if torch.all(self.misalignment == 0):
return R
else:
R_entry, R_exit = misalignment_matrix(self.misalignment)
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)
return R
R_entry, R_exit = misalignment_matrix(self.misalignment)
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)

return R

@property
def is_active(self) -> bool:
Expand Down
12 changes: 6 additions & 6 deletions cheetah/particles/parameter_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,47 +553,47 @@ def mu_x(self) -> torch.Tensor:

@property
def sigma_x(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 0, 0], 1e-20))
return torch.sqrt(self.cov[..., 0, 0])

@property
def mu_px(self) -> torch.Tensor:
return self.mu[..., 1]

@property
def sigma_px(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 1, 1], 1e-20))
return torch.sqrt(self.cov[..., 1, 1])

@property
def mu_y(self) -> torch.Tensor:
return self.mu[..., 2]

@property
def sigma_y(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 2, 2], 1e-20))
return torch.sqrt(self.cov[..., 2, 2])

@property
def mu_py(self) -> torch.Tensor:
return self.mu[..., 3]

@property
def sigma_py(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 3, 3], 1e-20))
return torch.sqrt(self.cov[..., 3, 3])

@property
def mu_tau(self) -> torch.Tensor:
return self.mu[..., 4]

@property
def sigma_tau(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 4, 4], 1e-20))
return torch.sqrt(self.cov[..., 4, 4])

@property
def mu_p(self) -> torch.Tensor:
return self.mu[..., 5]

@property
def sigma_p(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 5, 5], 1e-20))
return torch.sqrt(self.cov[..., 5, 5])

@property
def cov_xpx(self) -> torch.Tensor:
Expand Down
Loading
Loading