Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

### 🚀 Features

- Implement second-order tracking for `Drift`, `Dipole` and `Quadrupole` elements, and add a convenient method to set tracking methods for an entire segment. This comes with an overhaul of the overall tracking system. Rename the tracking method `"cheetah"` to `"linear"` and `"bmadx"` to `"drift_kick_drift"`. The existing methods `"cheetah"` and `"bmadx"` will remain supported with a `DeprecationWarning`. (see #476) (@cr-xu, @jank324, @Hespe)
- Implement second-order tracking for `Drift`, `Dipole`, `Quadrupole`, and `CustomTransferMap` elements, and add a convenient method to set tracking methods for an entire segment. This comes with an overhaul of the overall tracking system. Rename the tracking method `"cheetah"` to `"linear"` and `"bmadx"` to `"drift_kick_drift"`. The existing methods `"cheetah"` and `"bmadx"` will remain supported with a `DeprecationWarning`. (see #476, #530) (@cr-xu, @jank324, @Hespe)

### 🐛 Bug fixes

Expand Down
74 changes: 65 additions & 9 deletions cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Rectangle
Expand All @@ -21,12 +23,13 @@ class CustomTransferMap(Element):
access the element in a segment.
"""

supported_tracking_methods = ["linear"]
supported_tracking_methods = ["linear", "second_order"]

def __init__(
self,
predefined_transfer_map: torch.Tensor,
length: torch.Tensor | None = None,
tracking_method: Literal["linear", "second_order"] = "linear",
name: torch.Tensor | None = None,
sanitize_name: bool = False,
device: torch.device | None = None,
Expand All @@ -38,19 +41,28 @@ def __init__(
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name, sanitize_name=sanitize_name, **factory_kwargs)

self.tracking_method = tracking_method
if length is not None:
self.length = torch.as_tensor(length, **factory_kwargs)

assert (predefined_transfer_map[..., -1, :-2] == 0.0).all() and (
predefined_transfer_map[..., -1, -1] == 1.0
).all(), "The seventh row of the transfer map must be [0, 0, 0, 0, 0, 0, 1]."
if self.tracking_method == "second_order":
assert predefined_transfer_map.shape[-3:] == (7, 7, 7)
assert (
(predefined_transfer_map[..., -1, :, :-1] == 0.0).all()
and (predefined_transfer_map[..., -1, :-1, -1] == 0.0).all()
and (predefined_transfer_map[..., -1, -1, -1] == 1.0).all()
), "The final plane of the output dimension must only contain a single 1."
else:
assert predefined_transfer_map.shape[-2:] == (7, 7)
assert (predefined_transfer_map[..., -1, :-1] == 0.0).all() and (
predefined_transfer_map[..., -1, -1] == 1.0
).all(), "The final row of the transfer map must be [0, 0, 0, 0, 0, 0, 1]."

self.register_buffer_or_parameter(
"predefined_transfer_map",
torch.as_tensor(predefined_transfer_map, **factory_kwargs),
)

assert self.predefined_transfer_map.shape[-2:] == (7, 7)

@classmethod
def from_merging_elements(
cls, elements: list[Element], incoming_beam: Beam
Expand Down Expand Up @@ -98,26 +110,70 @@ def from_merging_elements(
tm, length=combined_length, device=device, dtype=dtype, name=combined_name
)

def track(self, incoming: Beam) -> Beam:
"""
Track particles through the custom transfer map.

:param incoming: Beam entering the element.
:return: Beam exiting the element.
"""
if self.tracking_method == "linear":
return super()._track_first_order(incoming)
elif self.tracking_method == "second_order":
return super()._track_second_order(incoming)
else:
raise ValueError(
f"Invalid tracking method {self.tracking_method}. For element of"
f" type {self.__class__.__name__}, supported methods are "
f"{self.supported_tracking_methods}."
)

def first_order_transfer_map(
self, energy: torch.Tensor, species: Species
) -> torch.Tensor:
return self.predefined_transfer_map
if self.tracking_method == "linear":
return self.predefined_transfer_map
else:
transfer_map = self.predefined_transfer_map[..., :, 6, :]
transfer_map[..., :6, :] += self.predefined_transfer_map[..., :6, :, 6]

return transfer_map

def second_order_transfer_map(
self, energy: torch.Tensor, species: Species
) -> torch.Tensor:
if self.tracking_method == "second_order":
return self.predefined_transfer_map
else:
transfer_map = torch.zeros(
(*self.predefined_transfer_map.shape[:-2], 7, 7, 7),
device=self.predefined_transfer_map.device,
dtype=self.predefined_transfer_map.dtype,
)
transfer_map[..., :, 6, :] = self.predefined_transfer_map

return transfer_map

@property
def is_skippable(self) -> bool:
return True
return self.tracking_method == "linear"

def __repr__(self):
return (
f"{self.__class__.__name__}("
+ f"predefined_transfer_map={repr(self.predefined_transfer_map)}, "
+ f"length={repr(self.length)}, "
+ f"tracking_method={repr(self.tracking_method)}, "
+ f"name={repr(self.name)})"
)

@property
def defining_features(self) -> list[str]:
return super().defining_features + ["length", "predefined_transfer_map"]
return super().defining_features + [
"length",
"predefined_transfer_map",
"tracking_method",
]

def plot(
self, s: float, vector_idx: tuple | None = None, ax: plt.Axes | None = None
Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def second_order_transfer_map(

def track(self, incoming: Beam) -> Beam:
"""
Track particles through the dipole element.
Track particles through the drift element.

:param incoming: Beam entering the element.
:return: Beam exiting the element.
Expand Down
43 changes: 33 additions & 10 deletions cheetah/converters/elegant.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,11 @@ def convert_element(
return cheetah.BPM(name=name, sanitize_name=sanitize_name)
elif parsed["element_type"] == "ematrix":
validate_understood_properties(
shared_properties + ["l", "order", "c[1-6]", "r[1-6][1-6]"],
shared_properties
+ ["l", "order", "c[1-6]", "r[1-6][1-6]", "t[1-6][1-6][1-6]"],
parsed,
)

if parsed.get("order", 1) != 1:
raise ValueError("Only first order modelling is supported")

# Initially zero in elegant by convention
R = torch.zeros((7, 7), **factory_kwargs)
# Add linear component
Expand All @@ -246,12 +244,37 @@ def convert_element(
# Ensure the affine component is passed along
R[6, 6] = 1.0

return cheetah.CustomTransferMap(
length=torch.tensor(parsed.get("l", 0.0), **factory_kwargs),
predefined_transfer_map=R,
name=name,
sanitize_name=sanitize_name,
)
if parsed.get("order", 2) == 2:
T = torch.zeros((7, 7, 7), **factory_kwargs)
T[:6, :6, :6] = torch.tensor(
[
[
[
parsed.get(f"t{i + 1}{j + 1}{k + 1}", 0.0)
for k in range(6)
]
for j in range(6)
]
for i in range(6)
],
**factory_kwargs,
)
T[:, 6, :] = R
return cheetah.CustomTransferMap(
length=torch.tensor(parsed.get("l", 0.0), **factory_kwargs),
predefined_transfer_map=T,
tracking_method="second_order",
name=name,
sanitize_name=sanitize_name,
)
else:
return cheetah.CustomTransferMap(
length=torch.tensor(parsed.get("l", 0.0), **factory_kwargs),
predefined_transfer_map=R,
tracking_method="linear",
name=name,
sanitize_name=sanitize_name,
)
elif parsed["element_type"] == "rfca":
validate_understood_properties(
shared_properties + ["l", "phase", "volt", "freq"], parsed
Expand Down
13 changes: 12 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,18 @@
cheetah.Aperture: {"inactive": {"is_active": False}, "active": {"is_active": True}},
cheetah.BPM: {"inactive": {"is_active": False}, "active": {"is_active": True}},
cheetah.Cavity: {"default": {"length": torch.tensor(1.0)}},
cheetah.CustomTransferMap: {"identity": {"predefined_transfer_map": torch.eye(7)}},
cheetah.CustomTransferMap: {
"linear": {
"predefined_transfer_map": torch.eye(7),
"tracking_method": "linear",
},
"second_order": {
"predefined_transfer_map": torch.cat(
[torch.zeros(7, 7, 6), torch.eye(7).unsqueeze(-1)], dim=-1
),
"tracking_method": "second_order",
},
},
cheetah.Dipole: {
"linear": {
"length": torch.tensor(1.0),
Expand Down
19 changes: 17 additions & 2 deletions tests/test_device_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def test_conflicting_element_dtype(element):
and feature not in required_arguments
}

# Collect remaining optional arguments
optional_non_tensor_arguments = {
feature: getattr(element, feature)
for feature in element.defining_features
if not isinstance(getattr(element, feature), torch.Tensor)
and feature not in required_arguments
}

# Ensure that at least one tensor is part of the arguments that are passed each call
if (
not any(
Expand All @@ -90,11 +98,18 @@ def test_conflicting_element_dtype(element):
for name, value in optional_tensor_arguments.items():
with pytest.raises(AssertionError):
# Contains conflicting dtype
element.__class__(**{name: value.double()}, **required_arguments)
element.__class__(
**{name: value.double()},
**optional_non_tensor_arguments,
**required_arguments,
)

# Conflict can be overriden by manual dtype selection
element.__class__(
**{name: value.double()}, **required_arguments, dtype=torch.float32
**{name: value.double()},
**optional_non_tensor_arguments,
**required_arguments,
dtype=torch.float32,
)


Expand Down
1 change: 1 addition & 0 deletions tests/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def test_particle_beam_tracking_with_device_and_dtype(element, device, dtype):
isinstance(
element,
(
cheetah.CustomTransferMap,
cheetah.Dipole,
cheetah.Drift,
cheetah.Quadrupole,
Expand Down
Loading