Skip to content

Commit 7792153

Browse files
committed
Add default filtering to remove mis-alinged weights
1 parent baa78f2 commit 7792153

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

torchao/quantization/quant_api.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,32 @@ def _input_activation_quant_func_fp8(
733733
return activation
734734

735735

736+
def _fp8_mm_compat(weight: torch.Tensor) -> bool:
737+
"""
738+
Check if a weight tensor meets float8 quantization requirements.
739+
740+
Args:
741+
weight (torch.Tensor): The weight tensor to check
742+
743+
Returns:
744+
bool: True if the tensor can be quantized to float8, False otherwise
745+
"""
746+
assert (
747+
weight.dim() == 2
748+
), f"float8 quantization only works for 2-D tensors, got {weight.dim()}D tensor"
749+
750+
out_dim, in_dim = weight.shape
751+
is_compatible = (in_dim % 16 == 0) and (out_dim % 16 == 0)
752+
753+
if not is_compatible:
754+
logger.info(
755+
f"Skipping float8 quantization: weight shape {weight.shape} is not compatible with _scaled_mm. "
756+
f"Both input dimension ({in_dim}) and output dimension ({out_dim}) must be multiples of 16. "
757+
)
758+
759+
return is_compatible
760+
761+
736762
def float8_dynamic_activation_float8_weight(
737763
activation_dtype: torch.dtype = torch.float8_e4m3fn,
738764
weight_dtype: torch.dtype = torch.float8_e4m3fn,
@@ -761,6 +787,8 @@ def float8_dynamic_activation_float8_weight(
761787
activation_granularity, weight_granularity = _normalize_granularity(granularity)
762788

763789
def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
790+
if not _fp8_mm_compat(weight):
791+
return weight
764792
if isinstance(weight_granularity, PerRow):
765793
assert (
766794
weight.dtype == torch.bfloat16
@@ -818,6 +846,8 @@ def float8_static_activation_float8_weight(
818846
), "Static quantization only supports PerTensor granularity"
819847

820848
def apply_float8_static_activation_quant(weight: torch.Tensor):
849+
if not _fp8_mm_compat(weight):
850+
return weight
821851
block_size = get_block_size(weight.shape, weight_granularity)
822852
quantized_weight = to_affine_quantized_floatx(
823853
input_float=weight,

0 commit comments

Comments
 (0)