Skip to content

Commit 8d7bb43

Browse files
authored
Merge pull request #311 from bigict/attn
refactor: add kabsch_tranform func
2 parents e3ffc52 + 87cd3fa commit 8d7bb43

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

profold2/model/functional.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,32 @@ def kabsch_rotation(
14471447
return r
14481448

14491449

1450+
def kabsch_transform(
1451+
x: torch.Tensor, y: torch.Tensor, mask: Optional[torch.Tensor] = None
1452+
):
1453+
"""Calculate the best rotation that minimises the RMSD between x and y.
1454+
Args:
1455+
x (torch.Tensor): source atom coordinates with shape `(num_res, 3)`.
1456+
y (torch.Tensor): target atom coordinates with shape `(num_res, 3)`.
1457+
mask (torch.Tensor optional): with shape `(num_res, )`.
1458+
Returns:
1459+
r (torch.Tensor): rotation matrix with shape `(3, 3)`
1460+
"""
1461+
assert x.shape == y.shape
1462+
1463+
R = kabsch_rotation(x, y, mask=mask) # pylint: disable=invalid-name
1464+
1465+
if exists(mask):
1466+
x_center = masked_mean(value=x, mask=mask, dim=-2, keepdim=True)
1467+
y_center = masked_mean(value=y, mask=mask, dim=-2, keepdim=True)
1468+
else:
1469+
x_center = torch.mean(x, dim=-2, keepdim=True)
1470+
y_center = torch.mean(y, dim=-2, keepdim=True)
1471+
t = x_center - torch.einsum('... h w, ... w -> ... h', R, y_center)
1472+
1473+
return R, t
1474+
1475+
14501476
def kabsch_align(x: torch.Tensor, y: torch.Tensor, mask: Optional[torch.Tensor] = None):
14511477
""" Kabsch alignment of x into y. Assumes x, y are both (num_res, 3).
14521478
"""
@@ -1500,11 +1526,7 @@ def optimal_transform_create(pred_points, true_points, points_mask):
15001526
with torch.no_grad():
15011527
pred_ca = true_ca
15021528

1503-
R = kabsch_rotation(pred_ca, true_ca) # pylint: disable=invalid-name
1504-
1505-
pred_center = torch.mean(pred_ca, dim=-2, keepdim=True)
1506-
true_center = torch.mean(true_ca, dim=-2, keepdim=True)
1507-
t = pred_center - torch.einsum('... h w, ... w -> ... h', R, true_center)
1529+
return kabsch_transform(pred_ca, true_ca)
15081530

15091531
return R, t
15101532

0 commit comments

Comments
 (0)