Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- First- and second-order transfer maps are now cached resulting in potential speed-ups of up to 10x and more (see #532) (@jank324)
- Methods for creating `ParticleBeam` instances from distributions via stochastic sampling now make sure that the statistics of the generated particles match the desired distribution (see #546) (@cr-xu)
- `BPM` elements now support misalignments (see #533) (@roussel-ryan, @jank324)
- Speed up tracking by replacing some PyTorch operations with faster alternatives (see #538, #561) (@jank324, @Hespe)

### 🐛 Bug fixes

Expand Down
13 changes: 13 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Contribution Guidelines

## How to write fast PyTorch code

### Creating new tensors

```python
torch.tensor(0.0, device=a.device, dtype=a.dtype)
torch.zeros((), device=a.device, dtype=a.dtype)
torch.zeros_like(a) # <-- This is fastest for same shape (see #561)
a.new_zeros(()) # <-- This is fastest for compatible constants (see #561)
a.new_zeros(a.shape)
```
12 changes: 6 additions & 6 deletions cheetah/accelerator/bpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@ def __init__(
self.is_active = is_active
factory_kwargs = {"device": device, "dtype": dtype}

self.register_buffer(
"reading",
torch.tensor((torch.nan, torch.nan), **factory_kwargs),
persistent=False,
)

self.register_buffer_or_parameter(
"misalignment",
(
Expand All @@ -52,6 +46,12 @@ def __init__(
),
)

self.register_buffer(
"reading",
torch.tensor((torch.nan, torch.nan), **factory_kwargs),
Copy link
Preview

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffer registration should use self.misalignment.new_full((2,), torch.nan) instead of torch.tensor((torch.nan, torch.nan), **factory_kwargs) to be consistent with the PR's optimization goals and ensure the tensor uses the same device/dtype as other parameters.

Suggested change
torch.tensor((torch.nan, torch.nan), **factory_kwargs),
self.misalignment.new_full((2,), torch.nan),

Copilot uses AI. Check for mistakes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This really depends on the answer to my earlier question. I'm thinking the answer is probably no, let's not do that.

persistent=False,
)

@property
def is_skippable(self) -> bool:
return not self.is_active
Expand Down
12 changes: 7 additions & 5 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,17 @@ def is_skippable(self) -> bool:
def _compute_first_order_transfer_map(
self, energy: torch.Tensor, species: Species
) -> torch.Tensor:
zero = self.length.new_zeros(())

return torch.where(
(self.voltage != 0).unsqueeze(-1).unsqueeze(-1),
self._cavity_rmatrix(energy, species),
base_rmatrix(
length=self.length,
k1=torch.zeros_like(self.length),
hx=torch.zeros_like(self.length),
k1=zero,
hx=zero,
species=species,
tilt=torch.zeros_like(self.length),
tilt=zero,
energy=energy,
),
)
Expand All @@ -107,8 +109,8 @@ def track(self, incoming: Beam) -> Beam:
)

T566 = 1.5 * self.length * igamma2 / beta0**3
T556 = torch.full_like(self.length, 0.0)
T555 = torch.full_like(self.length, 0.0)
T556 = self.length.new_zeros(())
T555 = self.length.new_zeros(())

if torch.any(incoming.energy + delta_energy > 0):
k = 2 * torch.pi * self.frequency / constants.speed_of_light
Expand Down
27 changes: 7 additions & 20 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ def _track_drift_kick_drift(self, incoming: ParticleBeam) -> ParticleBeam:
# throughout the function makes it even harder, is bad practice and should
# really be fixed!

# Zero constant for later use to save on tensor allocations
zero = self.tilt.new_zeros(())

# Compute Bmad coordinates and p0c
x = incoming.x
px = incoming.px
Expand All @@ -220,31 +223,15 @@ def _track_drift_kick_drift(self, incoming: ParticleBeam) -> ParticleBeam:
z, pz, p0c = bmadx.cheetah_to_bmad_z_pz(tau, delta, incoming.energy, mc2)

# Begin Bmad-X tracking
x, px, y, py = bmadx.offset_particle_set(
torch.zeros_like(self.tilt),
torch.zeros_like(self.tilt),
self.tilt,
x,
px,
y,
py,
)
x, px, y, py = bmadx.offset_particle_set(zero, zero, self.tilt, x, px, y, py)

if self.fringe_at == "entrance" or self.fringe_at == "both":
px, py = self._bmadx_fringe_linear("entrance", x, px, y, py)
x, px, y, py, z, pz = self._bmadx_body(x, px, y, py, z, pz, p0c, mc2)
if self.fringe_at == "exit" or self.fringe_at == "both":
px, py = self._bmadx_fringe_linear("exit", x, px, y, py)

x, px, y, py = bmadx.offset_particle_unset(
torch.zeros_like(self.tilt),
torch.zeros_like(self.tilt),
self.tilt,
x,
px,
y,
py,
)
x, px, y, py = bmadx.offset_particle_unset(zero, zero, self.tilt, x, px, y, py)
# End of Bmad-X tracking

# Convert back to Cheetah coordinates
Expand Down Expand Up @@ -442,7 +429,7 @@ def _compute_second_order_transfer_map(
T = base_ttensor(
length=self.length,
k1=self.k1,
k2=torch.tensor(0.0, **factory_kwargs),
k2=self.length.new_zeros(()),
hx=self.hx,
species=species,
energy=energy,
Expand All @@ -462,7 +449,7 @@ def _compute_second_order_transfer_map(
R[..., 2, 6] = self.angle
R[..., 2, 3] = self.length

T = torch.zeros((*self.length.shape, 7, 7), **factory_kwargs)
T = self.length.new_zeros((*self.length.shape, 7, 7, 7))
T[..., :, 6, :] = R

# Apply fringe fields
Expand Down
9 changes: 3 additions & 6 deletions cheetah/accelerator/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,10 @@ def _compute_first_order_transfer_map(
def _compute_second_order_transfer_map(
self, energy: torch.Tensor, species: Species
) -> torch.Tensor:
zero = self.length.new_zeros(())

T = base_ttensor(
self.length,
k1=torch.tensor(0.0, device=self.length.device, dtype=self.length.dtype),
k2=torch.tensor(0.0, device=self.length.device, dtype=self.length.dtype),
hx=torch.tensor(0.0, device=self.length.device, dtype=self.length.dtype),
energy=energy,
species=species,
self.length, k1=zero, k2=zero, hx=zero, energy=energy, species=species
)

# Fill the first-order transfer map into the second-order transfer map
Expand Down
12 changes: 7 additions & 5 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _compute_first_order_transfer_map(
R = base_rmatrix(
length=self.length,
k1=self.k1,
hx=torch.tensor(0.0, device=self.length.device, dtype=self.length.dtype),
hx=self.length.new_zeros(()),
species=species,
tilt=self.tilt,
energy=energy,
Expand All @@ -98,11 +98,13 @@ def _compute_first_order_transfer_map(
def _compute_second_order_transfer_map(
self, energy: torch.Tensor, species: Species
) -> torch.Tensor:
zero = self.length.new_zeros(())

T = base_ttensor(
length=self.length,
k1=self.k1,
k2=torch.tensor(0.0, device=self.length.device, dtype=self.length.dtype),
hx=torch.tensor(0.0, device=self.length.device, dtype=self.length.dtype),
k2=zero,
hx=zero,
tilt=self.tilt,
energy=energy,
species=species,
Expand All @@ -112,7 +114,7 @@ def _compute_second_order_transfer_map(
T[..., :, 6, :] = base_rmatrix(
length=self.length,
k1=self.k1,
hx=torch.tensor(0.0, device=self.length.device, dtype=self.length.dtype),
hx=zero,
species=species,
tilt=self.tilt,
energy=energy,
Expand Down Expand Up @@ -230,7 +232,7 @@ def _track_drift_kick_drift(self, incoming: ParticleBeam) -> ParticleBeam:
)

# pz is unaffected by tracking, therefore needs to match vector dimensions
pz = pz * torch.ones_like(x)
pz, _ = torch.broadcast_tensors(pz, x)
# End of Bmad-X tracking

# Convert back to Cheetah coordinates
Expand Down
6 changes: 2 additions & 4 deletions cheetah/accelerator/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,8 @@ def reading(self) -> torch.Tensor:

read_beam = self.get_read_beam()
if read_beam is None:
image = torch.zeros(
(int(self.effective_resolution[1]), int(self.effective_resolution[0])),
device=self.misalignment.device,
dtype=self.misalignment.dtype,
image = self.misalignment.new_zeros(
(int(self.effective_resolution[1]), int(self.effective_resolution[0]))
)
elif isinstance(read_beam, ParameterBeam):
if torch.numel(read_beam.mu[..., 0]) > 1:
Expand Down
6 changes: 4 additions & 2 deletions cheetah/accelerator/sextupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ def _compute_first_order_transfer_map(
return drift_matrix(length=self.length, species=species, energy=energy)

def _compute_second_order_transfer_map(self, energy, species):
zero = self.length.new_zeros(())

T = base_ttensor(
length=self.length,
k1=torch.tensor(0.0, device=self.length.device, dtype=self.length.dtype),
k1=zero,
k2=self.k2,
hx=torch.tensor(0.0, device=self.length.device, dtype=self.length.dtype),
hx=zero,
species=species,
tilt=self.tilt,
energy=energy,
Expand Down
24 changes: 7 additions & 17 deletions cheetah/accelerator/space_charge_kick.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,7 @@ def _deposit_charge_on_grid(
Deposits the charge density of the beam onto a grid, using the
Cloud-In-Cell (CIC) method. Returns a grid of charge density in C/m^3.
"""
charge = torch.zeros(
beam.particles.shape[:-2] + self.grid_shape,
device=beam.particles.device,
dtype=beam.particles.dtype,
)
charge = beam.particles.new_zeros(beam.particles.shape[:-2] + self.grid_shape)

# Compute inverse cell size (to avoid multiple divisions later on)
inv_cell_size = 1 / cell_size
Expand Down Expand Up @@ -224,10 +220,8 @@ def _array_rho(
new_dims = tuple(2 * dim for dim in self.grid_shape)

# Create a new tensor with the doubled dimensions, filled with zeros
new_charge_density = torch.zeros(
beam.particles.shape[:-2] + new_dims,
device=beam.particles.device,
dtype=beam.particles.dtype,
new_charge_density = beam.particles.new_zeros(
beam.particles.shape[:-2] + new_dims
)

# Copy the original charge_density values to the beginning of the new tensor
Expand Down Expand Up @@ -316,15 +310,13 @@ def _integrated_green_function(
)

# Initialize the grid with double dimensions
green_func_values = torch.zeros(
green_func_values = beam.particles.new_zeros(
(
*beam.particles.shape[:-2],
2 * num_grid_points_x,
2 * num_grid_points_y,
2 * num_grid_points_tau,
),
device=beam.particles.device,
dtype=beam.particles.dtype,
)
)

# Fill the grid with G_values and its periodic copies
Expand Down Expand Up @@ -462,10 +454,8 @@ def _compute_forces(
beam, xp_coordinates, cell_size, grid_dimensions
)
grid_shape = self.grid_shape
interpolated_forces = torch.zeros(
(*beam.particles.shape[:-1], 3),
device=beam.particles.device,
dtype=beam.particles.dtype,
interpolated_forces = beam.particles.new_zeros(
(*beam.particles.shape[:-1], 3)
) # (..., num_particles, 3)

# Get particle positions
Expand Down
8 changes: 3 additions & 5 deletions cheetah/track_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def base_rmatrix(
"""
factory_kwargs = {"device": length.device, "dtype": length.dtype}

zero = torch.tensor(0.0, **factory_kwargs)
zero = length.new_zeros(())

tilt = tilt if tilt is not None else zero
energy = energy if energy is not None else zero
Expand Down Expand Up @@ -104,9 +104,7 @@ def base_ttensor(
:param energy: Beam energy in eV.
:return: Second order transfer map for the element.
"""
factory_kwargs = {"device": length.device, "dtype": length.dtype}

zero = torch.tensor(0.0, **factory_kwargs)
zero = length.new_zeros(())

tilt = tilt if tilt is not None else zero
energy = energy if energy is not None else zero
Expand Down Expand Up @@ -155,7 +153,7 @@ def base_ttensor(
length.shape, k1.shape, k2.shape, hx.shape, tilt.shape, energy.shape
)

T = torch.zeros((7, 7, 7), **factory_kwargs).repeat(*vector_shape, 1, 1, 1)
T = length.new_zeros((7, 7, 7)).repeat(*vector_shape, 1, 1, 1)
T[..., 0, 0, 0] = -1 / 6 * khk * (sx**2 + dx) - 0.5 * hx * kx2 * sx**2
T[..., 0, 0, 1] = 2 * (-1 / 6 * khk * sx * dx + 0.5 * hx * sx * cx)
T[..., 0, 1, 1] = -1 / 6 * khk * dx**2 + 0.5 * hx * dx * cx
Expand Down
4 changes: 1 addition & 3 deletions cheetah/utils/bmadx.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def bmad_to_cheetah_coords(
# TODO This can probably be moved to the `ParticleBeam` class at some point

# Initialize Cheetah coordinates
cheetah_coords = torch.ones(
(*bmad_coords.shape[:-1], 7), dtype=bmad_coords.dtype, device=bmad_coords.device
)
cheetah_coords = bmad_coords.new_ones((*bmad_coords.shape[:-1], 7))
cheetah_coords[..., :6] = bmad_coords.clone()

# Bmad longitudinal coordinates
Expand Down
Loading