-
Notifications
You must be signed in to change notification settings - Fork 23
Remove branching #553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Remove branching #553
Changes from all commits
b70bdf0
44fb11a
40fa731
d256379
4a4073f
ac899c5
95fad94
ea75d9c
7c38b80
ca4a862
c3a6d17
636845f
c026a77
ce18c3f
2cea65d
3cf1b7d
87195f2
29c8bc0
8de0b59
e01ddb1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,7 +128,7 @@ def __init__( | |
|
||
@property | ||
def hx(self) -> torch.Tensor: | ||
return torch.where(self.length == 0.0, 0.0, self.angle / self.length) | ||
return self.angle / self.length # Zero length not caught because not physical | ||
|
||
@property | ||
def dipole_e1(self) -> torch.Tensor: | ||
|
@@ -402,82 +402,57 @@ def _bmadx_fringe_linear( | |
def first_order_transfer_map( | ||
self, energy: torch.Tensor, species: Species | ||
) -> torch.Tensor: | ||
factory_kwargs = {"device": self.length.device, "dtype": self.length.dtype} | ||
|
||
R_enter = self._transfer_map_enter() | ||
R_exit = self._transfer_map_exit() | ||
|
||
if torch.any(self.length != 0.0): # Bending magnet with finite length | ||
R = base_rmatrix( | ||
length=self.length, | ||
k1=self.k1, | ||
hx=self.hx, | ||
species=species, | ||
energy=energy, | ||
) # Tilt is applied after adding edges | ||
else: # Reduce to Thin-Corrector | ||
R = torch.eye(7, **factory_kwargs).repeat((*self.length.shape, 1, 1)) | ||
R[..., 0, 1] = self.length | ||
R[..., 2, 6] = self.angle | ||
R[..., 2, 3] = self.length | ||
R = base_rmatrix( | ||
length=self.length, | ||
k1=self.k1, | ||
hx=self.hx, | ||
species=species, | ||
energy=energy, | ||
) # Tilt is applied after adding edges | ||
|
||
# Apply fringe fields | ||
R = R_exit @ R @ R_enter | ||
|
||
# Apply rotation for tilted magnets | ||
if torch.any(self.tilt != 0): | ||
rotation = rotation_matrix(self.tilt) | ||
R = rotation.mT @ R @ rotation | ||
rotation = rotation_matrix(self.tilt) | ||
R = rotation.mT @ R @ rotation | ||
Comment on lines
+420
to
+421
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could it be worth it to build an autograd function for this sort of function that only computes the tilt optionally in a forward pass, but always computes the gradient as if it had been in there? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be, but that also depends on how expensive the multiplication is. The overhead of the custom autograd functions seems to be non-negligable from the |
||
|
||
return R | ||
|
||
@cache_transfer_map | ||
def second_order_transfer_map( | ||
self, energy: torch.Tensor, species: Species | ||
) -> torch.Tensor: | ||
factory_kwargs = {"device": self.length.device, "dtype": self.length.dtype} | ||
|
||
R_enter = self._transfer_map_enter() | ||
R_exit = self._transfer_map_exit() | ||
|
||
if torch.any(self.length != 0.0): # Bending magnet with finite length | ||
T = base_ttensor( | ||
length=self.length, | ||
k1=self.k1, | ||
k2=torch.tensor(0.0, **factory_kwargs), | ||
hx=self.hx, | ||
species=species, | ||
energy=energy, | ||
) | ||
|
||
# Fill the first-order transfer map into the second-order transfer map | ||
T[..., :, 6, :] = base_rmatrix( | ||
length=self.length, | ||
k1=self.k1, | ||
hx=self.hx, | ||
species=species, | ||
energy=energy, | ||
) | ||
else: # Reduce to Thin-Corrector | ||
R = torch.eye(7, **factory_kwargs).repeat((*self.length.shape, 1, 1)) | ||
R[..., 0, 1] = self.length | ||
R[..., 2, 6] = self.angle | ||
R[..., 2, 3] = self.length | ||
T = base_ttensor( | ||
length=self.length, | ||
k1=self.k1, | ||
k2=self.length.new_zeros(()), | ||
hx=self.hx, | ||
species=species, | ||
energy=energy, | ||
) | ||
|
||
T = torch.zeros((*self.length.shape, 7, 7), **factory_kwargs) | ||
T[..., :, 6, :] = R | ||
# Fill the first-order transfer map into the second-order transfer map | ||
T[..., :, 6, :] = base_rmatrix( | ||
length=self.length, k1=self.k1, hx=self.hx, species=species, energy=energy | ||
) | ||
|
||
# Apply fringe fields | ||
T = torch.einsum( | ||
"...ij,...jkl,...kn,...lm->...inm", R_exit, T, R_enter, R_enter | ||
) | ||
|
||
# Apply rotation for tilted magnets | ||
if torch.any(self.tilt != 0): | ||
rotation = rotation_matrix(self.tilt) | ||
T = torch.einsum( | ||
"...ji,...jkl,...kn,...lm->...inm", rotation, T, rotation, rotation | ||
) | ||
rotation = rotation_matrix(self.tilt) | ||
T = torch.einsum( | ||
"...ji,...jkl,...kn,...lm->...inm", rotation, T, rotation, rotation | ||
) | ||
|
||
return T | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.