@@ -41,12 +41,18 @@ class MappingType(Enum):
41
41
we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7)
42
42
e.g. scale = (10.2 - (-10.2)) / (7 - (-8))
43
43
44
+ SYMMETRIC_MAX_POS_NEG is a variant of symmetric mapping, where the scale is the max of smin
45
+ and smax, where smin = min_val_neg / quant_min, and smax = max_val_pos / quant_max. By calculating
46
+ smin and smax individually, there can be less round error on negative values, and no out-of-range
47
+ of all floating point values.
48
+
44
49
asymmetric mapping means we just directly map the floating point range to integer range,
45
50
for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter
46
51
based on this mapping
47
52
e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
48
53
"""
49
54
SYMMETRIC = auto ()
55
+ SYMMETRIC_MAX_POS_NEG = auto ()
50
56
ASYMMETRIC = auto ()
51
57
52
58
class ZeroPointDomain (Enum ):
@@ -695,7 +701,7 @@ def _choose_qparams_affine(
695
701
and `zero_point_domain`
696
702
"""
697
703
quant_min , quant_max = _get_and_check_qmin_qmax (target_dtype , quant_min , quant_max )
698
- assert mapping_type in [MappingType .SYMMETRIC .name , MappingType .ASYMMETRIC .name ], f"Unsupported mapping type: { mapping_type } "
704
+ assert mapping_type in [MappingType .SYMMETRIC .name , MappingType .SYMMETRIC_MAX_POS_NEG . name , MappingType . ASYMMETRIC .name ], f"Unsupported mapping type: { mapping_type } "
699
705
700
706
if input is not None :
701
707
if scale_dtype is None :
@@ -729,11 +735,25 @@ def _choose_qparams_affine(
729
735
min_val_neg = min_val
730
736
max_val_pos = max_val
731
737
732
- if mapping_type == MappingType .SYMMETRIC .name :
733
- smin = min_val_neg / float (quant_min )
734
- smax = max_val_pos / float (quant_max )
735
- mask = smin > smax
736
- scale = torch .where (mask , smin , smax )
738
+ if mapping_type == MappingType .SYMMETRIC .name or mapping_type == MappingType .SYMMETRIC_MAX_POS_NEG .name :
739
+ # scales
740
+ if mapping_type == MappingType .SYMMETRIC .name :
741
+ max_val_pos = torch .max (- min_val_neg , max_val_pos )
742
+ scale = max_val_pos / (float (quant_max - quant_min ) / 2 )
743
+ else :
744
+ assert mapping_type == MappingType .SYMMETRIC_MAX_POS_NEG .name
745
+ # calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and
746
+ # quant_max = 7.
747
+ # - If smin is bigger: There would be coverage on negative values down to -8, and less rounding
748
+ # error than the existing SYMMETRIC case.
749
+ # - If smax is bigger: it covers the positive values up to 7. The round
750
+ # error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after
751
+ # quantization.
752
+ smin = min_val_neg / float (quant_min )
753
+ smax = max_val_pos / float (quant_max )
754
+ mask = smin > smax
755
+ scale = torch .where (mask , smin , smax )
756
+ # zeros
737
757
if not preserve_zero :
738
758
raise ValueError ("preserve_zero == False is not supported for symmetric quantization" )
739
759
if zero_point_domain is not None and zero_point_domain != ZeroPointDomain .INT .name :
0 commit comments