Skip to content

Commit 379010f

Browse files
committed
better check for mxfp8 cuda kernel presence (#2933)
Summary: Short term fix for #2932. If torchao was build without CUDA 10.0 (such as in our CI), ensures that: a. only callsites which actually use the mxfp8 dim1 kernel see the error message. Using NVFP4 no longer hits this error. b. make the error message point to github issue for more info on the workaround (for now, build from souce). Test Plan: 1. hardcode mxfp8 kernel from being built: https://github.com/pytorch/ao/blob/85557135c93d3429320a4a360c0ee9cb49f84a00/setup.py#L641 2. build torchao from source, verify `torchao/prototype` does not have any `.so` files 3. run nvfp4 tests, verify they now pass: `pytest test/prototype/mx_formats/test_nvfp4_tensor.py -s -x` 4. run mxfp8 linear tests, verify the new error message is displayed for dim1 kernel tests: `pytest test/prototype/mx_formats/test_mx_linear.py -s -x -k test_linear_eager_vs_hp` 5. undo the change in (1), rebuild torchao, verify all mx tests pass: `pytest test/prototype/mx_formats/ -s -x` Reviewers: Subscribers: Tasks: Tags:
1 parent 1eb5902 commit 379010f

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import Optional, Tuple
89

910
import numpy as np
@@ -35,6 +36,8 @@
3536
F32_EXP_BIAS,
3637
)
3738

39+
logger = logging.getLogger(__name__)
40+
3841

3942
def get_bits(x: torch.Tensor) -> str:
4043
bits_per_byte = 8
@@ -1476,10 +1479,20 @@ def triton_quantize_nvfp4(
14761479
raise AssertionError("needs torch version 2.8+ and triton")
14771480

14781481

1479-
# MXFP8 CUDA kernel is only built on SM100+
1482+
mxfp8_cuda_extension_available = False
14801483
if is_sm_at_least_100():
1481-
from torchao.prototype import mxfp8_cuda
1482-
1484+
try:
1485+
# MXFP8 CUDA kernel is only built on SM100+. Furthermore,
1486+
# currently our CI runners are not SM100+, so the user needs to build
1487+
# from source.
1488+
# TODO(#2932): improve this
1489+
from torchao.prototype import mxfp8_cuda
1490+
1491+
mxfp8_cuda_extension_available = True
1492+
except ImportError:
1493+
logging.debug("Skipping import of torchao.prototype.mxfp8_cuda")
1494+
1495+
if mxfp8_cuda_extension_available:
14831496
# TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string.
14841497
# Currently we have to use an arbitrary string because custom ops don't support enum
14851498
# params.
@@ -1599,4 +1612,6 @@ def mxfp8_quantize_cuda(
15991612
colwise: bool = True,
16001613
scaling_mode: str = "floor",
16011614
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1602-
raise NotImplementedError("needs torch version 2.8+ and sm100")
1615+
raise NotImplementedError(
1616+
"`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details."
1617+
)

0 commit comments

Comments
 (0)