Skip to content

Commit 699c9cb

Browse files
ikawrakowIwan Kawrakow
andauthored
Faster MoE token generation on CUDA (#248)
* This gives us ~20% TG speedup for DeepSeek on CUDA * Slightly better * Also do it for plain (not fused) mul_mat_id * Guard against numerical precision issues for MLA on CUDA --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent b096a5d commit 699c9cb

File tree

6 files changed

+487
-208
lines changed

6 files changed

+487
-208
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 278 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,6 +1765,93 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg
17651765
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
17661766
}
17671767

1768+
/*
1769+
static void ggml_cuda_op_gemv_id(
1770+
ggml_backend_cuda_context & ctx,
1771+
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src0_ids, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
1772+
quantize_cuda_t quantize_src1) {
1773+
1774+
GGML_ASSERT(src0->ne[3] == 1);
1775+
GGML_ASSERT(ggml_is_contiguous(src0));
1776+
GGML_ASSERT(ggml_is_contiguous(src1));
1777+
GGML_ASSERT(ggml_is_contiguous(dst));
1778+
GGML_ASSERT(ggml_nrows(src1) == 1);
1779+
GGML_ASSERT(src0_ids->ne[1] == 1);
1780+
GGML_ASSERT(src0_ids->ne[0] <= dst->ne[2]);
1781+
GGML_ASSERT(dst->ne[1] == 1);
1782+
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
1783+
1784+
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
1785+
GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
1786+
GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
1787+
1788+
ggml_backend_cuda_buffer_context * src0_ctx = (ggml_backend_cuda_buffer_context *) src0->buffer->context;
1789+
ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
1790+
ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
1791+
1792+
int device_id = ctx.device;
1793+
GGML_ASSERT(src0_ctx->device == device_id);
1794+
GGML_ASSERT(src1_ctx->device == device_id);
1795+
GGML_ASSERT(dst_ctx->device == device_id);
1796+
1797+
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
1798+
GGML_ASSERT(!split);
1799+
1800+
const int64_t ne00 = src0->ne[0];
1801+
const int64_t ne01 = src0->ne[1];
1802+
const int64_t ne02 = src0->ne[2];
1803+
1804+
const int64_t ne10 = src1->ne[0];
1805+
const int64_t nrows1 = 1;
1806+
1807+
const int64_t ne0 = dst->ne[0];
1808+
const int64_t ne2 = dst->ne[2];
1809+
1810+
const int64_t nb2 = dst->nb[2];
1811+
1812+
// Why?
1813+
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
1814+
1815+
const size_t src0_rs = ggml_row_size(src0->type, ne00);
1816+
const size_t q8_1_ts = sizeof(block_q8_1);
1817+
const size_t q8_1_bs = QK8_1;
1818+
1819+
const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
1820+
1821+
ggml_cuda_pool_alloc<char> src0_dd_alloc;
1822+
ggml_cuda_pool_alloc<float> src1_ddf_alloc;
1823+
ggml_cuda_pool_alloc<char> src1_ddq_alloc;
1824+
ggml_cuda_pool_alloc<float> dst_dd_alloc;
1825+
1826+
char * src0_dd = nullptr;
1827+
float * src1_ddf = (float *)src1->data;
1828+
char * src1_ddq = nullptr; // q8_1
1829+
float * dst_dd = (float *)dst->data;
1830+
1831+
bool quantization_done = false;
1832+
1833+
const bool src1_on_device = device_id == src1_ctx->device;
1834+
const bool dst_on_device = device_id == dst_ctx->device;
1835+
1836+
ggml_cuda_set_device(device_id);
1837+
cudaStream_t stream = ctx.stream(device_id, 0);
1838+
1839+
src0_dd = (char *) src0->data;
1840+
1841+
if (quantize_src1) {
1842+
size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
1843+
src1_ddq = src1_ddq_alloc.alloc(ctx.pool(device_id), src_1_ddq_size);
1844+
quantize_src1(src1_ddf, src1_ddq, ne10, 1, 1, src1_padded_col_size, src0->type, stream);
1845+
}
1846+
1847+
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, src1, src0_ids, dst,
1848+
(const char *)src0->data, (const float *)src1->data, src1_ddq, (float *)dst->data,
1849+
0, ne01, 1, src1_padded_col_size, stream);
1850+
CUDA_CHECK(cudaGetLastError());
1851+
1852+
}
1853+
*/
1854+
17681855
static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
17691856
GGML_ASSERT(!ggml_is_transposed(src0));
17701857
GGML_ASSERT(!ggml_is_transposed(src1));
@@ -2090,6 +2177,52 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
20902177
const ggml_tensor * src1 = dst->src[1];
20912178
const ggml_tensor * ids = dst->src[2];
20922179

