Skip to content
Draft
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 4 additions & 24 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParameterBeam, ParticleBeam, Species
from cheetah.track_methods import base_rmatrix
from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors
from cheetah.utils.autograd import log1pdiv

Expand Down Expand Up @@ -76,18 +75,7 @@ def is_skippable(self) -> bool:
def _compute_first_order_transfer_map(
self, energy: torch.Tensor, species: Species
) -> torch.Tensor:
return torch.where(
(self.voltage != 0).unsqueeze(-1).unsqueeze(-1),
self._cavity_rmatrix(energy, species),
base_rmatrix(
length=self.length,
k1=torch.zeros_like(self.length),
hx=torch.zeros_like(self.length),
species=species,
tilt=torch.zeros_like(self.length),
energy=energy,
),
)
return self._cavity_rmatrix(energy, species)

def track(self, incoming: Beam) -> Beam:
gamma0, igamma2, beta0 = compute_relativistic_factors(
Expand Down Expand Up @@ -261,25 +249,17 @@ def _cavity_rmatrix(self, energy: torch.Tensor, species: Species) -> torch.Tenso
beta1 = torch.sqrt(1 - 1 / Ef**2)

r11 = torch.cos(alpha) - math.sqrt(2.0) * torch.cos(phi) * torch.sin(alpha)

# In Ocelot r12 is defined as below only if abs(Ep) > 10, and self.length
# otherwise. This is implemented differently here to achieve results
# closer to Bmad.
r12 = (
math.sqrt(8.0)
* energy
/ effective_voltage
* torch.sin(alpha)
torch.sinc(alpha / torch.pi)
* log1pdiv(delta_energy / energy)
* self.length
)

r21 = -(
effective_voltage
/ ((energy + delta_energy) * math.sqrt(2.0) * self.length)
* (0.5 + torch.cos(phi) ** 2)
* torch.sin(alpha)
)

r22 = (
Ei
/ Ef
Expand Down Expand Up @@ -310,7 +290,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor, species: Species) -> torch.Tenso

elif self.cavity_type == "traveling_wave":
# Reference paper: Rosenzweig and Serafini, PhysRevE, Vol.49, p.1599,(1994)
f = Ei / dE * torch.log(1 + (dE / Ei))
f = log1pdiv(dE / Ei)

vector_shape = torch.broadcast_shapes(
self.length.shape, f.shape, Ei.shape, Ef.shape
Expand Down
85 changes: 30 additions & 55 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(

@property
def hx(self) -> torch.Tensor:
return torch.where(self.length == 0.0, 0.0, self.angle / self.length)
return self.angle / self.length # Zero length not caught because not physical

@property
def dipole_e1(self) -> torch.Tensor:
Expand Down Expand Up @@ -401,85 +401,60 @@ def _bmadx_fringe_linear(
def _compute_first_order_transfer_map(
self, energy: torch.Tensor, species: Species
) -> torch.Tensor:
factory_kwargs = {"device": self.length.device, "dtype": self.length.dtype}

R_enter = self._transfer_map_enter()
R_exit = self._transfer_map_exit()

if torch.any(self.length != 0.0): # Bending magnet with finite length
R = base_rmatrix(
length=self.length,
k1=self.k1,
hx=self.hx,
species=species,
energy=energy,
) # Tilt is applied after adding edges
else: # Reduce to Thin-Corrector
R = torch.eye(7, **factory_kwargs).repeat((*self.length.shape, 1, 1))
R[..., 0, 1] = self.length
R[..., 2, 6] = self.angle
R[..., 2, 3] = self.length
R = base_rmatrix(
length=self.length,
k1=self.k1,
hx=self.hx,
species=species,
energy=energy,
) # Tilt is applied after adding edges

# Apply fringe fields
R = R_exit @ R @ R_enter

# Apply rotation for tilted magnets
if torch.any(self.tilt != 0):
rotation = rotation_matrix(self.tilt)
R = rotation.transpose(-1, -2) @ R @ rotation
rotation = rotation_matrix(self.tilt)
R = rotation.transpose(-1, -2) @ R @ rotation

return R

def _compute_second_order_transfer_map(
self, energy: torch.Tensor, species: Species
) -> torch.Tensor:
factory_kwargs = {"device": self.length.device, "dtype": self.length.dtype}

R_enter = self._transfer_map_enter()
R_exit = self._transfer_map_exit()

if torch.any(self.length != 0.0): # Bending magnet with finite length
T = base_ttensor(
length=self.length,
k1=self.k1,
k2=torch.tensor(0.0, **factory_kwargs),
hx=self.hx,
species=species,
energy=energy,
)

# Fill the first-order transfer map into the second-order transfer map
T[..., :, 6, :] = base_rmatrix(
length=self.length,
k1=self.k1,
hx=self.hx,
species=species,
energy=energy,
)
else: # Reduce to Thin-Corrector
R = torch.eye(7, **factory_kwargs).repeat((*self.length.shape, 1, 1))
R[..., 0, 1] = self.length
R[..., 2, 6] = self.angle
R[..., 2, 3] = self.length
T = base_ttensor(
length=self.length,
k1=self.k1,
k2=self.length.new_zeros(()),
hx=self.hx,
species=species,
energy=energy,
)

T = torch.zeros((*self.length.shape, 7, 7), **factory_kwargs)
T[..., :, 6, :] = R
# Fill the first-order transfer map into the second-order transfer map
T[..., :, 6, :] = base_rmatrix(
length=self.length, k1=self.k1, hx=self.hx, species=species, energy=energy
)

# Apply fringe fields
T = torch.einsum(
"...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_enter, R_enter
)

# Apply rotation for tilted magnets
if torch.any(self.tilt != 0):
rotation = rotation_matrix(self.tilt)
T = torch.einsum(
"...ij,...jkl,...kn,...lm->...inm",
rotation.transpose(-1, -2),
T,
rotation,
rotation,
)
rotation = rotation_matrix(self.tilt)
T = torch.einsum(
"...ji,...jkl,...kn,...lm->...inm",
rotation, # Switch index labels in einsum instead of transpose (faster)
T,
rotation,
rotation,
)

return T

Expand Down
14 changes: 6 additions & 8 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ def _compute_first_order_transfer_map(
energy=energy,
)

if torch.any(self.misalignment != 0):
R_entry, R_exit = misalignment_matrix(self.misalignment)
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)
R_entry, R_exit = misalignment_matrix(self.misalignment)
R = R_exit @ R @ R_entry

return R

Expand Down Expand Up @@ -119,11 +118,10 @@ def _compute_second_order_transfer_map(
)

# Apply misalignments to the entire second-order transfer map
if not torch.all(self.misalignment == 0):
R_entry, R_exit = misalignment_matrix(self.misalignment)
T = torch.einsum(
"...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry
)
R_entry, R_exit = misalignment_matrix(self.misalignment)
T = torch.einsum(
"...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry
)

return T

Expand Down
9 changes: 4 additions & 5 deletions cheetah/accelerator/sextupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,10 @@ def _compute_second_order_transfer_map(self, energy, species):
)

# Apply misalignments to the entire second-order transfer map
if not torch.all(self.misalignment == 0):
R_entry, R_exit = misalignment_matrix(self.misalignment)
T = torch.einsum(
"...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry
)
R_entry, R_exit = misalignment_matrix(self.misalignment)
T = torch.einsum(
"...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_entry, R_entry
)

return T

Expand Down
14 changes: 5 additions & 9 deletions cheetah/accelerator/solenoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def _compute_first_order_transfer_map(
self.length.shape, self.k.shape, energy.shape
)

r56 = torch.where(
gamma != 0, self.length / (1 - gamma**2), torch.zeros_like(self.length)
)
r56 = self.length / (1 - gamma**2)

R = torch.eye(7, **factory_kwargs).repeat((*vector_shape, 1, 1))
R[..., 0, 0] = c**2
Expand All @@ -96,12 +94,10 @@ def _compute_first_order_transfer_map(

R = R.real

if torch.all(self.misalignment == 0):
return R
else:
R_entry, R_exit = misalignment_matrix(self.misalignment)
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)
return R
R_entry, R_exit = misalignment_matrix(self.misalignment)
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)

return R

@property
def is_active(self) -> bool:
Expand Down
12 changes: 6 additions & 6 deletions cheetah/particles/parameter_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,47 +553,47 @@ def mu_x(self) -> torch.Tensor:

@property
def sigma_x(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 0, 0], 1e-20))
return torch.sqrt(self.cov[..., 0, 0])

