Skip to content

Commit 54c48d1

Browse files
committed
Match torch.fake_quantize numerics in 8da4w QAT
Summary: There are two subtle differences between the 8da4w quant primitives and `torch.fake_quantize_per_channel_affine` today: 1. 8da4w uses float32 zero points torch.fake_quantize uses int32 zero points 2. 8da4w uses input.div(scales) torch.fake_quantize uses input.mul(1.0 / scales) Of these two differences, the second one is smaller and only resulted in 0.1% elements mismatched in unit tests, but it is a source of numerical divergence nonetheless. This commit changes 8da4w QAT quant primitives to match the torch.fake_quantize behavior for both of these differences. In a future commit, we will change the 8da4w PTQ quant primitives as well so PTQ and QAT remain consistent. Note: This commit also has the side effect of reducing memory footprint significantly for bf16 inputs. We now cast them to fp32 before multiplying them with fp32 scales. This reduced memory usage presumably because bf16 * fp32 kernels are not as memory efficient. Test Plan: python test/quantization/test_qat.py -k test_qat_generic_fake_quantize Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar
1 parent 3dd16c9 commit 54c48d1

File tree

3 files changed

+78
-33
lines changed

3 files changed

+78
-33
lines changed

test/quantization/test_qat.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1515
from torchao.quantization.prototype.qat import (
1616
_choose_qparams_per_token_asymmetric,
17+
_GenericFakeQuantize,
1718
fake_quantize_per_channel_group,
1819
fake_quantize_per_token,
1920
)
@@ -58,7 +59,7 @@ def _get_qmin_qmax(self, n_bit: int):
5859
qmax = 2 ** (n_bit - 1) - 1
5960
return (qmin, qmax)
6061

61-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
62+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
6263
def test_fake_quantize_per_channel_group(self):
6364
n_bit = 4
6465
(qmin, qmax) = self._get_qmin_qmax(n_bit)
@@ -67,6 +68,7 @@ def test_fake_quantize_per_channel_group(self):
6768
torch.manual_seed(self.SEED)
6869
x = torch.randn(100, 256).requires_grad_()
6970
(s, zp) = get_group_qparams_symmetric(x, n_bit, group_size)
71+
zp = zp.to(torch.int32)
7072
x2 = copy.deepcopy(x)
7173

7274
# fake quant op
@@ -84,18 +86,15 @@ def test_fake_quantize_per_channel_group(self):
8486
)
8587
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)
8688

87-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
89+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
8890
def test_fake_quantize_per_token(self):
8991
(qmin, qmax) = self._get_qmin_qmax(8)
9092

9193
torch.manual_seed(self.SEED)
9294
x = torch.randn(100, 256).requires_grad_()
9395
x2 = copy.deepcopy(x)
9496
# TODO: use torch.ops.aten.quantized_decomposed version instead
95-
(s, zp) = _choose_qparams_per_token_asymmetric(
96-
x,
97-
torch.int8, # not used
98-
)
97+
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)
9998

10099
# fake quant op
101100
out = fake_quantize_per_token(x, s, zp, qmin, qmax)
@@ -130,7 +129,7 @@ def _set_ptq_weight(
130129
ptq_linear.scales = s
131130
ptq_linear.zeros = zp
132131

133-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
132+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
134133
def test_qat_8da4w_linear(self):
135134
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
136135
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
@@ -155,7 +154,7 @@ def test_qat_8da4w_linear(self):
155154
ptq_out = ptq_linear(x2)
156155
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
157156

158-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
157+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
159158
def test_qat_8da4w_quantizer(self):
160159
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
161160
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
@@ -189,7 +188,7 @@ def test_qat_8da4w_quantizer(self):
189188
for k in ptq_state_dict.keys():
190189
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
191190

192-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
191+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
193192
def test_qat_8da4w_quantizer_meta_weights(self):
194193
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
195194

@@ -201,7 +200,7 @@ def test_qat_8da4w_quantizer_meta_weights(self):
201200
qat_model = qat_quantizer.prepare(m)
202201
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))
203202

204-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
203+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
205204
def test_qat_8da4w_quantizer_disable_fake_quant(self):
206205
"""
207206
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
@@ -254,7 +253,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
254253
qat_out2 = qat_model2(*x2)
255254
torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0)
256255

257-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
256+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
258257
def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
259258
"""
260259
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
@@ -299,6 +298,30 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
299298
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
300299
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)
301300