2180+
if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 &&
2181+
ggml_is_quantized(src0->type) &&
2182+
ggml_backend_buffer_is_cuda(src0->buffer) &&
2183+
ggml_backend_buffer_is_cuda(src1->buffer) &&
2184+
ggml_backend_buffer_is_cuda(dst->buffer) &&
2185+
!ggml_backend_buffer_is_cuda_split(src0->buffer) &&
2186+
src1->type == GGML_TYPE_F32) {
2187+
int device_id = ctx.device;
2188+
ggml_backend_cuda_buffer_context * src0_ctx = (ggml_backend_cuda_buffer_context *) src0->buffer->context;
2189+
ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
2190+
ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
2191+
if (src0_ctx->device == device_id &&
2192+
src1_ctx->device == device_id &&
2193+
dst_ctx->device == device_id) {
2194+
GGML_ASSERT(src1->ne[0] % QK8_1 == 0);
2195+
// Fast TG path
2196+
const int64_t n_ids = ids->ne[0];
2197+
auto stream = ctx.stream(device_id, 0);
2198+
2199+
auto local_dst = *dst;
2200+
local_dst.ne[2] = n_ids;
2201+
local_dst.ne[1] = local_dst.ne[3] = 1;
2202+
local_dst.nb[2] = local_dst.nb[1];
2203+
2204+
auto local_src1 = *src1;
2205+
local_src1.nb[2] = local_src1.nb[3] = 0;
2206+
2207+
const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
2208+
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
2209+
auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1;
2210+
local_src1.data = src1_quantized.alloc(src_1_ddq_size);
2211+
quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size,
2212+
src0->type, stream);
2213+
CUDA_CHECK(cudaGetLastError());
2214+
2215+
local_src1.nb[1] = src_1_ddq_size;
2216+
2217+
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, &local_src1, ids, &local_dst,
2218+
(const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
2219+
0, src0->ne[1], 1, src1_padded_col_size, stream);
2220+
CUDA_CHECK(cudaGetLastError());
2221+
2222+
return;
2223+
}
2224+
}
2225+
20932226
GGML_TENSOR_BINARY_OP_LOCALS
20942227

20952228
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
@@ -2232,6 +2365,121 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
22322365
const ggml_tensor * src1 = dst->src[2];
22332366
const ggml_tensor * ids = dst->src[3];
22342367

