Skip to content

Commit 3bd151a

Browse files
iseeyuanMartin Yuan
authored andcommitted
Update quant_primitives.py
1 parent 43492c8 commit 3bd151a

File tree

3 files changed

+47
-11
lines changed

3 files changed

+47
-11
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def test_get_group_qparams_symmetric(self):
164164
scale_obs = scale_obs.reshape(weight.shape[0], -1)
165165

166166
# assert that scales are identical
167-
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16)
167+
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16, mapping_type=MappingType.SYMMETRIC)
168168
torch.testing.assert_close(scale_obs, scale_ao, rtol=0, atol=0)
169169

170170
def test_choose_qparams_group_sym(self):
@@ -179,11 +179,27 @@ def test_choose_qparams_group_sym(self):
179179
precision = torch.float32
180180
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision)
181181

182-
scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision)
182+
scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type)
183183

184184
self.assertTrue(torch.equal(scale, scale_ref))
185185
self.assertTrue(torch.equal(zero_point, zp_ref))
186186

187+
def test_choose_qparams_group_sym_pos_neg(self):
188+
"""
189+
Test the added MappingType.SYMMETRIC_MAX_POS_NEG
190+
"""
191+
input = torch.randn(10, 10)
192+
mapping_type = MappingType.SYMMETRIC_MAX_POS_NEG
193+
dtype = torch.int8
194+
block_size = (1, 2)
195+
eps = torch.finfo(torch.float32).eps
196+
precision = torch.float32
197+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision)
198+
199+
scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type)
200+
201+
self.assertTrue(torch.equal(scale, scale_ref))
202+
self.assertTrue(torch.equal(zero_point, zp_ref))
187203
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower")
188204
@unittest.skipIf(is_fbcode(), "broken in fbcode")
189205
def test_choose_qparams_token_asym(self):

torchao/quantization/quant_primitives.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,18 @@ class MappingType(Enum):
4141
we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7)
4242
e.g. scale = (10.2 - (-10.2)) / (7 - (-8))
4343
44+
SYMMETRIC_MAX_POS_NEG is a variant of symmetric mapping, where the scale is the max of smin
45+
and smax, where smin = min_val_neg / quant_min, and smax = max_val_pos / quant_max. By calculating
46+
smin and smax individually, there can be less round error on negative values, and no out-of-range
47+
of all floating point values.
48+
4449
asymmetric mapping means we just directly map the floating point range to integer range,
4550
for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter
4651
based on this mapping
4752
e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
4853
"""
4954
SYMMETRIC = auto()
55+
SYMMETRIC_MAX_POS_NEG = auto()
5056
ASYMMETRIC = auto()
5157

5258
class ZeroPointDomain(Enum):
@@ -695,7 +701,7 @@ def _choose_qparams_affine(
695701
and `zero_point_domain`
696702
"""
697703
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
698-
assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}"
704+
assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.SYMMETRIC_MAX_POS_NEG.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}"
699705

700706
if input is not None:
701707
if scale_dtype is None:
@@ -729,11 +735,25 @@ def _choose_qparams_affine(
729735
min_val_neg = min_val
730736
max_val_pos = max_val
731737

732-
if mapping_type == MappingType.SYMMETRIC.name:
733-
smin = min_val_neg / float(quant_min)
734-
smax = max_val_pos / float(quant_max)
735-
mask = smin > smax
736-
scale = torch.where(mask, smin, smax)
738+
if mapping_type == MappingType.SYMMETRIC.name or mapping_type == MappingType.SYMMETRIC_MAX_POS_NEG.name:
739+
# scales
740+
if mapping_type == MappingType.SYMMETRIC.name:
741+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
742+
scale = max_val_pos / (float(quant_max - quant_min) / 2)
743+
else:
744+
assert mapping_type == MappingType.SYMMETRIC_MAX_POS_NEG.name
745+
# calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and
746+
# quant_max = 7.
747+
# - If smin is bigger: There would be coverage on negative values down to -8, and less rounding
748+
# error than the existing SYMMETRIC case.
749+
# - If smax is bigger: it covers the positive values up to 7. The round
750+
# error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after
751+
# quantization.
752+
smin = min_val_neg / float(quant_min)
753+
smax = max_val_pos / float(quant_max)
754+
mask = smin > smax
755+
scale = torch.where(mask, smin, smax)
756+
# zeros
737757
if not preserve_zero:
738758
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
739759
if zero_point_domain is not None and zero_point_domain != ZeroPointDomain.INT.name:

torchao/quantization/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def groupwise_affine_dequantize_tensor(
418418

419419

420420
# TODO: separate scale and zero point precision
421-
def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float32):
421+
def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float32, mapping_type=MappingType.SYMMETRIC_MAX_POS_NEG):
422422
# needed for GPTQ with padding
423423
if groupsize > w.shape[-1]:
424424
groupsize = w.shape[-1]
@@ -427,7 +427,6 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float
427427
assert w.dim() == 2
428428
assert n_bit <= 8, f"unsupported n_bit: {n_bit}"
429429

430-
mapping_type = MappingType.SYMMETRIC
431430
block_size = (1, groupsize)
432431
eps = torch.finfo(torch.float32).eps
433432
ranges = {}
@@ -445,8 +444,9 @@ def group_quantize_tensor_symmetric(
445444
n_bit=4,
446445
group_size=128,
447446
precision=torch.float32,
447+
mapping_type=MappingType.SYMMETRIC_MAX_POS_NEG
448448
):
449-
scales, zeros = get_group_qparams_symmetric(w, n_bit, group_size, precision)
449+
scales, zeros = get_group_qparams_symmetric(w, n_bit, group_size, precision, mapping_type)
450450
n_bit = 4
451451
max_int = 2 ** (n_bit - 1) - 1
452452
min_int = -(2 ** (n_bit - 1))

0 commit comments

Comments
 (0)