@@ -1447,6 +1447,32 @@ def kabsch_rotation(
1447
1447
return r
1448
1448
1449
1449
1450
+ def kabsch_transform (
1451
+ x : torch .Tensor , y : torch .Tensor , mask : Optional [torch .Tensor ] = None
1452
+ ):
1453
+ """Calculate the best rotation that minimises the RMSD between x and y.
1454
+ Args:
1455
+ x (torch.Tensor): source atom coordinates with shape `(num_res, 3)`.
1456
+ y (torch.Tensor): target atom coordinates with shape `(num_res, 3)`.
1457
+ mask (torch.Tensor optional): with shape `(num_res, )`.
1458
+ Returns:
1459
+ r (torch.Tensor): rotation matrix with shape `(3, 3)`
1460
+ """
1461
+ assert x .shape == y .shape
1462
+
1463
+ R = kabsch_rotation (x , y , mask = mask ) # pylint: disable=invalid-name
1464
+
1465
+ if exists (mask ):
1466
+ x_center = masked_mean (value = x , mask = mask , dim = - 2 , keepdim = True )
1467
+ y_center = masked_mean (value = y , mask = mask , dim = - 2 , keepdim = True )
1468
+ else :
1469
+ x_center = torch .mean (x , dim = - 2 , keepdim = True )
1470
+ y_center = torch .mean (y , dim = - 2 , keepdim = True )
1471
+ t = x_center - torch .einsum ('... h w, ... w -> ... h' , R , y_center )
1472
+
1473
+ return R , t
1474
+
1475
+
1450
1476
def kabsch_align (x : torch .Tensor , y : torch .Tensor , mask : Optional [torch .Tensor ] = None ):
1451
1477
""" Kabsch alignment of x into y. Assumes x, y are both (num_res, 3).
1452
1478
"""
@@ -1500,11 +1526,7 @@ def optimal_transform_create(pred_points, true_points, points_mask):
1500
1526
with torch .no_grad ():
1501
1527
pred_ca = true_ca
1502
1528
1503
- R = kabsch_rotation (pred_ca , true_ca ) # pylint: disable=invalid-name
1504
-
1505
- pred_center = torch .mean (pred_ca , dim = - 2 , keepdim = True )
1506
- true_center = torch .mean (true_ca , dim = - 2 , keepdim = True )
1507
- t = pred_center - torch .einsum ('... h w, ... w -> ... h' , R , true_center )
1529
+ return kabsch_transform (pred_ca , true_ca )
1508
1530
1509
1531
return R , t
1510
1532
0 commit comments