2368+
if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 &&
2369+
ggml_is_quantized(src0_1->type) &&
2370+
ggml_is_quantized(src0_2->type) &&
2371+
ggml_backend_buffer_is_cuda(src0_1->buffer) &&
2372+
ggml_backend_buffer_is_cuda(src0_2->buffer) &&
2373+
ggml_backend_buffer_is_cuda(src1->buffer) &&
2374+
ggml_backend_buffer_is_cuda(dst->buffer) &&
2375+
!ggml_backend_buffer_is_cuda_split(src0_1->buffer) &&
2376+
!ggml_backend_buffer_is_cuda_split(src0_2->buffer) &&
2377+
src1->type == GGML_TYPE_F32) {
2378+
int device_id = ctx.device;
2379+
ggml_backend_cuda_buffer_context * src0_1_ctx = (ggml_backend_cuda_buffer_context *) src0_1->buffer->context;
2380+
ggml_backend_cuda_buffer_context * src0_2_ctx = (ggml_backend_cuda_buffer_context *) src0_2->buffer->context;
2381+
ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
2382+
ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
2383+
if (src0_1_ctx->device == device_id &&
2384+
src0_2_ctx->device == device_id &&
2385+
src1_ctx->device == device_id &&
2386+
dst_ctx->device == device_id) {
2387+
// Fast TG path
2388+
const int64_t n_ids = ids->ne[0];
2389+
auto stream = ctx.stream(device_id, 0);
2390+
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids);
2391+
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids);
2392+
2393+
auto local_dst = *dst;
2394+
local_dst.ne[2] = n_ids;
2395+
local_dst.ne[1] = local_dst.ne[3] = 1;
2396+
local_dst.nb[1] = local_dst.nb[2] = local_dst.nb[3] = local_dst.ne[0]*sizeof(float);
2397+
2398+
auto local_src1 = *src1;
2399+
local_src1.nb[2] = local_src1.nb[3] = 0;
2400+
2401+
const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
2402+
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
2403+
if (ggml_is_quantized(src0_1->type) || ggml_is_quantized(src0_2->type)) {
2404+
GGML_ASSERT(src1->ne[0] % QK8_1 == 0);
2405+
auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1;
2406+
local_src1.data = src1_quantized.alloc(src_1_ddq_size);
2407+
// Note: no use is currently made of the quantization type passed into quantize_row_q8_1_cuda.
2408+
// If that were to change, we would need to adjust the code to handle src0_1->type != src0_2->type
2409+
quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size,
2410+
src0_1->type, stream);
2411+
CUDA_CHECK(cudaGetLastError());
2412+
2413+
local_src1.nb[1] = src_1_ddq_size;
2414+
}
2415+
2416+
local_dst.data = dst_up_contiguous.get();
2417+
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst,
2418+
(const char *)src0_1->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_up_contiguous.get(),
2419+
0, src0_1->ne[1], 1, src1_padded_col_size, stream);
2420+
CUDA_CHECK(cudaGetLastError());
2421+
2422+
local_dst.data = dst_gate_contiguous.get();
2423+
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst,
2424+
(const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(),
2425+
0, src0_2->ne[1], 1, src1_padded_col_size, stream);
2426+
CUDA_CHECK(cudaGetLastError());
2427+
2428+
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
2429+
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
2430+
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
2431+
((ggml_backend_cuda_buffer_context *)next->src[0]->buffer->context)->device == device_id &&
2432+
ggml_backend_buffer_is_cuda(next->buffer) &&
2433+
((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) {
2434+
2435+
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids,
2436+
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
2437+
CUDA_CHECK(cudaGetLastError());
2438+
2439+
const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING);
2440+
GGML_ASSERT(dst->ne[0] % QK8_1 == 0);
2441+
auto dst_row_size = dst_padded_col_size*sizeof(block_q8_1)/QK8_1;
2442+
auto dst_ddq_size = n_ids*dst_row_size;
2443+
ggml_cuda_pool_alloc<char> dst_quantized(ctx.pool(), dst_ddq_size);
2444+
quantize_row_q8_1_cuda((const float *)dst_gate_contiguous.get(), (void *)dst_quantized.get(), dst->ne[0], n_ids, 1,
2445+
dst_padded_col_size, next->src[0]->type, stream);
2446+
CUDA_CHECK(cudaGetLastError());
2447+
2448+
std::vector<char> ids_host(ggml_nbytes(ids));
2449+
const char * ids_dev = (const char *) ids->data;
2450+
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
2451+
CUDA_CHECK(cudaStreamSynchronize(stream));
2452+
2453+
local_dst.ne[2] = 1;
2454+
2455+
auto local_next = *next;
2456+
local_next.ne[2] = local_next.ne[1];
2457+
local_next.ne[1] = local_next.ne[3] = 1;
2458+
local_next.nb[2] = local_next.nb[1];
2459+
2460+
local_src1 = *next->src[1];
2461+
local_src1.ne[1] = local_src1.ne[2] = local_src1.ne[3] = 1;
2462+
local_src1.nb[1] = local_src1.nb[2] = local_src1.nb[3] = dst_row_size;
2463+
2464+
auto local_src0 = *next->src[0];
2465+
local_src0.ne[2] = local_src0.ne[3] = 1;
2466+
2467+
ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next,
2468+
(const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data,
2469+
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
2470+
CUDA_CHECK(cudaGetLastError());
2471+
2472+
return true;
2473+
} else {
2474+
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
2475+
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
2476+
CUDA_CHECK(cudaGetLastError());
2477+
return false;
2478+
}
2479+
}
2480+
}
2481+
2482+
22352483
GGML_TENSOR_BINARY_OP_LOCALS
22362484

