@@ -215,7 +215,8 @@ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
215
215
return mask .transpose () # Put in C order
216
216
217
217
218
- def _mask_to_rle_pytorch_2_0 (tensor : torch .Tensor ) -> (torch .Tensor , torch .Tensor , torch .Tensor ):
218
+ @torch .compile (fullgraph = True , dynamic = True )
219
+ def _mask_to_rle_pytorch_2_0_0 (tensor : torch .Tensor ) -> (torch .Tensor , torch .Tensor ):
219
220
"""
220
221
Encodes masks to an uncompressed RLE, in the format expected by
221
222
pycoco tools.
@@ -227,33 +228,53 @@ def _mask_to_rle_pytorch_2_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Tenso
227
228
with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: change indices" ):
228
229
# Compute change indices
229
230
diff = tensor [:, 1 :] ^ tensor [:, :- 1 ]
230
- a = torch .tensor ([[True ]])
231
- if diff .is_cuda :
232
- a = a .pin_memory ().cuda ()
233
- # a = a.to(diff.device)
231
+ # a = torch.tensor([[True]])
232
+ a = torch .ones ((1 , 1 ), dtype = bool , device = diff .device )
233
+ # if diff.is_cuda:
234
+ # a = a.pin_memory().cuda()
235
+ # # a = a.to(diff.device)
234
236
a = a .expand_as (diff .narrow (1 , 0 , 1 ))
235
237
diff = torch .cat ([a , diff , a ], dim = 1 )
236
- if diff .numel () > 2147483646 :
237
- num_chunks = (diff .numel () + 2147483646 ) // 2147483646
238
- change_indices = torch .cat ([d .nonzero () for d in diff .chunk (num_chunks )])
239
- else :
240
- change_indices = diff .nonzero ()
238
+ return diff
239
+
240
+
241
+ @torch .compile (fullgraph = True , dynamic = True )
242
+ def _mask_to_rle_pytorch_2_0_1 (tensor : torch .Tensor , diff : torch .Tensor , change_indices : torch .Tensor ) -> (torch .Tensor , torch .Tensor ):
243
+ tensor = tensor .permute (0 , 2 , 1 ).flatten (1 )
241
244
242
245
with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: all_btw_idx" ):
243
246
alt_lens = diff .sum (dim = 1 )
244
247
245
248
all_cur_idx = change_indices [:, 1 ]
246
- all_btw_idx = torch .cat ([all_cur_idx [1 :], all_cur_idx [:1 ]]) - all_cur_idx
249
+ all_cur_idx_0 = all_cur_idx .narrow (0 , 1 , all_cur_idx .size (0 ) - 1 )
250
+ all_cur_idx_1 = all_cur_idx .narrow (0 , 0 , 1 )
251
+ all_btw_idx = torch .cat ([all_cur_idx_0 , all_cur_idx_1 ])
252
+ all_btw_idx = all_btw_idx - all_cur_idx
247
253
248
254
with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: Encode run length" ):
249
255
alt_lens_nt = torch .nested .nested_tensor_from_jagged (all_btw_idx , lengths = alt_lens )
250
256
# Encode run length
251
257
counts_init = (tensor [:, 0 ] == 0 )
252
- return RLEData (alt_lens_nt = alt_lens_nt ,
253
- counts_init = counts_init ,
254
- b = b ,
255
- h = h ,
256
- w = w )
258
+ return alt_lens_nt , counts_init
259
+
260
+
261
+ def _mask_to_rle_pytorch_2_0 (tensor : torch .Tensor ) -> RLEData :
262
+ b , h , w = tensor .shape
263
+ with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: _mask_to_rle_pytorch_2_0_0" ):
264
+ diff = _mask_to_rle_pytorch_2_0_0 (tensor )
265
+ with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: nonzero" ):
266
+ if diff .numel () > 2147483646 :
267
+ num_chunks = (diff .numel () + 2147483646 ) // 2147483646
268
+ change_indices = torch .cat ([d .nonzero () for d in diff .chunk (num_chunks )])
269
+ else :
270
+ change_indices = diff .nonzero ()
271
+ with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: _mask_to_rle_pytorch_2_0_1" ):
272
+ alt_lens_nt , counts_init = _mask_to_rle_pytorch_2_0_1 (tensor , diff , change_indices )
273
+ return RLEData (alt_lens_nt = alt_lens_nt ,
274
+ counts_init = counts_init ,
275
+ b = b ,
276
+ h = h ,
277
+ w = w )
257
278
258
279
259
280
def _mask_to_rle_pytorch_2_1 (rle_data : RLEData ):
@@ -276,7 +297,8 @@ def _mask_to_rle_pytorch_2_1(rle_data: RLEData):
276
297
277
298
278
299
def mask_to_rle_pytorch_2 (tensor : torch .Tensor ) -> List [Dict [str , Any ]]:
279
- return _mask_to_rle_pytorch_2_1 (_mask_to_rle_pytorch_2_0 (tensor ))
300
+ with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2" ):
301
+ return _mask_to_rle_pytorch_2_1 (_mask_to_rle_pytorch_2_0 (tensor ))
280
302
281
303
282
304
def area_from_rle (rle : Dict [str , Any ]) -> int :
0 commit comments