Skip to content

Commit a77bc68

Browse files
authored
Merge pull request #313 from bigict/attn
fix: multi_chain_permutation, torch.clone
2 parents 63b3f13 + 20d13fc commit a77bc68

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

profold2/model/functional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1628,7 +1628,9 @@ def multi_chain_permutation_alignment(value, batch):
16281628
]
16291629
T = optimal_transform_create(pred_points, true_points, points_mask) # pylint: disable=invalid-name
16301630

1631-
coord, coord_mask = batch['coord'][bdx], batch['coord_mask'][bdx]
1631+
coord, coord_mask = map(
1632+
torch.clone, (batch['coord'][bdx], batch['coord_mask'][bdx])
1633+
)
16321634
for seq_color_i, seq_color_j in optimal_permutation_find(
16331635
rigids_apply(T, batch['coord_fgt'][bdx]),
16341636
batch['coord_mask_fgt'][bdx],

0 commit comments

Comments
 (0)