Skip to content

Commit a6de35a

Browse files
Make the kernel fail for sm75 + bfloat16 inputs
1 parent 379bd5e commit a6de35a

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

torchao/csrc/cuda/fp6_llm/fp6_linear.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,11 @@ void fpx_linear_kernel(cudaStream_t stream,
109109
CHECK_CUDA(cudaGetDevice(&device));
110110
CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
111111
CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
112-
113-
if ((major < 7) || (major == 7 && minor < 5)) {
114-
TORCH_CHECK(false, "FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n");
115-
}
116-
117112
const bool is_sm75_gpu = (major == 7) && (minor == 5);
113+
if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value)
114+
TORCH_CHECK(false, "Bfloat16 inputs are not supported for SM75");
115+
if ((major < 7) || (major == 7 && minor < 5))
116+
TORCH_CHECK(false, "FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n");
118117

119118
if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) {
120119
// For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory.

0 commit comments

Comments
 (0)