Skip to content

Commit 01a7b47

Browse files
cpuhrschsunjiweiswift
authored andcommitted
SAM2 Fast AMG: memory profiling and more compile (pytorch#1296)
1 parent 4cd453c commit 01a7b47

File tree

7 files changed

+95
-42
lines changed

7 files changed

+95
-42
lines changed

examples/sam2_amg_server/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ Experiments run on H100 and with batch size 1
2424
| mode | mIoU | mask count mismatch | avg. ms per request | max. memory (MiB (%)) | batch size | points per batch |
2525
| -------------- | ----------------- | ------------------- | ------------------- | --------------------- | ---------- | ---------------- |
2626
| baseline | 1.0 | 0 | 863 | 4013MiB (4%) | 1 | 64 |
27-
| ao | 0.9999980926513672 | 6 | 586 | | 1 | 64 |
28-
| fast | 0.9937329888343811 | 191 | 333 | | 1 | 1024 |
29-
| fast | 0.9937219619750977 | 192 | 324 | | 16 | 1024 |
30-
| fast + furious | 0.9804400205612183 | 292 | 131 | | 1 | 1024 |
31-
| fast + furious | 0.9806423187255859 | 282 | 130 | | 16 | 1024 |
27+
| ao | 0.9999980926513672 | 6 | 586 | 3257MiB (3%) | 1 | 64 |
28+
| fast | 0.993732988834381 | 191 | 326 | 27197MiB (27%) | 1 | 1024 |
29+
| fast | 0.9937511086463928 | 194 | 315 | 27488MiB (28%) | 16 | 1024 |
30+
| fast + furious | 0.9817246198654175 | 266 | 120 | 13616MiB (13%) | 1 | 1024 |
31+
| fast + furious | 0.9794579744338989 | 274 | 122 | 13808MiB (14%) | 16 | 1024 |
3232

3333
mask count mismatch counts the number of requests where the number of masks differ from the baseline.
3434
For example, the baseline may have chosen to segment an image into 18 masks, but the fast variant produces 17 or 19.

examples/sam2_amg_server/server.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,26 @@ def profiler_runner(path, fn, *args, **kwargs):
149149
return result
150150

151151

152+
def memory_runner(path, fn, *args, **kwargs):
153+
print("Start memory recording")
154+
torch.cuda.synchronize()
155+
torch.cuda.memory._record_memory_history(
156+
True,
157+
trace_alloc_max_entries=100000,
158+
trace_alloc_record_context=True
159+
)
160+
result = fn(*args, **kwargs)
161+
torch.cuda.synchronize()
162+
snapshot = torch.cuda.memory._snapshot()
163+
print("Finish memory recording")
164+
import pickle
165+
with open(path, 'wb') as f:
166+
pickle.dump(snapshot, f)
167+
# Use to convert pickle file into html
168+
# python torch/cuda/_memory_viz.py trace_plot <snapshot>.pickle -o <snapshot>.html
169+
return result
170+
171+
152172
def image_tensor_to_masks(example_image, mask_generator):
153173
masks = mask_generator.generate(example_image)
154174
return masks
@@ -187,7 +207,7 @@ def process_batch(batch, mask_generator):
187207
print(f"Processing batch of len {len(batch)} using generate_batch")
188208
masks = mask_generator.generate_batch(image_tensors)
189209
print(f"Took avg. {(time.time() - t) / len(batch)}s per batch entry")
190-
# max_memory_allocated()
210+
max_memory_allocated()
191211
return masks
192212

193213

@@ -259,6 +279,7 @@ def main(checkpoint_path,
259279
unittest=False,
260280
benchmark=False,
261281
profile=None,
282+
memory_profile=None,
262283
verbose=False,
263284
points_per_batch=64,
264285
port=5000,
@@ -305,13 +326,6 @@ def main(checkpoint_path,
305326
dynamic=False,
306327
)
307328

308-
mask_generator.predictor.model.sam_prompt_encoder.forward = torch.compile(
309-
mask_generator.predictor.model.sam_prompt_encoder.forward,
310-
mode="max-autotune",
311-
fullgraph=True,
312-
dynamic=False,
313-
)
314-
315329
mask_generator.predictor._predict_masks = torch.compile(
316330
mask_generator.predictor._predict_masks,
317331
mode="max-autotune",
@@ -329,6 +343,7 @@ def main(checkpoint_path,
329343
mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16)
330344
# NOTE: Not baseline feature
331345
mask_generator.predictor._image_dtype = torch.float16
346+
mask_generator.predictor._transforms_device = mask_generator.predictor.device
332347
torch.set_float32_matmul_precision('high')
333348
mask_generator.predictor.model.sam_mask_decoder = mask_generator.predictor.model.sam_mask_decoder.to(torch.float16)
334349
# NOTE: Not baseline feature
@@ -363,11 +378,15 @@ def main(checkpoint_path,
363378
for i, shapes in enumerate([example_shapes(), example_shapes_2()]):
364379
print(f"batch size {batch_size} example shapes {i} benchmark")
365380
random_images = [np.random.randint(0, 256, size=size, dtype=np.uint8) for size in shapes]
381+
if batch_size > len(random_images):
382+
num_repeat = (len(random_images) + batch_size) // batch_size
383+
random_images = num_repeat * random_images
366384

367385
if batch_size == 1:
368386
[benchmark_fn(image_tensor_to_masks, r, mask_generator) for r in random_images]
369387
else:
370388
random_images = random_images[:batch_size]
389+
print("len(random_images): ", len(random_images))
371390
benchmark_fn(image_tensors_to_masks, random_images, mask_generator)
372391

373392
if profile is not None:
@@ -377,6 +396,13 @@ def main(checkpoint_path,
377396
else:
378397
profiler_runner(profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)
379398

399+
if memory_profile is not None:
400+
print(f"Saving memory profile under {memory_profile}")
401+
if batch_size == 1:
402+
memory_runner(memory_profile, image_tensor_to_masks, image_tensor, mask_generator)
403+
else:
404+
memory_runner(memory_profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)
405+
380406
if dry:
381407
return
382408

torchao/_models/sam2/automatic_mask_generator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,6 @@ def _process_crop_batch(
381381
i = 0
382382
batch_features = self.predictor._features
383383
all_crop_data = []
384-
all_all_batch_iterator_data = []
385384
for (cropped_im, crop_box, layer_idx, orig_size) in zip(all_cropped_im, all_crop_box, all_layer_idx, all_orig_size):
386385
cropped_im_size = cropped_im.shape[:2]
387386
self.predictor.reset_predictor()
@@ -425,9 +424,6 @@ def _process_crop_batch(
425424
data = self._process_batch_fullgraph(points, im_size, crop_box, crop_box_torch, orig_size, normalize, orig_box_torch)
426425
all_batch_iterator_data.append(data)
427426
self.predictor.reset_predictor()
428-
all_all_batch_iterator_data.append(all_batch_iterator_data)
429-
430-
for all_batch_iterator_data in all_all_batch_iterator_data:
431427

432428
result_data = None
433429
with torch.autograd.profiler.record_function("all mask_to_rle_pytorch_2"):

torchao/_models/sam2/modeling/sam/mask_decoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ def predict_masks(
219219
# TODO: Not specifying scale kwarg in SDPA will cause NaN here
220220
# print("hs.isnan().any(): ", hs.isnan().any().item())
221221

222+
# TODO: These outputs are being immediately indexed.
223+
# Is there something to remove?
224+
# TODO: The fact that there's a crop box and we try to find stuff at the
225+
# boundary later and there's generally cropping going on smells of padding.
222226
iou_token_out = hs[:, s, :]
223227
mask_tokens_out = hs[:, s + 1: (s + 1 + self.num_mask_tokens), :]
224228

torchao/_models/sam2/sam2_image_predictor.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
]
6767

6868
self._image_dtype = torch.float32
69+
self._transforms_device = "cpu"
6970

7071
@classmethod
7172
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
@@ -110,10 +111,10 @@ def set_image(
110111
raise NotImplementedError("Image format not supported")
111112

112113
input_image = self._transforms.to_tensor(image)
113-
# TODO: Doing these transforms on the GPU changes the numerics
114+
# NOTE: Doing these transforms on the GPU changes the numerics
115+
input_image = input_image.to(device=self._transforms_device)
114116
input_image = self._transforms.transforms(input_image)
115117
input_image = input_image.to(device=self.device)
116-
# TODO: Doing this here instead causes masks to not match reference exactly
117118
# input_image = self._transforms.transforms(input_image)
118119
input_image = input_image[None, ...].to(dtype=self._image_dtype)
119120

@@ -167,8 +168,10 @@ def set_image_batch(
167168
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
168169
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
169170
logging.info("Computing image embeddings for the provided images...")
170-
backbone_out = self.model.forward_image(img_batch)
171-
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
171+
with torch.autograd.profiler.record_function("forward_image"):
172+
backbone_out = self.model.forward_image(img_batch)
173+
with torch.autograd.profiler.record_function("_prepare_backbone_features"):
174+
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
172175
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
173176
if self.model.directly_add_no_mem_embed:
174177
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
@@ -462,11 +465,11 @@ def _predict_masks_postprocess(self, low_res_masks, img_idx, return_logits, chan
462465
# Upscale the masks to the original image resolution
463466
if channel_1:
464467
masks = self._transforms.postprocess_masks_1_channel(
465-
low_res_masks, self._orig_hw[img_idx]
468+
low_res_masks, self._orig_hw[img_idx], self._image_dtype
466469
)
467470
else:
468471
masks = self._transforms.postprocess_masks(
469-
low_res_masks, self._orig_hw[img_idx]
472+
low_res_masks, self._orig_hw[img_idx], self._image_dtype
470473
)
471474
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
472475
if not return_logits:

torchao/_models/sam2/utils/amg.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
215215
return mask.transpose() # Put in C order
216216

217217

218-
def _mask_to_rle_pytorch_2_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor):
218+
@torch.compile(fullgraph=True, dynamic=True)
219+
def _mask_to_rle_pytorch_2_0_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Tensor):
219220
"""
220221
Encodes masks to an uncompressed RLE, in the format expected by
221222
pycoco tools.
@@ -227,33 +228,53 @@ def _mask_to_rle_pytorch_2_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Tenso
227228
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: change indices"):
228229
# Compute change indices
229230
diff = tensor[:, 1:] ^ tensor[:, :-1]
230-
a = torch.tensor([[True]])
231-
if diff.is_cuda:
232-
a = a.pin_memory().cuda()
233-
# a = a.to(diff.device)
231+
# a = torch.tensor([[True]])
232+
a = torch.ones((1, 1), dtype=bool, device=diff.device)
233+
# if diff.is_cuda:
234+
# a = a.pin_memory().cuda()
235+
# # a = a.to(diff.device)
234236
a = a.expand_as(diff.narrow(1, 0, 1))
235237
diff = torch.cat([a, diff, a], dim=1)
236-
if diff.numel() > 2147483646:
237-
num_chunks = (diff.numel() + 2147483646) // 2147483646
238-
change_indices = torch.cat([d.nonzero() for d in diff.chunk(num_chunks)])
239-
else:
240-
change_indices = diff.nonzero()
238+
return diff
239+
240+
241+
@torch.compile(fullgraph=True, dynamic=True)
242+
def _mask_to_rle_pytorch_2_0_1(tensor: torch.Tensor, diff: torch.Tensor, change_indices: torch.Tensor) -> (torch.Tensor, torch.Tensor):
243+
tensor = tensor.permute(0, 2, 1).flatten(1)
241244

242245
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: all_btw_idx"):
243246
alt_lens = diff.sum(dim=1)
244247

245248
all_cur_idx = change_indices[:, 1]
246-
all_btw_idx = torch.cat([all_cur_idx[1:], all_cur_idx[:1]]) - all_cur_idx
249+
all_cur_idx_0 = all_cur_idx.narrow(0, 1, all_cur_idx.size(0) - 1)
250+
all_cur_idx_1 = all_cur_idx.narrow(0, 0, 1)
251+
all_btw_idx = torch.cat([all_cur_idx_0, all_cur_idx_1])
252+
all_btw_idx = all_btw_idx - all_cur_idx
247253

248254
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: Encode run length"):
249255
alt_lens_nt = torch.nested.nested_tensor_from_jagged(all_btw_idx, lengths=alt_lens)
250256
# Encode run length
251257
counts_init = (tensor[:, 0] == 0)
252-
return RLEData(alt_lens_nt=alt_lens_nt,
253-
counts_init=counts_init,
254-
b=b,
255-
h=h,
256-
w=w)
258+
return alt_lens_nt, counts_init
259+
260+
261+
def _mask_to_rle_pytorch_2_0(tensor: torch.Tensor) -> RLEData:
262+
b, h, w = tensor.shape
263+
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: _mask_to_rle_pytorch_2_0_0"):
264+
diff = _mask_to_rle_pytorch_2_0_0(tensor)
265+
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: nonzero"):
266+
if diff.numel() > 2147483646:
267+
num_chunks = (diff.numel() + 2147483646) // 2147483646
268+
change_indices = torch.cat([d.nonzero() for d in diff.chunk(num_chunks)])
269+
else:
270+
change_indices = diff.nonzero()
271+
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: _mask_to_rle_pytorch_2_0_1"):
272+
alt_lens_nt, counts_init = _mask_to_rle_pytorch_2_0_1(tensor, diff, change_indices)
273+
return RLEData(alt_lens_nt=alt_lens_nt,
274+
counts_init=counts_init,
275+
b=b,
276+
h=h,
277+
w=w)
257278

258279

259280
def _mask_to_rle_pytorch_2_1(rle_data: RLEData):
@@ -276,7 +297,8 @@ def _mask_to_rle_pytorch_2_1(rle_data: RLEData):
276297

277298

278299
def mask_to_rle_pytorch_2(tensor: torch.Tensor) -> List[Dict[str, Any]]:
279-
return _mask_to_rle_pytorch_2_1(_mask_to_rle_pytorch_2_0(tensor))
300+
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2"):
301+
return _mask_to_rle_pytorch_2_1(_mask_to_rle_pytorch_2_0(tensor))
280302

281303

282304
def area_from_rle(rle: Dict[str, Any]) -> int:

torchao/_models/sam2/utils/transforms.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def transform_boxes(
7373
boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
7474
return boxes
7575

76-
def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
76+
def postprocess_masks(self, masks: torch.Tensor, orig_hw, output_dtype) -> torch.Tensor:
7777
"""
7878
Perform PostProcessing on output masks.
7979
"""
@@ -114,10 +114,11 @@ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
114114
)
115115
masks = input_masks
116116

117+
masks = masks.to(output_dtype)
117118
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
118119
return masks
119120

120-
def postprocess_masks_1_channel(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
121+
def postprocess_masks_1_channel(self, masks: torch.Tensor, orig_hw, output_dtype) -> torch.Tensor:
121122
"""
122123
Perform PostProcessing on output masks.
123124
"""
@@ -161,5 +162,6 @@ def postprocess_masks_1_channel(self, masks: torch.Tensor, orig_hw) -> torch.Ten
161162
)
162163
masks = input_masks
163164

165+
masks = masks.to(output_dtype)
164166
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
165167
return masks

0 commit comments

Comments
 (0)