diff --git a/CHANGELOG.md b/CHANGELOG.md index c79260367..f057a7e03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ ### 🚀 Features -- Implement second-order tracking for `Drift`, `Dipole` and `Quadrupole` elements, and add a convenient method to set tracking methods for an entire segment. This comes with an overhaul of the overall tracking system. Rename the tracking method `"cheetah"` to `"linear"` and `"bmadx"` to `"drift_kick_drift"`. The existing methods `"cheetah"` and `"bmadx"` will remain supported with a `DeprecationWarning`. (see #476) (@cr-xu, @jank324, @Hespe) +- Implement second-order tracking for `Drift`, `Dipole`, `Quadrupole`, and `CustomTransferMap` elements, and add a convenient method to set tracking methods for an entire segment. This comes with an overhaul of the overall tracking system. Rename the tracking method `"cheetah"` to `"linear"` and `"bmadx"` to `"drift_kick_drift"`. The existing methods `"cheetah"` and `"bmadx"` will remain supported with a `DeprecationWarning`. (see #476, #530) (@cr-xu, @jank324, @Hespe) ### 🐛 Bug fixes diff --git a/cheetah/accelerator/custom_transfer_map.py b/cheetah/accelerator/custom_transfer_map.py index 6768893d6..385530d8a 100644 --- a/cheetah/accelerator/custom_transfer_map.py +++ b/cheetah/accelerator/custom_transfer_map.py @@ -1,3 +1,5 @@ +from typing import Literal + import matplotlib.pyplot as plt import torch from matplotlib.patches import Rectangle @@ -21,12 +23,13 @@ class CustomTransferMap(Element): access the element in a segment. """ - supported_tracking_methods = ["linear"] + supported_tracking_methods = ["linear", "second_order"] def __init__( self, predefined_transfer_map: torch.Tensor, length: torch.Tensor | None = None, + tracking_method: Literal["linear", "second_order"] = "linear", name: torch.Tensor | None = None, sanitize_name: bool = False, device: torch.device | None = None, @@ -38,19 +41,28 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name, sanitize_name=sanitize_name, **factory_kwargs) + self.tracking_method = tracking_method if length is not None: self.length = torch.as_tensor(length, **factory_kwargs) - assert (predefined_transfer_map[..., -1, :-2] == 0.0).all() and ( - predefined_transfer_map[..., -1, -1] == 1.0 - ).all(), "The seventh row of the transfer map must be [0, 0, 0, 0, 0, 0, 1]." + if self.tracking_method == "second_order": + assert predefined_transfer_map.shape[-3:] == (7, 7, 7) + assert ( + (predefined_transfer_map[..., -1, :, :-1] == 0.0).all() + and (predefined_transfer_map[..., -1, :-1, -1] == 0.0).all() + and (predefined_transfer_map[..., -1, -1, -1] == 1.0).all() + ), "The final plane of the output dimension must only contain a single 1." + else: + assert predefined_transfer_map.shape[-2:] == (7, 7) + assert (predefined_transfer_map[..., -1, :-1] == 0.0).all() and ( + predefined_transfer_map[..., -1, -1] == 1.0 + ).all(), "The final row of the transfer map must be [0, 0, 0, 0, 0, 0, 1]." + self.register_buffer_or_parameter( "predefined_transfer_map", torch.as_tensor(predefined_transfer_map, **factory_kwargs), ) - assert self.predefined_transfer_map.shape[-2:] == (7, 7) - @classmethod def from_merging_elements( cls, elements: list[Element], incoming_beam: Beam @@ -98,26 +110,70 @@ def from_merging_elements( tm, length=combined_length, device=device, dtype=dtype, name=combined_name ) + def track(self, incoming: Beam) -> Beam: + """ + Track particles through the custom transfer map. + + :param incoming: Beam entering the element. + :return: Beam exiting the element. + """ + if self.tracking_method == "linear": + return super()._track_first_order(incoming) + elif self.tracking_method == "second_order": + return super()._track_second_order(incoming) + else: + raise ValueError( + f"Invalid tracking method {self.tracking_method}. For element of" + f" type {self.__class__.__name__}, supported methods are " + f"{self.supported_tracking_methods}." + ) + def first_order_transfer_map( self, energy: torch.Tensor, species: Species ) -> torch.Tensor: - return self.predefined_transfer_map + if self.tracking_method == "linear": + return self.predefined_transfer_map + else: + transfer_map = self.predefined_transfer_map[..., :, 6, :] + transfer_map[..., :6, :] += self.predefined_transfer_map[..., :6, :, 6] + + return transfer_map + + def second_order_transfer_map( + self, energy: torch.Tensor, species: Species + ) -> torch.Tensor: + if self.tracking_method == "second_order": + return self.predefined_transfer_map + else: + transfer_map = torch.zeros( + (*self.predefined_transfer_map.shape[:-2], 7, 7, 7), + device=self.predefined_transfer_map.device, + dtype=self.predefined_transfer_map.dtype, + ) + transfer_map[..., :, 6, :] = self.predefined_transfer_map + + return transfer_map @property def is_skippable(self) -> bool: - return True + return self.tracking_method == "linear" def __repr__(self): return ( f"{self.__class__.__name__}(" + f"predefined_transfer_map={repr(self.predefined_transfer_map)}, " + f"length={repr(self.length)}, " + + f"tracking_method={repr(self.tracking_method)}, " + f"name={repr(self.name)})" ) @property def defining_features(self) -> list[str]: - return super().defining_features + ["length", "predefined_transfer_map"] + return super().defining_features + [ + "length", + "predefined_transfer_map", + "tracking_method", + ] def plot( self, s: float, vector_idx: tuple | None = None, ax: plt.Axes | None = None diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 72a66d8f9..60630033c 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -77,7 +77,7 @@ def second_order_transfer_map( def track(self, incoming: Beam) -> Beam: """ - Track particles through the dipole element. + Track particles through the drift element. :param incoming: Beam entering the element. :return: Beam exiting the element. diff --git a/cheetah/converters/elegant.py b/cheetah/converters/elegant.py index 1c8138a7b..3abcaf2e6 100644 --- a/cheetah/converters/elegant.py +++ b/cheetah/converters/elegant.py @@ -222,13 +222,11 @@ def convert_element( return cheetah.BPM(name=name, sanitize_name=sanitize_name) elif parsed["element_type"] == "ematrix": validate_understood_properties( - shared_properties + ["l", "order", "c[1-6]", "r[1-6][1-6]"], + shared_properties + + ["l", "order", "c[1-6]", "r[1-6][1-6]", "t[1-6][1-6][1-6]"], parsed, ) - if parsed.get("order", 1) != 1: - raise ValueError("Only first order modelling is supported") - # Initially zero in elegant by convention R = torch.zeros((7, 7), **factory_kwargs) # Add linear component @@ -246,12 +244,37 @@ def convert_element( # Ensure the affine component is passed along R[6, 6] = 1.0 - return cheetah.CustomTransferMap( - length=torch.tensor(parsed.get("l", 0.0), **factory_kwargs), - predefined_transfer_map=R, - name=name, - sanitize_name=sanitize_name, - ) + if parsed.get("order", 2) == 2: + T = torch.zeros((7, 7, 7), **factory_kwargs) + T[:6, :6, :6] = torch.tensor( + [ + [ + [ + parsed.get(f"t{i + 1}{j + 1}{k + 1}", 0.0) + for k in range(6) + ] + for j in range(6) + ] + for i in range(6) + ], + **factory_kwargs, + ) + T[:, 6, :] = R + return cheetah.CustomTransferMap( + length=torch.tensor(parsed.get("l", 0.0), **factory_kwargs), + predefined_transfer_map=T, + tracking_method="second_order", + name=name, + sanitize_name=sanitize_name, + ) + else: + return cheetah.CustomTransferMap( + length=torch.tensor(parsed.get("l", 0.0), **factory_kwargs), + predefined_transfer_map=R, + tracking_method="linear", + name=name, + sanitize_name=sanitize_name, + ) elif parsed["element_type"] == "rfca": validate_understood_properties( shared_properties + ["l", "phase", "volt", "freq"], parsed diff --git a/tests/conftest.py b/tests/conftest.py index 9bc19dd1a..caabbe275 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,18 @@ cheetah.Aperture: {"inactive": {"is_active": False}, "active": {"is_active": True}}, cheetah.BPM: {"inactive": {"is_active": False}, "active": {"is_active": True}}, cheetah.Cavity: {"default": {"length": torch.tensor(1.0)}}, - cheetah.CustomTransferMap: {"identity": {"predefined_transfer_map": torch.eye(7)}}, + cheetah.CustomTransferMap: { + "linear": { + "predefined_transfer_map": torch.eye(7), + "tracking_method": "linear", + }, + "second_order": { + "predefined_transfer_map": torch.cat( + [torch.zeros(7, 7, 6), torch.eye(7).unsqueeze(-1)], dim=-1 + ), + "tracking_method": "second_order", + }, + }, cheetah.Dipole: { "linear": { "length": torch.tensor(1.0), diff --git a/tests/test_device_dtype.py b/tests/test_device_dtype.py index 1abb426d2..2bcc6cdf1 100644 --- a/tests/test_device_dtype.py +++ b/tests/test_device_dtype.py @@ -76,6 +76,14 @@ def test_conflicting_element_dtype(element): and feature not in required_arguments } + # Collect remaining optional arguments + optional_non_tensor_arguments = { + feature: getattr(element, feature) + for feature in element.defining_features + if not isinstance(getattr(element, feature), torch.Tensor) + and feature not in required_arguments + } + # Ensure that at least one tensor is part of the arguments that are passed each call if ( not any( @@ -90,11 +98,18 @@ def test_conflicting_element_dtype(element): for name, value in optional_tensor_arguments.items(): with pytest.raises(AssertionError): # Contains conflicting dtype - element.__class__(**{name: value.double()}, **required_arguments) + element.__class__( + **{name: value.double()}, + **optional_non_tensor_arguments, + **required_arguments, + ) # Conflict can be overriden by manual dtype selection element.__class__( - **{name: value.double()}, **required_arguments, dtype=torch.float32 + **{name: value.double()}, + **optional_non_tensor_arguments, + **required_arguments, + dtype=torch.float32, ) diff --git a/tests/test_elements.py b/tests/test_elements.py index 4b021fb2c..36a8f5b20 100644 --- a/tests/test_elements.py +++ b/tests/test_elements.py @@ -102,6 +102,7 @@ def test_particle_beam_tracking_with_device_and_dtype(element, device, dtype): isinstance( element, ( + cheetah.CustomTransferMap, cheetah.Dipole, cheetah.Drift, cheetah.Quadrupole,