301+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
302+
def test_qat_generic_fake_quantize(self):
303+
"""
304+
Test that the generic fake quantize used in 8da4w QAT matches
305+
the numerics of existing fake quantize ops in Pytorch in both
306+
the forward and the backward passes.
307+
"""
308+
(qmin, qmax) = self._get_qmin_qmax(4)
309+
py_input = torch.randn(16, 64).float().requires_grad_()
310+
py_s = torch.randn(16).float()
311+
py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32)
312+
py_out = torch.fake_quantize_per_channel_affine(py_input, py_s, py_zp, 0, qmin, qmax)
313+
py_out.sum().backward()
314+
315+
ao_input = copy.deepcopy(py_input)
316+
ao_input.grad.data.zero_()
317+
ao_s = copy.deepcopy(py_s).reshape(-1, 1)
318+
ao_zp = copy.deepcopy(py_zp).reshape(-1, 1)
319+
ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax)
320+
ao_out.sum().backward()
321+
322+
torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
323+
torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0)
324+
302325

303326
if __name__ == "__main__":
304327
unittest.main()

torchao/quantization/prototype/qat.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Optional, Tuple
7+
from typing import Any, Tuple
88

99
import torch
1010
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
1111
from torch.library import impl
1212

13-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
13+
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
1414
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
1515
from torchao.quantization.unified import TwoStepQuantizer
1616

1717

