@@ -17103,23 +17103,25 @@ template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
17103
17103
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
17104
17104
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
17105
17105
17106
- if (nk1 >= 256) { //4096) {
17107
- if (nq1 >= 64) {
17108
- FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
17109
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17110
- return;
17111
- }
17112
- if (nq1 >= 32) {
17113
- FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
17114
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17115
- return;
17116
- }
17117
- if (nq1 >= 16) {
17118
- FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
17119
- fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17120
- return;
17121
- }
17122
- }
17106
+ // Not sure if this actually helps.
17107
+ // So, let's reduce compilation time by commenting it out for now.
17108
+ //if (nk1 >= 256) { //4096) {
17109
+ // if (nq1 >= 64) {
17110
+ // FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
17111
+ // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17112
+ // return;
17113
+ // }
17114
+ // if (nq1 >= 32) {
17115
+ // FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
17116
+ // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17117
+ // return;
17118
+ // }
17119
+ // if (nq1 >= 16) {
17120
+ // FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
17121
+ // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17122
+ // return;
17123
+ // }
17124
+ //}
17123
17125
if (nq1 >= 8) {
17124
17126
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
17125
17127
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
@@ -17265,13 +17267,25 @@ template <int step_k, typename KHelper, typename VHelper>
17265
17267
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
17266
17268
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
17267
17269
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
17268
- if (nq1 % 8 == 0 ) {
17270
+ if (nq1 >= 8 ) {
17269
17271
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
17270
17272
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
17271
- } else {
17273
+ }
17274
+ else if (nq1 >= 4) {
17275
+ FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
17276
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
17277
+ }
17278
+ else {
17272
17279
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
17273
17280
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
17274
17281
}
17282
+ //if (nq1 % 8 == 0) {
17283
+ // FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
17284
+ // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
17285
+ //} else {
17286
+ // FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
17287
+ // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
17288
+ //}
17275
17289
}
17276
17290
17277
17291
template <int step_k>
0 commit comments