Skip to content

Commit b969c58

Browse files
committed
deduplicate code for get_group_qparams_symmetric
Summary: This just removes the implementation, we can have follow up PRs to remove the call all together after we have replaced all implementation with the new blockwise quant code Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent f05c215 commit b969c58

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ def test_choose_qparams_group_sym(self):
6767
mapping_type = MappingType.SYMMETRIC
6868
dtype = torch.int8
6969
block_size = (1, 2)
70-
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
70+
eps = torch.finfo(torch.float32).eps
71+
precision = torch.float32
72+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision)
7173

72-
scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2)
74+
scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision)
7375

7476
self.assertTrue(torch.equal(scale, scale_ref))
7577
self.assertTrue(torch.equal(zero_point, zp_ref))

torchao/quantization/quant_primitives.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
"groupwise_affine_dequantize_tensor_from_qparams",
4141
"groupwise_affine_quantize_tensor",
4242
"groupwise_affine_dequantize_tensor",
43+
"choose_qparams_affine",
44+
"quantize_affine",
45+
"dequantize_affine",
4346
# TODO: need to clean up above functions
4447
] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else [])
4548

@@ -728,26 +731,19 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float
728731
assert groupsize > 1
729732
assert w.shape[-1] % groupsize == 0
730733
assert w.dim() == 2
734+
assert n_bit <= 8, f"unsupported n_bit: {n_bit}"
731735

732-
to_quant = w.reshape(-1, groupsize)
733-
assert torch.isnan(to_quant).sum() == 0
734-
735-
max_val = to_quant.amax(dim=1, keepdim=True)
736-
min_val = to_quant.amin(dim=1, keepdim=True)
737-
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
738-
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
739-
740-
max_val_abs = torch.max(-min_val_neg, max_val_pos)
741-
max_int = 2 ** (n_bit - 1) - 1
742-
min_int = -(2 ** (n_bit - 1))
743-
744-
scales = max_val_abs / (float(max_int - min_int) / 2)
745-
scales = torch.max(scales, torch.full_like(scales, torch.finfo(torch.float32).eps))
746-
# TODO: make sure abs(scales) is not too small?
747-
zeros = torch.full_like(scales, 0)
748-
return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape(
749-
w.shape[0], -1
750-
)
736+
block_size = (1, groupsize)
737+
mapping_type = MappingType.SYMMETRIC
738+
eps = torch.finfo(torch.float32).eps
739+
ranges = {}
740+
ranges[1] = (-1, 0)
741+
# generating ranges for bit 2 to 8
742+
for i in range(2, 9):
743+
ranges[i] = (-(2 ** (i - 1)), 2 ** (i - 1) - 1)
744+
quant_min, quant_max = ranges[n_bit]
745+
scale, zero_point = choose_qparams_affine(w, mapping_type, block_size, target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, eps=eps, scale_dtype=precision, zero_point_dtype=precision)
746+
return scale.reshape(w.shape[0], -1), zero_point.reshape(w.shape[0], -1)
751747

752748

753749
if TORCH_VERSION_AFTER_2_3:

0 commit comments

Comments
 (0)