4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Any , Optional , Tuple
7
+ from typing import Any , Tuple
8
8
9
9
import torch
10
10
from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib
11
11
from torch .library import impl
12
12
13
- from torchao .quantization .utils import TORCH_VERSION_AFTER_2_3
13
+ from torchao .quantization .utils import TORCH_VERSION_AFTER_2_4
14
14
from torchao .quantization .quant_primitives import get_group_qparams_symmetric
15
15
from torchao .quantization .unified import TwoStepQuantizer
16
16
17
17
18
- if TORCH_VERSION_AFTER_2_3 :
18
+ if TORCH_VERSION_AFTER_2_4 :
19
19
from torchao .quantization .GPTQ import (
20
20
_replace_linear_8da4w ,
21
21
Int8DynActInt4WeightLinear ,
@@ -54,7 +54,7 @@ def prepare(
54
54
self .precision ,
55
55
self .scales_precision ,
56
56
Int8DynActInt4WeightQATLinear ,
57
- copy_weights = True ,
57
+ copy_weights = True ,
58
58
)
59
59
return model
60
60
@@ -95,7 +95,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):
95
95
quantized_linear .zeros = zp
96
96
else :
97
97
_convert_qat_linear_8da4w (child )
98
-
98
+
99
99
class Int8DynActInt4WeightQATLinear (torch .nn .Linear ):
100
100
"""
101
101
This module implements a linear layer with int8 dynamic per token fake
@@ -131,6 +131,8 @@ def __init__(
131
131
self .groupsize = groupsize
132
132
self .precision = precision
133
133
self .scales_precision = scales_precision
134
+ # TODO: make this configurable?
135
+ self .zero_points_precision = torch .int32
134
136
self ._fake_quant_enabled = True
135
137
136
138
def enable_fake_quant (self , enabled : bool = True ):
@@ -142,8 +144,8 @@ def disable_fake_quant(self):
142
144
def forward (self , x : torch .Tensor ) -> torch .Tensor :
143
145
# activations: int8 dynamic asymmetric quant
144
146
if self ._fake_quant_enabled :
145
- (act_scales , act_zp ) = _choose_qparams_per_token_asymmetric (
146
- x , torch . int8 , # dtype not used
147
+ (act_scales , act_zp ) = _choose_qparams_per_token_asymmetric (
148
+ x , self . scales_precision , self . zero_points_precision ,
147
149
)
148
150
(act_qmin , act_qmax ) = self ._get_qmin_qmax (8 )
149
151
x_fq = fake_quantize_per_token (
@@ -157,6 +159,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
157
159
(weight_scales , weight_zp ) = get_group_qparams_symmetric (
158
160
self .weight , 4 , self .groupsize , self .scales_precision ,
159
161
)
162
+ # TODO: pass zp dtype to `get_group_qparams_symmetric` instead
163
+ weight_zp = weight_zp .to (self .zero_points_precision )
160
164
(weight_qmin , weight_qmax ) = self ._get_qmin_qmax (4 )
161
165
w_fq = fake_quantize_per_channel_group (
162
166
self .weight ,
@@ -190,6 +194,20 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module):
190
194
if isinstance (mod , Int8DynActInt4WeightQATLinear ):
191
195
mod .disable_fake_quant ()
192
196
197
+ else : # not TORCH_VERSION_AFTER_2_4
198
+
199
+ class Int8DynActInt4WeightQATQuantizer :
200
+ def __init__ (* args , ** kwargs ):
201
+ raise ValueError (
202
+ "Int8DynActInt4WeightQATQuantizer is only supported after PyTorch 2.4+"
203
+ )
204
+
205
+ class Int8DynActInt4WeightQATLinear :
206
+ def __init__ (* args , ** kwargs ):
207
+ raise ValueError (
208
+ "Int8DynActInt4WeightQATLinear is only supported after PyTorch 2.4+"
209
+ )
210
+
193
211
194
212
# ========================
195
213
# | QUANT PRIMITIVES |
@@ -205,13 +223,14 @@ class _GenericFakeQuantize(torch.autograd.Function):
205
223
206
224
@staticmethod
207
225
def forward (ctx , input , scales , zero_points , quant_min , quant_max ):
208
- # Note: this diverges from `torch.fake_quantize_per_channel_affine`,
209
- # which rounds first before adding the zero points. However, this
210
- # is what `quantize_per_channel_group` and `quantize_per_token`
211
- # do and here we try to match that behavior as closely as possible.
212
- q = input .mul (1.0 / scales ).add (zero_points ).round ()
226
+ # Note: for bf16 inputs, casting them to fp32 has the unexpected
227
+ # side effect of reducing memory footprint significantly, presumably
228
+ # because bf16 * fp32 kernels are not as memory efficient
229
+ assert input .dtype == torch .float32
230
+ assert scales .dtype == torch .float32
231
+ assert zero_points .dtype == torch .int32
232
+ q = input .mul (1.0 / scales ).round ().add (zero_points )
213
233
dq = q .clamp (quant_min , quant_max ).sub (zero_points ).mul (scales )
214
- # TODO: do we need this mask?
215
234
mask = torch .logical_and ((q >= quant_min ), (q <= quant_max ))
216
235
ctx .save_for_backward (mask )
217
236
return dq
@@ -239,14 +258,13 @@ def fake_quantize_per_channel_group(
239
258
assert group_size > 1
240
259
assert input .shape [- 1 ] % group_size == 0
241
260
assert input .dim () == 2
242
- assert torch .isnan (input ).sum () == 0
243
- grouped_input = input .reshape (- 1 , group_size )
261
+ grouped_input = input .reshape (- 1 , group_size ).to (torch .float32 )
244
262
scales = scales .reshape (- 1 , 1 )
245
263
zero_points = zero_points .reshape (- 1 , 1 )
246
264
fq = _GenericFakeQuantize .apply (
247
265
grouped_input , scales , zero_points , quant_min , quant_max ,
248
266
)
249
- return fq .reshape_as (input )
267
+ return fq .reshape_as (input ). to ( input . dtype )
250
268
251
269
# TODO: move this to core
252
270
quantized_decomposed_lib .define (
@@ -266,17 +284,20 @@ def fake_quantize_per_token(
266
284
from torch .ao .quantization .fx ._decomposed import _per_token_quant_qparam_dim_check
267
285
268
286
_per_token_quant_qparam_dim_check (input , scales , zero_points )
269
- return _GenericFakeQuantize .apply (
270
- input , scales , zero_points , quant_min , quant_max ,
287
+ fq_input = input .to (torch .float32 )
288
+ fq = _GenericFakeQuantize .apply (
289
+ fq_input , scales , zero_points , quant_min , quant_max ,
271
290
)
291
+ return fq .reshape_as (input ).to (input .dtype )
272
292
273
293
# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
274
294
# The version in pytorch does not have backward support yet so we add
275
295
# it here for now until https://github.com/pytorch/pytorch/pull/123452
276
296
# is landed.
277
297
def _choose_qparams_per_token_asymmetric (
278
298
input : torch .Tensor ,
279
- dtype : torch .dtype ,
299
+ scales_precision : torch .dtype = torch .float32 ,
300
+ zero_points_precision : torch .dtype = torch .float32 ,
280
301
) -> Tuple [torch .Tensor , torch .Tensor ]:
281
302
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
282
303
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
@@ -285,7 +306,8 @@ def _choose_qparams_per_token_asymmetric(
285
306
286
307
Args:
287
308
input (torch.Tensor): original float32/float16 Tensor
288
- dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
309
+ scales_precision (torch.dtype): precision of returned scales
310
+ zero_points_precision (torch.dtype): precision of returned zero points
289
311
290
312
Returns:
291
313
scales and zero_points, both float32 Tensors
@@ -314,4 +336,4 @@ def _choose_qparams_per_token_asymmetric(
314
336
)
315
337
zero_point = torch .clamp (zero_point , qmin , qmax ).round ()
316
338
317
- return scale .to (torch . float32 ), zero_point .to (torch . float32 )
339
+ return scale .to (scales_precision ), zero_point .to (zero_points_precision )
0 commit comments