22372485
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers");
@@ -2299,49 +2547,47 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
22992547
if (fuse_down) {
23002548
final_dst.src[1] = &dst_row;
23012549
}
2302-
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2303-
for (int64_t id = 0; id < n_ids; id++) {
2304-
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2550+
for (int64_t id = 0; id < n_ids; id++) {
2551+
const int32_t i02 = *(const int32_t *) (ids_host.data() + id*ids->nb[0]);
23052552

2306-
if (i02 < 0 || i02 >= n_as) continue;
2307-
//GGML_ASSERT(i02 >= 0 && i02 < n_as);
2553+
if (i02 < 0 || i02 >= n_as) continue;
2554+
//GGML_ASSERT(i02 >= 0 && i02 < n_as);
23082555

2309-
const int64_t i11 = id % ne11;
2310-
const int64_t i12 = iid1;
2556+
const int64_t i11 = id % ne11;
2557+
const int64_t i12 = 0;
23112558

2312-
const int64_t i1 = id;
2313-
const int64_t i2 = i12;
2559+
const int64_t i1 = id;
2560+
const int64_t i2 = i12;
23142561

2315-
src0_1_row.data = src0_1_original + i02*nb02;
2316-
src0_2_row.data = src0_2_original + i02*nb02;
2317-
src1_row.data = src1_original + i11*nb11 + i12*nb12;
2318-
//dst_row.data = dst_original + i1*nb1 + i2*nb2;
2562+
src0_1_row.data = src0_1_original + i02*nb02;
2563+
src0_2_row.data = src0_2_original + i02*nb02;
2564+
src1_row.data = src1_original + i11*nb11 + i12*nb12;
2565+
//dst_row.data = dst_original + i1*nb1 + i2*nb2;
23192566

2320-
dst_row.data = dst_up_contiguous.get();
2321-
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
2322-
CUDA_CHECK(cudaGetLastError());
2567+
dst_row.data = dst_up_contiguous.get();
2568+
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
2569+
CUDA_CHECK(cudaGetLastError());
23232570

2324-
dst_row.data = dst_gate_contiguous.get();
2325-
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
2326-
CUDA_CHECK(cudaGetLastError());
2571+
dst_row.data = dst_gate_contiguous.get();
2572+
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
2573+
CUDA_CHECK(cudaGetLastError());
23272574

2328-
if (fuse_down) {
2329-
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
2330-
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
2331-
CUDA_CHECK(cudaGetLastError());
2575+
if (fuse_down) {
2576+
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
2577+
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
2578+
CUDA_CHECK(cudaGetLastError());
23322579

2333-
final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2];
2334-
final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2];
2335-
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst);
2336-
CUDA_CHECK(cudaGetLastError());
2580+
final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2];
2581+
final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2];
2582+
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst);
2583+
CUDA_CHECK(cudaGetLastError());
23372584

2338-
} else {
2585+
} else {
23392586

2340-
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
2341-
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2));
2342-
CUDA_CHECK(cudaGetLastError());
2587+
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
2588+
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2));
2589+
CUDA_CHECK(cudaGetLastError());
23432590

2344-
}
23452591
}
23462592
}
23472593
} else {

0 commit comments

Comments
 (0)