Skip to content

Commit ff25836

Browse files
committed
Add mx_fp4 path
stack-info: PR: #2201, branch: drisspg/stack/54
1 parent 0607aa1 commit ff25836

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.nn as nn
1212

1313
from torchao.prototype.mx_formats.config import (
14+
MXGemmKernelChoice,
1415
MXInferenceLinearConfig,
1516
MXLinearConfig,
1617
MXLinearRecipeName,
@@ -380,7 +381,7 @@ def test_inference_print_str():
380381
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
381382
)
382383
@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100")
383-
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn])
384+
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
384385
@pytest.mark.parametrize("bias", [True, False])
385386
@pytest.mark.parametrize("compile", [True, False])
386387
@torch.no_grad()
@@ -394,7 +395,16 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
394395

395396
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
396397
m_mx = copy.deepcopy(m)
397-
config = MXFPInferenceConfig()
398+
kernel_choice = (
399+
MXGemmKernelChoice.CUTLASS
400+
if elem_dtype == DTYPE_FP4
401+
else MXGemmKernelChoice.CUBLAS
402+
)
403+
config = MXFPInferenceConfig(
404+
activation_dtype=elem_dtype,
405+
weight_dtype=elem_dtype,
406+
gemm_kernel_choice=kernel_choice,
407+
)
398408
quantize_(m_mx, config=config)
399409
if compile:
400410
m_mx = torch.compile(m_mx, fullgraph=True)
@@ -403,4 +413,7 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
403413
y_ref = m(x)
404414
y_mx = m_mx(x)
405415
sqnr = compute_error(y_ref, y_mx)
406-
assert sqnr >= 25.0, f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
416+
SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0
417+
assert sqnr >= SQNR_THRESHOLD, (
418+
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
419+
)

torchao/prototype/mx_formats/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77

8+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8
9+
810
# This is conceptually an enum of non-core dtypes
911
# TODO(future PR): change to a cleaner way to represent this without
1012
# regressing torch.compile and while keeping things readable.
11-
DTYPE_FP4 = "fp4_e2m1"
13+
DTYPE_FP4 = torch.float4_e2m1fn_x2 if TORCH_VERSION_AT_LEAST_2_8 else "fp4_e2m1"
1214
DTYPE_FP6_E3M2 = "fp6_e3m2"
1315
DTYPE_FP6_E2M3 = "fp6_e2m3"
1416

0 commit comments

Comments
 (0)