diff --git a/CHANGELOG.md b/CHANGELOG.md index e5a7f0341..3830aa41d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ### 🚀 Features - Implement `split` method for the `Solenoid` element (see #380) (@cr-xu) +- Add a `Sextupole` element (see #406) (@jank324, @Hespe) ### 🐛 Bug fixes diff --git a/cheetah/__init__.py b/cheetah/__init__.py index 5a37aa789..86ddc7eb7 100644 --- a/cheetah/__init__.py +++ b/cheetah/__init__.py @@ -9,10 +9,12 @@ Element, HorizontalCorrector, Marker, + Octupole, Quadrupole, RBend, Screen, Segment, + Sextupole, Solenoid, SpaceChargeKick, TransverseDeflectingCavity, diff --git a/cheetah/accelerator/__init__.py b/cheetah/accelerator/__init__.py index bc783dcee..5b8d6b5b9 100644 --- a/cheetah/accelerator/__init__.py +++ b/cheetah/accelerator/__init__.py @@ -7,10 +7,12 @@ from .element import Element # noqa: F401 from .horizontal_corrector import HorizontalCorrector # noqa: F401 from .marker import Marker # noqa: F401 +from .octupole import Octupole # noqa: F401 from .quadrupole import Quadrupole # noqa: F401 from .rbend import RBend # noqa: F401 from .screen import Screen # noqa: F401 from .segment import Segment # noqa: F401 +from .sextupole import Sextupole # noqa: F401 from .solenoid import Solenoid # noqa: F401 from .space_charge_kick import SpaceChargeKick # noqa: F401 from .transverse_deflecting_cavity import TransverseDeflectingCavity # noqa: F401 diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index cf72e1dc3..5bf701402 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -5,7 +5,12 @@ from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParticleBeam, Species -from cheetah.utils import UniqueNameGenerator, bmadx, compute_relativistic_factors +from cheetah.utils import ( + UniqueNameGenerator, + bmadx, + compute_relativistic_factors, + verify_device_and_dtype, +) generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -30,6 +35,7 @@ def __init__( device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: + device, dtype = verify_device_and_dtype([length], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name, **factory_kwargs) diff --git a/cheetah/accelerator/octupole.py b/cheetah/accelerator/octupole.py new file mode 100644 index 000000000..e760d6bf6 --- /dev/null +++ b/cheetah/accelerator/octupole.py @@ -0,0 +1,137 @@ +import matplotlib.pyplot as plt +import torch + +from cheetah.accelerator.element import Element +from cheetah.particles import Beam, ParameterBeam, ParticleBeam, Species +from cheetah.track_methods import base_rmatrix, base_tmatrix, misalignment_matrix +from cheetah.utils import verify_device_and_dtype + + +class Octupole(Element): + """ + An octupole element in a particle accelerator. + + :param length: Length in meters. + :param k3: TODO + :param misalignment: TODO + :param tilt: TODO + :param name: Unique identifier of the element. + """ + + def __init__( + self, + length: torch.Tensor, + k3: torch.Tensor | None = None, + misalignment: torch.Tensor | None = None, + tilt: torch.Tensor | None = None, + name: str | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> None: + device, dtype = verify_device_and_dtype([length, k3], device, dtype) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name, **factory_kwargs) + + self.length = torch.as_tensor(length, **factory_kwargs) + + self.register_buffer_or_parameter( + "k3", torch.as_tensor(k3 if k3 is not None else 0.0, **factory_kwargs) + ) + self.register_buffer_or_parameter( + "misalignment", + torch.as_tensor( + misalignment if misalignment is not None else (0.0, 0.0), + **factory_kwargs, + ), + ) + self.register_buffer_or_parameter( + "tilt", torch.as_tensor(tilt if tilt is not None else 0.0, **factory_kwargs) + ) + + def transfer_map(self, energy: torch.Tensor, species: Species) -> torch.Tensor: + R = base_rmatrix( + length=self.length, + k1=torch.zeros_like(self.length), + hx=torch.zeros_like(self.length), + species=species, + tilt=self.tilt, + energy=energy, + ) + + if torch.all(self.misalignment == 0): + return R + else: + R_entry, R_exit = misalignment_matrix(self.misalignment) + R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry) + return R + + def track(self, incoming: Beam) -> Beam: + """ + Track the beam through the sextupole element. + + :param incoming: Beam entering the element. + :return: Beam exiting the element. + """ + first_order_tm = self.transfer_map(incoming.energy, incoming.species) + second_order_tm = base_tmatrix( + length=self.length, + k1=torch.zeros_like(self.length), + k2=self.k2, + hx=torch.zeros_like(self.length), + species=incoming.species, + tilt=self.tilt, + energy=incoming.energy, + ) + + if isinstance(incoming, ParameterBeam): + # For ParameterBeam, only first-order effects are applied + return super().track(incoming) + elif isinstance(incoming, ParticleBeam): + # Apply the transfer map to the incoming particles + first_order_particles = torch.matmul( + incoming.particles, first_order_tm.transpose(-2, -1) + ) + second_order_particles = torch.einsum( + "...ijk,...j,...k->...i", + second_order_tm, + incoming.particles, + incoming.particles, + ) + outgoing_particles = second_order_particles + first_order_particles + + return ParticleBeam( + particles=outgoing_particles, + energy=incoming.energy, + particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, + species=incoming.species, + ) + else: + raise TypeError( + f"Unsupported beam type: {type(incoming)}. Expected ParameterBeam or " + "ParticleBeam." + ) + + @property + def is_skippable(self) -> bool: + return False + + @property + def is_active(self) -> bool: + return torch.any(self.k2 != 0.0).item() + + def split(self, resolution: torch.Tensor) -> list[Element]: + raise NotImplementedError + + def plot(self, ax: plt.Axes, s: float, vector_idx: tuple | None = None) -> None: + raise NotImplementedError + + def defining_features(self) -> list[str]: + return super().defining_features() + ["length", "k2"] + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(length={repr(self.length)}, " + f"k2={repr(self.k2)}, " + f"name={repr(self.name)})" + ) diff --git a/cheetah/accelerator/sextupole.py b/cheetah/accelerator/sextupole.py new file mode 100644 index 000000000..be35b51ac --- /dev/null +++ b/cheetah/accelerator/sextupole.py @@ -0,0 +1,137 @@ +import matplotlib.pyplot as plt +import torch + +from cheetah.accelerator.element import Element +from cheetah.particles import Beam, ParameterBeam, ParticleBeam, Species +from cheetah.track_methods import base_rmatrix, base_tmatrix, misalignment_matrix +from cheetah.utils import verify_device_and_dtype + + +class Sextupole(Element): + """ + A sextupole element in a particle accelerator. + + :param length: Length in meters. + :param k2: TODO + :param misalignment: TODO + :param tilt: TODO + :param name: Unique identifier of the element. + """ + + def __init__( + self, + length: torch.Tensor, + k2: torch.Tensor | None = None, + misalignment: torch.Tensor | None = None, + tilt: torch.Tensor | None = None, + name: str | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> None: + device, dtype = verify_device_and_dtype([length, k2], device, dtype) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name, **factory_kwargs) + + self.length = torch.as_tensor(length, **factory_kwargs) + + self.register_buffer_or_parameter( + "k2", torch.as_tensor(k2 if k2 is not None else 0.0, **factory_kwargs) + ) + self.register_buffer_or_parameter( + "misalignment", + torch.as_tensor( + misalignment if misalignment is not None else (0.0, 0.0), + **factory_kwargs, + ), + ) + self.register_buffer_or_parameter( + "tilt", torch.as_tensor(tilt if tilt is not None else 0.0, **factory_kwargs) + ) + + def transfer_map(self, energy: torch.Tensor, species: Species) -> torch.Tensor: + R = base_rmatrix( + length=self.length, + k1=torch.zeros_like(self.length), + hx=torch.zeros_like(self.length), + species=species, + tilt=self.tilt, + energy=energy, + ) + + if torch.all(self.misalignment == 0): + return R + else: + R_entry, R_exit = misalignment_matrix(self.misalignment) + R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry) + return R + + def track(self, incoming: Beam) -> Beam: + """ + Track the beam through the sextupole element. + + :param incoming: Beam entering the element. + :return: Beam exiting the element. + """ + first_order_tm = self.transfer_map(incoming.energy, incoming.species) + second_order_tm = base_tmatrix( + length=self.length, + k1=torch.zeros_like(self.length), + k2=self.k2, + hx=torch.zeros_like(self.length), + species=incoming.species, + tilt=self.tilt, + energy=incoming.energy, + ) + + if isinstance(incoming, ParameterBeam): + # For ParameterBeam, only first-order effects are applied + return super().track(incoming) + elif isinstance(incoming, ParticleBeam): + # Apply the transfer map to the incoming particles + first_order_particles = torch.matmul( + incoming.particles, first_order_tm.transpose(-2, -1) + ) + second_order_particles = torch.einsum( + "...ijk,...j,...k->...i", + second_order_tm, + incoming.particles, + incoming.particles, + ) + outgoing_particles = second_order_particles + first_order_particles + + return ParticleBeam( + particles=outgoing_particles, + energy=incoming.energy, + particle_charges=incoming.particle_charges, + survival_probabilities=incoming.survival_probabilities, + species=incoming.species, + ) + else: + raise TypeError( + f"Unsupported beam type: {type(incoming)}. Expected ParameterBeam or " + "ParticleBeam." + ) + + @property + def is_skippable(self) -> bool: + return False + + @property + def is_active(self) -> bool: + return torch.any(self.k2 != 0.0).item() + + def split(self, resolution: torch.Tensor) -> list[Element]: + raise NotImplementedError + + def plot(self, ax: plt.Axes, s: float, vector_idx: tuple | None = None) -> None: + raise NotImplementedError + + def defining_features(self) -> list[str]: + return super().defining_features() + ["length", "k2"] + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(length={repr(self.length)}, " + f"k2={repr(self.k2)}, " + f"name={repr(self.name)})" + ) diff --git a/cheetah/track_methods.py b/cheetah/track_methods.py index 1c687c4de..59b85eb0b 100644 --- a/cheetah/track_methods.py +++ b/cheetah/track_methods.py @@ -6,29 +6,6 @@ from cheetah.utils import compute_relativistic_factors -def rotation_matrix(angle: torch.Tensor) -> torch.Tensor: - """Rotate the transfer map in x-y plane. - - :param angle: Rotation angle in rad, for example `angle = np.pi/2` for vertical = - dipole. - :return: Rotation matrix to be multiplied to the element's transfer matrix. - """ - cs = torch.cos(angle) - sn = torch.sin(angle) - - tm = torch.eye(7, dtype=angle.dtype, device=angle.device).repeat(*angle.shape, 1, 1) - tm[..., 0, 0] = cs - tm[..., 0, 2] = sn - tm[..., 1, 1] = cs - tm[..., 1, 3] = sn - tm[..., 2, 0] = -sn - tm[..., 2, 2] = cs - tm[..., 3, 1] = -sn - tm[..., 3, 3] = cs - - return tm - - def base_rmatrix( length: torch.Tensor, k1: torch.Tensor, @@ -38,15 +15,15 @@ def base_rmatrix( energy: torch.Tensor | None = None, ) -> torch.Tensor: """ - Create a universal transfer matrix for a beamline element. + Create a first order universal transfer matrix for a beamline element. :param length: Length of the element in m. :param k1: Quadrupole strength in 1/m**2. - :param hx: Curvature (1/radius) of the element in 1/m**2. + :param hx: Curvature (1/radius) of the element in 1/m. :param species: Particle species of the beam. :param tilt: Roation of the element relative to the longitudinal axis in rad. :param energy: Beam energy in eV. - :return: Transfer matrix for the element. + :return: First order transfer matrix for the element. """ device = length.device dtype = length.dtype @@ -102,6 +79,230 @@ def base_rmatrix( return R +def base_tmatrix( + length: torch.Tensor, + k1: torch.Tensor, + k2: torch.Tensor, + hx: torch.Tensor, + species: Species, + tilt: torch.Tensor | None = None, + energy: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Create a second order universal transfer matrix for a beamline element. + + :param length: Length of the element in m. + :param k1: Quadrupole strength in 1/m**2. + :param k2: Sextupole strength in 1/m**3. + :param hx: Curvature (1/radius) of the element in 1/m. + :param species: Particle species of the beam. + :param tilt: Roation of the element relative to the longitudinal axis in rad. + :param energy: Beam energy in eV. + :return: Second order transfer matrix for the element. + """ + device = length.device + dtype = length.dtype + + tilt = tilt if tilt is not None else torch.tensor(0.0, device=device, dtype=dtype) + energy = ( + energy if energy is not None else torch.tensor(0.0, device=device, dtype=dtype) + ) + + _, igamma2, beta = compute_relativistic_factors(energy, species.mass_eV) + + kx2 = k1 + hx**2 + ky2 = -k1 + kx = torch.sqrt(torch.complex(kx2, torch.tensor(0.0, device=device, dtype=dtype))) + ky = torch.sqrt(torch.complex(ky2, torch.tensor(0.0, device=device, dtype=dtype))) + cx = torch.cos(kx * length).real + cy = torch.cos(ky * length).real + sx = torch.where(kx != 0, (torch.sin(kx * length) / kx).real, length) + sy = torch.where(ky != 0, (torch.sin(ky * length) / ky).real, length) + dx = torch.where(kx != 0, (1.0 - cx) / kx2, length**2 / 2.0) + + d2y = 0.5 * sy**2 + s2y = sy * cy + c2y = torch.cos(2 * ky * length).real + fx = torch.where(kx2 != 0, (length - sx) / kx2, length**3 / 6.0) + f2y = torch.where(ky2 != 0, (length - s2y) / ky2, length**3 / 6.0) + + j1 = torch.where(kx2 != 0, (length - sx) / kx2, length**3 / 6.0) + j2 = torch.where( + kx2 != 0, + (3.0 * length - 4.0 * sx + sx * cx) / (2 * kx2**2), + length**5 / 20.0, + ) + j3 = torch.where( + kx2 != 0, + (15.0 * length - 22.5 * sx + 9.0 * sx * cx - 1.5 * sx * cx**2 + kx2 * sx**3) + / (6.0 * kx2**3), + length**7 / 56.0, + ) + j_denominator = kx2 - 4 * ky2 + jc = torch.where(j_denominator != 0, (c2y - cx) / j_denominator, 0.5 * length**2) + js = torch.where( + j_denominator != 0, (cy * sy - sx) / j_denominator, length**3 / 6.0 + ) + jd = torch.where(j_denominator != 0, (d2y - dx) / j_denominator, length**4 / 24.0) + jf = torch.where(j_denominator != 0, (f2y - fx) / j_denominator, length**5 / 120.0) + + khk = k2 + 2 * hx * k1 + + vector_shape = torch.broadcast_shapes( + length.shape, k1.shape, hx.shape, tilt.shape, energy.shape + ) + + T = torch.zeros((7, 7, 7), dtype=dtype, device=device).repeat( + *vector_shape, 1, 1, 1 + ) + T[..., 0, 0, 0] = -1 / 6 * khk * (sx**2 + dx) - 0.5 * hx * kx2 * sx**2 + T[..., 0, 0, 1] = 2 * -1 / 6 * khk * sx * dx + 0.5 * hx * sx * cx + T[..., 0, 1, 1] = -1 / 6 * khk * dx**2 + 0.5 * hx * dx * cx + T[..., 0, 0, 5] = ( + 2 * -hx / 12 / beta * khk * (3 * sx * j1 - dx**2) + + 0.5 * hx**2 / beta * sx**2 + + 0.25 / beta * k1 * length * sx + ) + T[..., 0, 1, 5] = ( + 2 * -hx / 12 / beta * khk * (sx * dx**2 - 2 * cx * j2) + + 0.25 * hx**2 / beta * (sx * dx + cx * j1) + - 0.25 / beta * (sx + length * cx) + ) + T[..., 0, 5, 5] = ( + -(hx**2) / 6 / beta**2 * khk * (dx**2 * dx - 2 * sx * j2) + + 0.5 * hx**3 / beta**2 * sx * j1 + - 0.5 * hx / beta**2 * length * sx + - 0.5 * hx / (beta**2) * igamma2 * dx + ) + T[..., 0, 2, 2] = k1 * k2 * jd + 0.5 * (k2 + hx * k1) * dx + T[..., 0, 2, 3] = 2 * 0.5 * k2 * js + T[..., 0, 3, 3] = k2 * jd - 0.5 * hx * dx + T[..., 1, 0, 0] = -1 / 6 * khk * sx * (1 + 2 * cx) + T[..., 1, 0, 1] = 2 * -1 / 6 * khk * dx * (1 + 2 * cx) + T[..., 1, 1, 1] = -1 / 3 * khk * sx * dx - 0.5 * hx * sx + T[..., 1, 0, 5] = 2 * -hx / 12 / beta * khk * ( + 3 * cx * j1 + sx * dx + ) - 0.25 / beta * k1 * (sx - length * cx) + T[..., 1, 1, 5] = ( + 2 * -hx / 12 / beta * khk * (3 * sx * j1 + dx**2) + + 0.25 / beta * k1 * length * sx + ) + T[..., 1, 5, 5] = ( + -(hx**2) / 6 / beta**2 * khk * (sx * dx**2 - 2 * cx * j2) + - 0.5 * hx / beta**2 * k1 * (cx * j1 - sx * dx) + - 0.5 * hx / beta**2 * igamma2 * sx + ) + T[..., 1, 2, 2] = k1 * k2 * js + 0.5 * (k2 + hx * k1) * sx + T[..., 1, 2, 3] = 2 * 0.5 * k2 * jc + T[..., 1, 3, 3] = k2 * js - 0.5 * hx * sx + T[..., 2, 0, 2] = ( + 2 * 0.5 * k2 * (cy * jc - 2 * k1 * sy * js) + 0.5 * hx * k1 * sx * sy + ) + T[..., 2, 0, 3] = 2 * 0.5 * k2 * (sy * jc - 2 * cy * js) + 0.5 * hx * sx * cy + T[..., 2, 1, 2] = ( + 2 * 0.5 * k2 * (cy * js - 2 * k1 * sy * jd) + 0.5 * hx * k1 * dx * sy + ) + T[..., 2, 1, 3] = 2 * 0.5 * k2 * (sy * js - 2 * cy * jd) + 0.5 * hx * dx * cy + T[..., 2, 2, 5] = ( + 2 * 0.5 * hx / beta * k2 * (cy * jd - 2 * k1 * sy * jf) + + 0.5 * hx**2 / beta * k1 * j1 * sy + - 0.25 / beta * k1 * length * sy + ) + T[..., 2, 3, 5] = ( + 2 * 0.5 * hx / beta * k2 * (sy * jd - 2 * cy * jf) + + 0.5 * hx**2 / beta * j1 * cy + - 0.25 / beta * (sy + length * cy) + ) + T[..., 3, 0, 2] = ( + 2 * 0.5 * k1 * k2 * (2 * cy * js - sy * jc) + 0.5 * (k2 + hx * k1) * sx * cy + ) + T[..., 3, 0, 3] = ( + 2 * 0.5 * k2 * (2 * k1 * sy * js - cy * jc) + 0.5 * (k2 + hx * k1) * sx * sy + ) + T[..., 3, 1, 2] = ( + 2 * 0.5 * k1 * k2 * (2 * cy * jd - sy * js) + 0.5 * (k2 + hx * k1) * dx * cy + ) + T[..., 3, 1, 3] = ( + 2 * 0.5 * k2 * (2 * k1 * sy * jd - cy * js) + 0.5 * (k2 + hx * k1) * dx * sy + ) + T[..., 3, 2, 5] = ( + 2 * 0.5 * hx / beta * k1 * k2 * (2 * cy * jf - sy * jd) + + 0.5 * hx / beta * (k2 + hx * k1) * j1 * cy + + 0.25 / beta * k1 * (sy - length * cy) + ) + T[..., 3, 3, 5] = ( + 2 * 0.5 * hx / beta * k2 * (2 * k1 * sy * jf - cy * jd) + + 0.5 * hx / beta * (k2 + hx * k1) * j1 * sy + - 0.25 / beta * k1 * length * sy + ) + T[..., 4, 0, 0] = -1 * hx / 12 / beta * khk * ( + sx * dx + 3 * j1 + ) - 0.25 / beta * k1 * (length - sx * cx) + T[..., 4, 0, 1] = -2 * hx / 12 / beta * khk * dx**2 + 0.25 / beta * k1 * sx**2 + T[..., 4, 1, 1] = ( + -1 * hx / 6 / beta * khk * j2 + - 0.5 / beta * sx + - 0.25 / beta * k1 * (j1 - sx * dx) + ) + T[..., 4, 0, 5] = ( + -2 * hx**2 / 12 / beta**2 * khk * (3 * dx * j1 - 4 * j2) + + 0.25 * hx / beta**2 * k1 * j1 * (1 + cx) + + 0.5 * hx / beta**2 * igamma2 * sx + ) + T[..., 4, 1, 5] = ( + -2 * hx**2 / 12 / beta**2 * khk * (dx * dx**2 - 2 * sx * j2) + + 0.25 * hx / beta**2 * k1 * sx * j1 + + 0.5 * hx / beta**2 * igamma2 * dx + ) + T[..., 4, 5, 5] = ( + -1 * hx**3 / 6 / beta**3 * khk * (3 * j3 - 2 * dx * j2) + + hx**2 / 6 / beta**3 * k1 * (sx * dx**2 - j2 * (1 + 2 * cx)) + + 1.5 / beta**3 * igamma2 * (hx**2 * j1 - length) + ) + T[..., 4, 2, 2] = ( + -1 * -hx / beta * k1 * k2 * jf + - 0.5 * hx / beta * (k2 + hx * k1) * j1 + + 0.25 / beta * k1 * (length - cy * sy) + ) + T[..., 4, 2, 3] = -2 * -0.5 * hx / beta * k2 * jd - 0.25 / beta * k1 * sy**2 + T[..., 4, 3, 3] = ( + -1 * -hx / beta * k2 * jf + + 0.5 * hx**2 / beta * j1 + - 0.25 / beta * (length + cy * sy) + ) + T[..., 6, 6, 6] = 0.0 # Constant term currently handled by first order transfer map + + # Rotate the R matrix for skew / vertical magnets + if torch.any(tilt != 0): + T = torch.einsum( + "...ij,...jk,...kl->...il", rotation_matrix(-tilt), T, rotation_matrix(tilt) + ) + return T + + +def rotation_matrix(angle: torch.Tensor) -> torch.Tensor: + """Rotate the transfer map in x-y plane. + + :param angle: Rotation angle in rad, for example `angle = np.pi/2` for vertical = + dipole. + :return: Rotation matrix to be multiplied to the element's transfer matrix. + """ + cs = torch.cos(angle) + sn = torch.sin(angle) + + tm = torch.eye(7, dtype=angle.dtype, device=angle.device).repeat(*angle.shape, 1, 1) + tm[..., 0, 0] = cs + tm[..., 0, 2] = sn + tm[..., 1, 1] = cs + tm[..., 1, 3] = sn + tm[..., 2, 0] = -sn + tm[..., 2, 2] = cs + tm[..., 3, 1] = -sn + tm[..., 3, 3] = cs + + return tm + + def misalignment_matrix( misalignment: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/tests/test_elements.py b/tests/test_elements.py index 6e5fb3992..771a8bb72 100644 --- a/tests/test_elements.py +++ b/tests/test_elements.py @@ -11,6 +11,7 @@ cheetah.HorizontalCorrector: {"length": torch.tensor(1.0)}, cheetah.Quadrupole: {"length": torch.tensor(1.0)}, cheetah.Segment: {"elements": [cheetah.Drift(length=torch.tensor(1.0))]}, + cheetah.Sextupole: {"length": torch.tensor(1.0)}, cheetah.Solenoid: {"length": torch.tensor(1.0)}, cheetah.SpaceChargeKick: {"effect_length": torch.tensor(1.0)}, cheetah.TransverseDeflectingCavity: {"length": torch.tensor(1.0)}, diff --git a/tests/test_octupole.py b/tests/test_octupole.py new file mode 100644 index 000000000..3560848c8 --- /dev/null +++ b/tests/test_octupole.py @@ -0,0 +1,104 @@ +from copy import deepcopy + +import ocelot +import torch + +import cheetah + + +def test_compare_octupole_to_ocelot(): + """Compare the results of tracking through a octupole in Cheetah and Ocelot.""" + length = 0.34 + k3 = 0.5 + tilt = 0.1 + + # Track through a octupole in Cheetah + incoming = cheetah.ParticleBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + cheetah_octupole = cheetah.Octupole( + length=torch.tensor(length), k3=torch.tensor(k3), tilt=torch.tensor(tilt) + ) + outgoing_cheetah = cheetah_octupole.track(incoming) + + # Convert to Ocelot octupole + incoming_p_array = ocelot.astraBeam2particleArray( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + lattice = ocelot.MagneticLattice( + [ocelot.Octupole(l=length, k3=k3, tilt=tilt)], + method={"global": ocelot.SecondTM}, + ) + navigator = ocelot.Navigator(lattice) + _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) + outgoing_ocelot = cheetah.ParticleBeam.from_ocelot(outgoing_p_array) + + # Compare the results + assert torch.allclose( + outgoing_cheetah.particles, outgoing_ocelot.particles, atol=1e-5, rtol=1e-6 + ) + + +def test_octupole_as_drift(): + """Test that a octupole with k3=0 is equivalent to a drift.""" + incoming = cheetah.ParticleBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + + octupole = cheetah.Octupole(length=torch.tensor(0.34), k3=torch.tensor(0.0)) + drift = cheetah.Drift(length=torch.tensor(0.34)) + + # Track through the octupole and drift + octupole_outgoing = octupole.track(incoming) + drift_outgoing = drift.track(incoming) + + # Check that the results are the same + assert torch.allclose( + octupole_outgoing.particles, drift_outgoing.particles, atol=1e-5, rtol=1e-6 + ) + + +def test_octupole_parameter_beam_particle_beam_agreement(): + """ + Test that the results of tracking an `ParameterBeam` and a `ParticleBeam` through a + octupole agree. + """ + # Create a octupole + length = 0.34 + k3 = 0.5 + tilt = 0.1 + octupole = cheetah.Octupole( + length=torch.tensor(length), k3=torch.tensor(k3), tilt=torch.tensor(tilt) + ) + + # Create an incoming ParticleBeam + incoming_particle_beam = cheetah.ParticleBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + + # Create an incoming ParameterBeam + incoming_parameter_beam = cheetah.ParameterBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + + # Track through the octupole + outgoing_particle_beam = octupole.track(incoming_particle_beam) + outgoing_parameter_beam = octupole.track(incoming_parameter_beam) + + outgoing_particle_beam_as_parameter_beam = ( + outgoing_particle_beam.as_parameter_beam() + ) + + # Check that the results are the same + assert torch.allclose( + outgoing_particle_beam_as_parameter_beam.mu, + outgoing_parameter_beam.mu, + atol=1e-5, + rtol=1e-6, + ) + assert torch.allclose( + outgoing_particle_beam_as_parameter_beam.cov, + outgoing_parameter_beam.cov, + atol=1e-5, + rtol=1e-6, + ) diff --git a/tests/test_sextupole.py b/tests/test_sextupole.py new file mode 100644 index 000000000..309413523 --- /dev/null +++ b/tests/test_sextupole.py @@ -0,0 +1,104 @@ +from copy import deepcopy + +import ocelot +import torch + +import cheetah + + +def test_compare_sextupole_to_ocelot(): + """Compare the results of tracking through a sextupole in Cheetah and Ocelot.""" + length = 0.34 + k2 = 0.5 + tilt = 0.1 + + # Track through a sextupole in Cheetah + incoming = cheetah.ParticleBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + cheetah_sextupole = cheetah.Sextupole( + length=torch.tensor(length), k2=torch.tensor(k2), tilt=torch.tensor(tilt) + ) + outgoing_cheetah = cheetah_sextupole.track(incoming) + + # Convert to Ocelot sextupole + incoming_p_array = ocelot.astraBeam2particleArray( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + lattice = ocelot.MagneticLattice( + [ocelot.Sextupole(l=length, k2=k2, tilt=tilt)], + method={"global": ocelot.SecondTM}, + ) + navigator = ocelot.Navigator(lattice) + _, outgoing_p_array = ocelot.track(lattice, deepcopy(incoming_p_array), navigator) + outgoing_ocelot = cheetah.ParticleBeam.from_ocelot(outgoing_p_array) + + # Compare the results + assert torch.allclose( + outgoing_cheetah.particles, outgoing_ocelot.particles, atol=1e-5, rtol=1e-6 + ) + + +def test_sextupole_as_drift(): + """Test that a sextupole with k2=0 is equivalent to a drift.""" + incoming = cheetah.ParticleBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + + sextupole = cheetah.Sextupole(length=torch.tensor(0.34), k2=torch.tensor(0.0)) + drift = cheetah.Drift(length=torch.tensor(0.34)) + + # Track through the sextupole and drift + sextupole_outgoing = sextupole.track(incoming) + drift_outgoing = drift.track(incoming) + + # Check that the results are the same + assert torch.allclose( + sextupole_outgoing.particles, drift_outgoing.particles, atol=1e-5, rtol=1e-6 + ) + + +def test_sextupole_parameter_beam_particle_beam_agreement(): + """ + Test that the results of tracking an `ParameterBeam` and a `ParticleBeam` through a + sextupole agree. + """ + # Create a sextupole + length = 0.34 + k2 = 0.5 + tilt = 0.1 + sextupole = cheetah.Sextupole( + length=torch.tensor(length), k2=torch.tensor(k2), tilt=torch.tensor(tilt) + ) + + # Create an incoming ParticleBeam + incoming_particle_beam = cheetah.ParticleBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + + # Create an incoming ParameterBeam + incoming_parameter_beam = cheetah.ParameterBeam.from_astra( + "tests/resources/ACHIP_EA1_2021.1351.001" + ) + + # Track through the sextupole + outgoing_particle_beam = sextupole.track(incoming_particle_beam) + outgoing_parameter_beam = sextupole.track(incoming_parameter_beam) + + outgoing_particle_beam_as_parameter_beam = ( + outgoing_particle_beam.as_parameter_beam() + ) + + # Check that the results are the same + assert torch.allclose( + outgoing_particle_beam_as_parameter_beam.mu, + outgoing_parameter_beam.mu, + atol=1e-5, + rtol=1e-6, + ) + assert torch.allclose( + outgoing_particle_beam_as_parameter_beam.cov, + outgoing_parameter_beam.cov, + atol=1e-5, + rtol=1e-6, + )