18-
if TORCH_VERSION_AFTER_2_3:
18+
if TORCH_VERSION_AFTER_2_4:
1919
from torchao.quantization.GPTQ import (
2020
_replace_linear_8da4w,
2121
Int8DynActInt4WeightLinear,
@@ -54,7 +54,7 @@ def prepare(
5454
self.precision,
5555
self.scales_precision,
5656
Int8DynActInt4WeightQATLinear,
57-
copy_weights = True,
57+
copy_weights=True,
5858
)
5959
return model
6060

@@ -95,7 +95,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):
9595
quantized_linear.zeros = zp
9696
else:
9797
_convert_qat_linear_8da4w(child)
98-
98+
9999
class Int8DynActInt4WeightQATLinear(torch.nn.Linear):
100100
"""
101101
This module implements a linear layer with int8 dynamic per token fake
@@ -131,6 +131,8 @@ def __init__(
131131
self.groupsize = groupsize
132132
self.precision = precision
133133
self.scales_precision = scales_precision
134+
# TODO: make this configurable?
135+
self.zero_points_precision = torch.int32
134136
self._fake_quant_enabled = True
135137

136138
def enable_fake_quant(self, enabled: bool = True):
@@ -142,8 +144,8 @@ def disable_fake_quant(self):
142144
def forward(self, x: torch.Tensor) -> torch.Tensor:
143145
# activations: int8 dynamic asymmetric quant
144146
if self._fake_quant_enabled:
145-
(act_scales, act_zp) =_choose_qparams_per_token_asymmetric(
146-
x, torch.int8, # dtype not used
147+
(act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
148+
x, self.scales_precision, self.zero_points_precision,
147149
)
148150
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
149151
x_fq = fake_quantize_per_token(
@@ -157,6 +159,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
157159
(weight_scales, weight_zp) = get_group_qparams_symmetric(
158160
self.weight, 4, self.groupsize, self.scales_precision,
159161
)
162+
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
163+
weight_zp = weight_zp.to(self.zero_points_precision)
160164
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
161165
w_fq = fake_quantize_per_channel_group(
162166
self.weight,
@@ -190,6 +194,20 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module):
190194
if isinstance(mod, Int8DynActInt4WeightQATLinear):
191195
mod.disable_fake_quant()
192196

197+
else: # not TORCH_VERSION_AFTER_2_4
198+
199+
class Int8DynActInt4WeightQATQuantizer:
200+
def __init__(*args, **kwargs):
201+
raise ValueError(
202+
"Int8DynActInt4WeightQATQuantizer is only supported after PyTorch 2.4+"
203+
)
204+
205+
class Int8DynActInt4WeightQATLinear:
206+
def __init__(*args, **kwargs):
207+
raise ValueError(
208+
"Int8DynActInt4WeightQATLinear is only supported after PyTorch 2.4+"
209+
)
210+
193211

194212
# ========================
195213
# | QUANT PRIMITIVES |
@@ -205,13 +223,14 @@ class _GenericFakeQuantize(torch.autograd.Function):
205223

206224
@staticmethod
207225
def forward(ctx, input, scales, zero_points, quant_min, quant_max):
208-
# Note: this diverges from `torch.fake_quantize_per_channel_affine`,
209-
# which rounds first before adding the zero points. However, this
210-
# is what `quantize_per_channel_group` and `quantize_per_token`
211-
# do and here we try to match that behavior as closely as possible.
212-
q = input.mul(1.0 / scales).add(zero_points).round()
226+
# Note: for bf16 inputs, casting them to fp32 has the unexpected
227+
# side effect of reducing memory footprint significantly, presumably
228+
# because bf16 * fp32 kernels are not as memory efficient
229+
assert input.dtype == torch.float32
230+
assert scales.dtype == torch.float32
231+
assert zero_points.dtype == torch.int32
232+
q = input.mul(1.0 / scales).round().add(zero_points)
213233
dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales)
214-
# TODO: do we need this mask?
215234
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
216235
ctx.save_for_backward(mask)
217236
return dq
@@ -239,14 +258,13 @@ def fake_quantize_per_channel_group(
239258
assert group_size > 1
240259
assert input.shape[-1] % group_size == 0
241260
assert input.dim() == 2
242-
assert torch.isnan(input).sum() == 0
243-
grouped_input = input.reshape(-1, group_size)
261+
grouped_input = input.reshape(-1, group_size).to(torch.float32)
244262
scales = scales.reshape(-1, 1)
245263
zero_points = zero_points.reshape(-1, 1)
246264
fq = _GenericFakeQuantize.apply(
247265
grouped_input, scales, zero_points, quant_min, quant_max,
248266
)
249-
return fq.reshape_as(input)
267+
return fq.reshape_as(input).to(input.dtype)
250268

251269
# TODO: move this to core
252270
quantized_decomposed_lib.define(
@@ -266,17 +284,20 @@ def fake_quantize_per_token(
266284
from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check
267285

268286
_per_token_quant_qparam_dim_check(input, scales, zero_points)
269-
return _GenericFakeQuantize.apply(
270-
input, scales, zero_points, quant_min, quant_max,
287+
fq_input = input.to(torch.float32)
288+
fq = _GenericFakeQuantize.apply(
289+
fq_input, scales, zero_points, quant_min, quant_max,
271290
)
291+
return fq.reshape_as(input).to(input.dtype)
272292

273293
# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
274294
# The version in pytorch does not have backward support yet so we add
275295
# it here for now until https://github.com/pytorch/pytorch/pull/123452
276296
# is landed.
277297
def _choose_qparams_per_token_asymmetric(
278298
input: torch.Tensor,
279-
dtype: torch.dtype,
299+
scales_precision: torch.dtype = torch.float32,
300+
zero_points_precision: torch.dtype = torch.float32,
280301
) -> Tuple[torch.Tensor, torch.Tensor]:
281302
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
282303
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
@@ -285,7 +306,8 @@ def _choose_qparams_per_token_asymmetric(
285306
286307
Args:
287308
input (torch.Tensor): original float32/float16 Tensor
288-
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
309+
scales_precision (torch.dtype): precision of returned scales
310+
zero_points_precision (torch.dtype): precision of returned zero points
289311
290312
Returns:
291313
scales and zero_points, both float32 Tensors
@@ -314,4 +336,4 @@ def _choose_qparams_per_token_asymmetric(
314336
)
315337
zero_point = torch.clamp(zero_point, qmin, qmax).round()
316338

317-
return scale.to(torch.float32), zero_point.to(torch.float32)
339+
return scale.to(scales_precision), zero_point.to(zero_points_precision)

torchao/quantization/quant_primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ def groupwise_affine_dequantize_tensor(
764764
)
765765

766766

767-
# TODO: replace this with torch.ao.quantization.PerChannelMinMaxObserver
767+
# TODO: separate scale and zero point precision
768768
def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float32):
769769
# needed for GPTQ with padding
770770
if groupsize > w.shape[-1]:

0 commit comments

Comments
 (0)