@@ -168,27 +168,25 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor:
168
168
169
169
TODO: check if transpose+quantize are actually fused.
170
170
"""
171
- A_i8 , A_scale_rowwise = quantize_int8_rowwise (A )
171
+ # A may have more than 2 dims, while B must be exactly 2-dim
172
+ A_i8 , A_scale_rowwise = quantize_int8_rowwise (A .view (- 1 , A .shape [- 1 ]))
172
173
B_t_i8 , B_scale_colwise = quantize_int8_rowwise (B .T )
173
- return int8_mm_dequant (
174
+ out = int8_mm_dequant (
174
175
A_i8 .contiguous (),
175
176
B_t_i8 .contiguous ().T ,
176
177
A_scale_rowwise .contiguous (),
177
178
B_scale_colwise .contiguous (),
178
179
)
180
+ return out .view (* A .shape [:- 1 ], out .shape [- 1 ])
179
181
180
182
181
183
class _Int8MixedPrecisionTrainingLinear (torch .autograd .Function ):
182
184
@staticmethod
183
185
def forward (input : Tensor , weight : Int8MixedPrecisionTrainingLinearWeight , bias : Optional [Tensor ]):
184
186
if weight .config .output :
185
- batch_dims = input .shape [:- 1 ]
186
- input = input .view (- 1 , weight .shape [1 ])
187
187
out = _dynamic_int8_mm (input , weight ._data .T )
188
- out = out .view (* batch_dims , weight .shape [0 ])
189
188
else :
190
- out = input @ weight .T
191
-
189
+ out = input @ weight ._data .T
192
190
out = out + bias if bias is not None else out
193
191
return out
194
192
@@ -204,18 +202,15 @@ def backward(ctx, grad_output):
204
202
input , weight = ctx .saved_tensors
205
203
grad_input = grad_weight = grad_bias = None
206
204
207
- batch_dims = grad_output .shape [:- 1 ]
208
- grad_output = grad_output .view (- 1 , weight .shape [0 ])
209
- input = input .view (- 1 , weight .shape [1 ])
210
-
211
205
if ctx .needs_input_grad [0 ]:
212
206
if ctx .config .grad_input :
213
207
grad_input = _dynamic_int8_mm (grad_output , weight )
214
208
else :
215
209
grad_input = grad_output @ weight
216
- grad_input = grad_input .view (* batch_dims , weight .shape [1 ])
217
210
218
211
if ctx .needs_input_grad [1 ]:
212
+ grad_output = grad_output .view (- 1 , weight .shape [0 ])
213
+ input = input .view (- 1 , weight .shape [1 ])
219
214
if ctx .config .grad_weight :
220
215
# grad_weight = _dynamic_int8_mm(grad_output.T, input)
221
216
grad_weight = _dynamic_int8_mm (input .T , grad_output ).T # this is slightly faster
0 commit comments