Skip to content

Commit be30a7f

Browse files
Fix FQ mask in 8da4w QAT (#199)
Co-authored-by: Jerry Zhang <[email protected]>
1 parent 5364de6 commit be30a7f

File tree

1 file changed

+1
-1
lines changed
  • torchao/quantization/prototype

1 file changed

+1
-1
lines changed

torchao/quantization/prototype/qat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def forward(ctx, input, scales, zero_points, quant_min, quant_max):
183183
q = input.div(scales).add(zero_points).round()
184184
dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales)
185185
# TODO: do we need this mask?
186-
mask = torch.logical_and((q >= quant_min), (dq <= quant_max))
186+
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
187187
ctx.save_for_backward(mask)
188188
return dq
189189

0 commit comments

Comments
 (0)