11
11
import torch .nn as nn
12
12
13
13
from torchao .prototype .mx_formats .config import (
14
+ MXGemmKernelChoice ,
14
15
MXInferenceLinearConfig ,
15
16
MXLinearConfig ,
16
17
MXLinearRecipeName ,
@@ -380,7 +381,7 @@ def test_inference_print_str():
380
381
not TORCH_VERSION_AT_LEAST_2_8 , reason = "torch.compile requires PyTorch 2.8+"
381
382
)
382
383
@pytest .mark .skipif (not is_sm_at_least_100 , reason = "Reqs sm100" )
383
- @pytest .mark .parametrize ("elem_dtype" , [torch .float8_e4m3fn ])
384
+ @pytest .mark .parametrize ("elem_dtype" , [torch .float8_e4m3fn , torch . float4_e2m1fn_x2 ])
384
385
@pytest .mark .parametrize ("bias" , [True , False ])
385
386
@pytest .mark .parametrize ("compile" , [True , False ])
386
387
@torch .no_grad ()
@@ -394,7 +395,16 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
394
395
395
396
m = nn .Linear (32 , 128 , bias = bias , dtype = torch .bfloat16 , device = "cuda" )
396
397
m_mx = copy .deepcopy (m )
397
- config = MXFPInferenceConfig ()
398
+ kernel_choice = (
399
+ MXGemmKernelChoice .CUTLASS
400
+ if elem_dtype == DTYPE_FP4
401
+ else MXGemmKernelChoice .CUBLAS
402
+ )
403
+ config = MXFPInferenceConfig (
404
+ activation_dtype = elem_dtype ,
405
+ weight_dtype = elem_dtype ,
406
+ gemm_kernel_choice = kernel_choice ,
407
+ )
398
408
quantize_ (m_mx , config = config )
399
409
if compile :
400
410
m_mx = torch .compile (m_mx , fullgraph = True )
@@ -403,4 +413,7 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
403
413
y_ref = m (x )
404
414
y_mx = m_mx (x )
405
415
sqnr = compute_error (y_ref , y_mx )
406
- assert sqnr >= 25.0 , f"Got a sqnr of { sqnr } for { elem_dtype } and bias={ bias } "
416
+ SQNR_THRESHOLD = 25.0 if elem_dtype == torch .float8_e4m3fn else 15.0
417
+ assert sqnr >= SQNR_THRESHOLD , (
418
+ f"Got a sqnr of { sqnr } for { elem_dtype } and bias={ bias } "
419
+ )
0 commit comments