Skip to content

Commit 6082d30

Browse files
committed
put leading dims logic to _dynamic_int8_mm
1 parent 0d65b26 commit 6082d30

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

torchao/prototype/quantized_training/int8_mixed_precision.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,27 +168,25 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor:
168168
169169
TODO: check if transpose+quantize are actually fused.
170170
"""
171-
A_i8, A_scale_rowwise = quantize_int8_rowwise(A)
171+
# A may have more than 2 dims, while B must be exactly 2-dim
172+
A_i8, A_scale_rowwise = quantize_int8_rowwise(A.view(-1, A.shape[-1]))
172173
B_t_i8, B_scale_colwise = quantize_int8_rowwise(B.T)
173-
return int8_mm_dequant(
174+
out = int8_mm_dequant(
174175
A_i8.contiguous(),
175176
B_t_i8.contiguous().T,
176177
A_scale_rowwise.contiguous(),
177178
B_scale_colwise.contiguous(),
178179
)
180+
return out.view(*A.shape[:-1], out.shape[-1])
179181

180182

181183
class _Int8MixedPrecisionTrainingLinear(torch.autograd.Function):
182184
@staticmethod
183185
def forward(input: Tensor, weight: Int8MixedPrecisionTrainingLinearWeight, bias: Optional[Tensor]):
184186
if weight.config.output:
185-
batch_dims = input.shape[:-1]
186-
input = input.view(-1, weight.shape[1])
187187
out = _dynamic_int8_mm(input, weight._data.T)
188-
out = out.view(*batch_dims, weight.shape[0])
189188
else:
190-
out = input @ weight.T
191-
189+
out = input @ weight._data.T
192190
out = out + bias if bias is not None else out
193191
return out
194192

@@ -204,18 +202,15 @@ def backward(ctx, grad_output):
204202
input, weight = ctx.saved_tensors
205203
grad_input = grad_weight = grad_bias = None
206204

207-
batch_dims = grad_output.shape[:-1]
208-
grad_output = grad_output.view(-1, weight.shape[0])
209-
input = input.view(-1, weight.shape[1])
210-
211205
if ctx.needs_input_grad[0]:
212206
if ctx.config.grad_input:
213207
grad_input = _dynamic_int8_mm(grad_output, weight)
214208
else:
215209
grad_input = grad_output @ weight
216-
grad_input = grad_input.view(*batch_dims, weight.shape[1])
217210

218211
if ctx.needs_input_grad[1]:
212+
grad_output = grad_output.view(-1, weight.shape[0])
213+
input = input.view(-1, weight.shape[1])
219214
if ctx.config.grad_weight:
220215
# grad_weight = _dynamic_int8_mm(grad_output.T, input)
221216
grad_weight = _dynamic_int8_mm(input.T, grad_output).T # this is slightly faster

0 commit comments

Comments
 (0)