@@ -1428,16 +1428,18 @@ def kabsch_rotation(
1428
1428
Returns:
1429
1429
r (torch.Tensor): rotation matrix with shape `(3, 3)`
1430
1430
"""
1431
- assert x .shape == y .shape
1431
+ assert x .shape == y .shape and x . shape [ - 1 ] == 3
1432
1432
1433
1433
if exists (mask ):
1434
+ assert len (x .shape ) == 2
1434
1435
x , y = x [mask > 0 , :], y [mask > 0 , :]
1435
1436
1436
1437
with autocast (enabled = False ):
1437
1438
x , y = x .float (), y .float ()
1438
1439
1439
1440
# optimal rotation matrix via SVD of the convariance matrix {x.T * y}
1440
- v , _ , w = torch .linalg .svd (x .T @ y )
1441
+ # v, _, w = torch.linalg.svd(x.T @ y)
1442
+ v , _ , w = torch .linalg .svd (torch .einsum ('... i c,... i d -> ... c d' , x , y ))
1441
1443
1442
1444
# determinant sign for direction correction
1443
1445
d = torch .sign (torch .det (v ) * torch .det (w ))
@@ -1460,14 +1462,14 @@ def kabsch_transform(
1460
1462
"""
1461
1463
assert x .shape == y .shape
1462
1464
1463
- R = kabsch_rotation (x , y , mask = mask ) # pylint: disable=invalid-name
1464
-
1465
1465
if exists (mask ):
1466
- x_center = masked_mean (value = x , mask = mask , dim = - 2 , keepdim = True )
1467
- y_center = masked_mean (value = y , mask = mask , dim = - 2 , keepdim = True )
1466
+ x_center = masked_mean (value = x , mask = mask [..., None ] , dim = - 2 , keepdim = True )
1467
+ y_center = masked_mean (value = y , mask = mask [..., None ] , dim = - 2 , keepdim = True )
1468
1468
else :
1469
1469
x_center = torch .mean (x , dim = - 2 , keepdim = True )
1470
1470
y_center = torch .mean (y , dim = - 2 , keepdim = True )
1471
+
1472
+ R = kabsch_rotation (x - x_center , y - y_center , mask = mask ) # pylint: disable=invalid-name
1471
1473
t = x_center - torch .einsum ('... h w, ... w -> ... h' , R , y_center )
1472
1474
1473
1475
return R , t
@@ -1478,8 +1480,8 @@ def kabsch_align(x: torch.Tensor, y: torch.Tensor, mask: Optional[torch.Tensor]
1478
1480
"""
1479
1481
# center x and y to the origin
1480
1482
if exists (mask ):
1481
- x_ = x - masked_mean (value = x , mask = mask , dim = - 2 , keepdim = True )
1482
- y_ = y - masked_mean (value = y , mask = mask , dim = - 2 , keepdim = True )
1483
+ x_ = x - masked_mean (value = x , mask = mask [..., None ] , dim = - 2 , keepdim = True )
1484
+ y_ = y - masked_mean (value = y , mask = mask [..., None ] , dim = - 2 , keepdim = True )
1483
1485
else :
1484
1486
x_ = x - x .mean (dim = - 2 , keepdim = True )
1485
1487
y_ = y - y .mean (dim = - 2 , keepdim = True )
0 commit comments