|
40 | 40 | "groupwise_affine_dequantize_tensor_from_qparams",
|
41 | 41 | "groupwise_affine_quantize_tensor",
|
42 | 42 | "groupwise_affine_dequantize_tensor",
|
| 43 | + "choose_qparams_affine", |
| 44 | + "quantize_affine", |
| 45 | + "dequantize_affine", |
43 | 46 | # TODO: need to clean up above functions
|
44 | 47 | ] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else [])
|
45 | 48 |
|
@@ -728,26 +731,19 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float
|
728 | 731 | assert groupsize > 1
|
729 | 732 | assert w.shape[-1] % groupsize == 0
|
730 | 733 | assert w.dim() == 2
|
| 734 | + assert n_bit <= 8, f"unsupported n_bit: {n_bit}" |
731 | 735 |
|
732 |
| - to_quant = w.reshape(-1, groupsize) |
733 |
| - assert torch.isnan(to_quant).sum() == 0 |
734 |
| - |
735 |
| - max_val = to_quant.amax(dim=1, keepdim=True) |
736 |
| - min_val = to_quant.amin(dim=1, keepdim=True) |
737 |
| - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) |
738 |
| - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) |
739 |
| - |
740 |
| - max_val_abs = torch.max(-min_val_neg, max_val_pos) |
741 |
| - max_int = 2 ** (n_bit - 1) - 1 |
742 |
| - min_int = -(2 ** (n_bit - 1)) |
743 |
| - |
744 |
| - scales = max_val_abs / (float(max_int - min_int) / 2) |
745 |
| - scales = torch.max(scales, torch.full_like(scales, torch.finfo(torch.float32).eps)) |
746 |
| - # TODO: make sure abs(scales) is not too small? |
747 |
| - zeros = torch.full_like(scales, 0) |
748 |
| - return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape( |
749 |
| - w.shape[0], -1 |
750 |
| - ) |
| 736 | + block_size = (1, groupsize) |
| 737 | + mapping_type = MappingType.SYMMETRIC |
| 738 | + eps = torch.finfo(torch.float32).eps |
| 739 | + ranges = {} |
| 740 | + ranges[1] = (-1, 0) |
| 741 | + # generating ranges for bit 2 to 8 |
| 742 | + for i in range(2, 9): |
| 743 | + ranges[i] = (-(2 ** (i - 1)), 2 ** (i - 1) - 1) |
| 744 | + quant_min, quant_max = ranges[n_bit] |
| 745 | + scale, zero_point = choose_qparams_affine(w, mapping_type, block_size, target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, eps=eps, scale_dtype=precision, zero_point_dtype=precision) |
| 746 | + return scale.reshape(w.shape[0], -1), zero_point.reshape(w.shape[0], -1) |
751 | 747 |
|
752 | 748 |
|
753 | 749 | if TORCH_VERSION_AFTER_2_3:
|
|
0 commit comments