File tree Expand file tree Collapse file tree 1 file changed +4
-5
lines changed
torchao/csrc/cuda/fp6_llm Expand file tree Collapse file tree 1 file changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -109,12 +109,11 @@ void fpx_linear_kernel(cudaStream_t stream,
109
109
CHECK_CUDA (cudaGetDevice (&device));
110
110
CHECK_CUDA (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
111
111
CHECK_CUDA (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
112
-
113
- if ((major < 7 ) || (major == 7 && minor < 5 )) {
114
- TORCH_CHECK (false , " FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n " );
115
- }
116
-
117
112
const bool is_sm75_gpu = (major == 7 ) && (minor == 5 );
113
+ if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value)
114
+ TORCH_CHECK (false , " Bfloat16 inputs are not supported for SM75" );
115
+ if ((major < 7 ) || (major == 7 && minor < 5 ))
116
+ TORCH_CHECK (false , " FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n " );
118
117
119
118
if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0 )) {
120
119
// For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory.
You can’t perform that action at this time.
0 commit comments