@@ -733,6 +733,32 @@ def _input_activation_quant_func_fp8(
733
733
return activation
734
734
735
735
736
+ def _fp8_mm_compat (weight : torch .Tensor ) -> bool :
737
+ """
738
+ Check if a weight tensor meets float8 quantization requirements.
739
+
740
+ Args:
741
+ weight (torch.Tensor): The weight tensor to check
742
+
743
+ Returns:
744
+ bool: True if the tensor can be quantized to float8, False otherwise
745
+ """
746
+ assert (
747
+ weight .dim () == 2
748
+ ), f"float8 quantization only works for 2-D tensors, got { weight .dim ()} D tensor"
749
+
750
+ out_dim , in_dim = weight .shape
751
+ is_compatible = (in_dim % 16 == 0 ) and (out_dim % 16 == 0 )
752
+
753
+ if not is_compatible :
754
+ logger .info (
755
+ f"Skipping float8 quantization: weight shape { weight .shape } is not compatible with _scaled_mm. "
756
+ f"Both input dimension ({ in_dim } ) and output dimension ({ out_dim } ) must be multiples of 16. "
757
+ )
758
+
759
+ return is_compatible
760
+
761
+
736
762
def float8_dynamic_activation_float8_weight (
737
763
activation_dtype : torch .dtype = torch .float8_e4m3fn ,
738
764
weight_dtype : torch .dtype = torch .float8_e4m3fn ,
@@ -761,6 +787,8 @@ def float8_dynamic_activation_float8_weight(
761
787
activation_granularity , weight_granularity = _normalize_granularity (granularity )
762
788
763
789
def apply_float8_dynamic_activation_quant (weight : torch .Tensor ):
790
+ if not _fp8_mm_compat (weight ):
791
+ return weight
764
792
if isinstance (weight_granularity , PerRow ):
765
793
assert (
766
794
weight .dtype == torch .bfloat16
@@ -818,6 +846,8 @@ def float8_static_activation_float8_weight(
818
846
), "Static quantization only supports PerTensor granularity"
819
847
820
848
def apply_float8_static_activation_quant (weight : torch .Tensor ):
849
+ if not _fp8_mm_compat (weight ):
850
+ return weight
821
851
block_size = get_block_size (weight .shape , weight_granularity )
822
852
quantized_weight = to_affine_quantized_floatx (
823
853
input_float = weight ,
0 commit comments