@property
def mu_px(self) -> torch.Tensor:
return self.mu[..., 1]

@property
def sigma_px(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 1, 1], 1e-20))
return torch.sqrt(self.cov[..., 1, 1])

@property
def mu_y(self) -> torch.Tensor:
return self.mu[..., 2]

@property
def sigma_y(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 2, 2], 1e-20))
return torch.sqrt(self.cov[..., 2, 2])

@property
def mu_py(self) -> torch.Tensor:
return self.mu[..., 3]

@property
def sigma_py(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 3, 3], 1e-20))
return torch.sqrt(self.cov[..., 3, 3])

@property
def mu_tau(self) -> torch.Tensor:
return self.mu[..., 4]

@property
def sigma_tau(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 4, 4], 1e-20))
return torch.sqrt(self.cov[..., 4, 4])

@property
def mu_p(self) -> torch.Tensor:
return self.mu[..., 5]

@property
def sigma_p(self) -> torch.Tensor:
return torch.sqrt(torch.clamp_min(self.cov[..., 5, 5], 1e-20))
return torch.sqrt(self.cov[..., 5, 5])

@property
def cov_xpx(self) -> torch.Tensor:
Expand Down
56 changes: 20 additions & 36 deletions cheetah/track_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,16 @@ def base_rmatrix(
cy = torch.cos(ky * length).real
sx = (torch.sinc(kx * length / torch.pi) * length).real
sy = (torch.sinc(ky * length / torch.pi) * length).real
dx = torch.where(kx2 != 0, hx / kx2 * (1.0 - cx), zero)
r56 = torch.where(kx2 != 0, hx**2 * (length - sx) / kx2 / beta**2, zero)

r56 = r56 - length / beta**2 * igamma2
r = (0.5 * kx * length / torch.pi).sinc()
dx = hx * 0.5 * length.square() * r.square().real

kx2_is_not_zero = kx2 != 0
r56 = (
torch.where(kx2_is_not_zero, hx**2 * (length - sx) / kx2, zero)
* -length
* igamma2
)

vector_shape = torch.broadcast_shapes(
length.shape, k1.shape, hx.shape, tilt.shape, energy.shape
Expand All @@ -66,18 +72,8 @@ def base_rmatrix(
R[..., 4, 1] = dx / beta
R[..., 4, 5] = r56

# Rotate the R matrix for skew / vertical magnets. The rotation only has an effect
# if hx != 0 or k1 != 0. Note that the first if is here to improve speed when no
# rotation needs to be applied accross all vector dimensions. The torch.where is
# here to improve numerical stability for the vector elements where no rotation
# needs to be applied.
if torch.any((tilt != 0) & ((hx != 0) | (k1 != 0))):
rotation = rotation_matrix(tilt)
R = torch.where(
((tilt != 0) & ((hx != 0) | (k1 != 0))).unsqueeze(-1).unsqueeze(-1),
rotation.transpose(-1, -2) @ R @ rotation,
R,
)
rotation = rotation_matrix(tilt)
R = rotation.transpose(-1, -2) @ R @ rotation

return R

Expand Down Expand Up @@ -271,27 +267,15 @@ def base_ttensor(
- 0.25 / beta * (length + cy * sy)
)

# Rotate the T tensor for skew / vertical magnets. The rotation only has an effect
# if hx != 0, k1 != 0 or k2 != 0. Note that the first if is here to improve speed
# when no rotation needs to be applied accross all vector dimensions. The
# torch.where is here to improve numerical stability for the vector elements where
# no rotation needs to be applied.
if torch.any((tilt != 0) & ((hx != 0) | (k1 != 0) | (k2 != 0))):
rotation = rotation_matrix(tilt)
T = torch.where(
((tilt != 0) & ((hx != 0) | (k1 != 0) | (k2 != 0)))
.unsqueeze(-1)
.unsqueeze(-1)
.unsqueeze(-1),
torch.einsum(
"...ij,...jkl,...kn,...lm->...inm",
rotation.transpose(-1, -2),
T,
rotation,
rotation,
),
T,
)
rotation = rotation_matrix(tilt)
T = torch.einsum(
"...ji,...jkl,...kn,...lm->...inm",
rotation, # Switch index labels in einsum instead of transpose (faster)
T,
rotation,
rotation,
)

return T


Expand Down
Loading
Loading