|
4 | 4 | # This source code is licensed under the license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import logging |
7 | 8 | from typing import Optional, Tuple
|
8 | 9 |
|
9 | 10 | import numpy as np
|
|
35 | 36 | F32_EXP_BIAS,
|
36 | 37 | )
|
37 | 38 |
|
| 39 | +logger = logging.getLogger(__name__) |
| 40 | + |
38 | 41 |
|
39 | 42 | def get_bits(x: torch.Tensor) -> str:
|
40 | 43 | bits_per_byte = 8
|
@@ -1476,10 +1479,20 @@ def triton_quantize_nvfp4(
|
1476 | 1479 | raise AssertionError("needs torch version 2.8+ and triton")
|
1477 | 1480 |
|
1478 | 1481 |
|
1479 |
| -# MXFP8 CUDA kernel is only built on SM100+ |
| 1482 | +mxfp8_cuda_extension_available = False |
1480 | 1483 | if is_sm_at_least_100():
|
1481 |
| - from torchao.prototype import mxfp8_cuda |
1482 |
| - |
| 1484 | + try: |
| 1485 | + # MXFP8 CUDA kernel is only built on SM100+. Furthermore, |
| 1486 | + # currently our CI runners are not SM100+, so the user needs to build |
| 1487 | + # from source. |
| 1488 | + # TODO(#2932): improve this |
| 1489 | + from torchao.prototype import mxfp8_cuda |
| 1490 | + |
| 1491 | + mxfp8_cuda_extension_available = True |
| 1492 | + except ImportError: |
| 1493 | + logging.debug("Skipping import of torchao.prototype.mxfp8_cuda") |
| 1494 | + |
| 1495 | +if mxfp8_cuda_extension_available: |
1483 | 1496 | # TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string.
|
1484 | 1497 | # Currently we have to use an arbitrary string because custom ops don't support enum
|
1485 | 1498 | # params.
|
@@ -1599,4 +1612,6 @@ def mxfp8_quantize_cuda(
|
1599 | 1612 | colwise: bool = True,
|
1600 | 1613 | scaling_mode: str = "floor",
|
1601 | 1614 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
1602 |
| - raise NotImplementedError("needs torch version 2.8+ and sm100") |
| 1615 | + raise NotImplementedError( |
| 1616 | + "`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details." |
| 1617 | + ) |
0 commit comments