Skip to content

Commit d12f4a1

Browse files
author
Iwan Kawrakow
committed
Improve DeepSeek batched processing speed
1 parent 5a4855e commit d12f4a1

File tree

2 files changed

+34
-20
lines changed

2 files changed

+34
-20
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17103,23 +17103,25 @@ template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
1710317103
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
1710417104
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
1710517105

17106-
if (nk1 >= 256) { //4096) {
17107-
if (nq1 >= 64) {
17108-
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
17109-
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17110-
return;
17111-
}
17112-
if (nq1 >= 32) {
17113-
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
17114-
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17115-
return;
17116-
}
17117-
if (nq1 >= 16) {
17118-
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
17119-
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17120-
return;
17121-
}
17122-
}
17106+
// Not sure if this actually helps.
17107+
// So, let's reduce compilation time by commenting it out for now.
17108+
//if (nk1 >= 256) { //4096) {
17109+
// if (nq1 >= 64) {
17110+
// FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
17111+
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17112+
// return;
17113+
// }
17114+
// if (nq1 >= 32) {
17115+
// FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
17116+
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17117+
// return;
17118+
// }
17119+
// if (nq1 >= 16) {
17120+
// FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
17121+
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
17122+
// return;
17123+
// }
17124+
//}
1712317125
if (nq1 >= 8) {
1712417126
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
1712517127
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
@@ -17265,13 +17267,25 @@ template <int step_k, typename KHelper, typename VHelper>
1726517267
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
1726617268
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
1726717269
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
17268-
if (nq1 % 8 == 0) {
17270+
if (nq1 >= 8) {
1726917271
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
1727017272
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
17271-
} else {
17273+
}
17274+
else if (nq1 >= 4) {
17275+
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
17276+
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
17277+
}
17278+
else {
1727217279
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
1727317280
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
1727417281
}
17282+
//if (nq1 % 8 == 0) {
17283+
// FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
17284+
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
17285+
//} else {
17286+
// FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
17287+
// fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
17288+
//}
1727517289
}
1727617290

1727717291
template <int step_k>

src/llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13896,7 +13896,7 @@ struct llm_build_context {
1389613896

1389713897
// whether to use n_tokens as the matrix dimension during multiplication or n_head
1389813898
// n_tokens is higher during prompt processing, this allows to optimize for this case
13899-
bool pp_opt = n_tokens > n_head;
13899+
bool pp_opt = n_tokens >= 128; // Is it a fixed constant or is it somehow relared to n_head? original: n_tokens > n_head;
1390013900

1390113901
for (int il = 0; il < n_layer; ++il) {
1390213902
struct ggml_tensor * inpSA = inpL;

0 commit comments

Comments
 (0)