Skip to content

Commit 63b3f13

Browse files
authored
Merge pull request #312 from bigict/attn
refactor: centre coords before applying Kabsch algo
2 parents 8d7bb43 + 0960852 commit 63b3f13

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

profold2/model/functional.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,16 +1428,18 @@ def kabsch_rotation(
14281428
Returns:
14291429
r (torch.Tensor): rotation matrix with shape `(3, 3)`
14301430
"""
1431-
assert x.shape == y.shape
1431+
assert x.shape == y.shape and x.shape[-1] == 3
14321432

14331433
if exists(mask):
1434+
assert len(x.shape) == 2
14341435
x, y = x[mask > 0, :], y[mask > 0, :]
14351436

14361437
with autocast(enabled=False):
14371438
x, y = x.float(), y.float()
14381439

14391440
# optimal rotation matrix via SVD of the convariance matrix {x.T * y}
1440-
v, _, w = torch.linalg.svd(x.T @ y)
1441+
# v, _, w = torch.linalg.svd(x.T @ y)
1442+
v, _, w = torch.linalg.svd(torch.einsum('... i c,... i d -> ... c d', x, y))
14411443

14421444
# determinant sign for direction correction
14431445
d = torch.sign(torch.det(v) * torch.det(w))
@@ -1460,14 +1462,14 @@ def kabsch_transform(
14601462
"""
14611463
assert x.shape == y.shape
14621464

1463-
R = kabsch_rotation(x, y, mask=mask) # pylint: disable=invalid-name
1464-
14651465
if exists(mask):
1466-
x_center = masked_mean(value=x, mask=mask, dim=-2, keepdim=True)
1467-
y_center = masked_mean(value=y, mask=mask, dim=-2, keepdim=True)
1466+
x_center = masked_mean(value=x, mask=mask[..., None], dim=-2, keepdim=True)
1467+
y_center = masked_mean(value=y, mask=mask[..., None], dim=-2, keepdim=True)
14681468
else:
14691469
x_center = torch.mean(x, dim=-2, keepdim=True)
14701470
y_center = torch.mean(y, dim=-2, keepdim=True)
1471+
1472+
R = kabsch_rotation(x - x_center, y - y_center, mask=mask) # pylint: disable=invalid-name
14711473
t = x_center - torch.einsum('... h w, ... w -> ... h', R, y_center)
14721474

14731475
return R, t
@@ -1478,8 +1480,8 @@ def kabsch_align(x: torch.Tensor, y: torch.Tensor, mask: Optional[torch.Tensor]
14781480
"""
14791481
# center x and y to the origin
14801482
if exists(mask):
1481-
x_ = x - masked_mean(value=x, mask=mask, dim=-2, keepdim=True)
1482-
y_ = y - masked_mean(value=y, mask=mask, dim=-2, keepdim=True)
1483+
x_ = x - masked_mean(value=x, mask=mask[..., None], dim=-2, keepdim=True)
1484+
y_ = y - masked_mean(value=y, mask=mask[..., None], dim=-2, keepdim=True)
14831485
else:
14841486
x_ = x - x.mean(dim=-2, keepdim=True)
14851487
y_ = y - y.mean(dim=-2, keepdim=True)

0 commit comments

Comments
 (0)