-
Notifications
You must be signed in to change notification settings - Fork 140
Option to use MLA without a transposed cache #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The `-mla` command line option turns into an int from a bool. mla = 0: use standard attention mla = 1: use MLA with transposed cache mla > 1: use MLA without transposed cache
Hey, thank you for your work on this. Trying to run with -mla 2, but still getting a 8900MB allocation per card I believe for the compute buffer. I'm not sure if this is correct, or am I doing something wrong with my run commands (I'm aware the layers are poorly balanced atm, but just wondering if this is as expected: Command:
Log:
Would really appreciate your help to see if I'm doing something wrong. Thank you! |
Was able to run this with 20K ctx, but not sure if this amount of compute buffer is still correct:
|
So, when I wrote the PR description I had forgotten that it is not yet possible to transpose quantized cache, which would be needed if we wanted to use I haven't looked in detail into the compute buffers on CUDA. I wouldn't have expected 5.7 GiB per GPU, this seems way too much. But I also don't have access to a multi-GPU box, so have never played with that. It looks like each GPU is allocating the same compute buffer as if the computation was running on a single GPU. Oh, I'm working on transposed matrix multiplications. Hope to be done in a day or two. It will then be possible to fully quantize the cache with |
Incredible, that makes sense. The cache using fp16 isn't a huge problem, to be honest. Also, yes, the 15 gpu build (trying to find a 16th for TP!) has been a lot of pain, so to see the speed increase on this, and longer context, is really promising. So thank you for all of your hard work. For these compute buffers, is there anything I can do to reduce it to the expected amount? |
@davidsyoung Have you tried using |
I realised that actually, and briefly tried it, but didn't try a long prefill, and now I'm back to trying to figure out this compute buffer first. As the model is so large, it takes quite some time to load it! I will try the fmoe and report back! |
I need to look into this. Have you tried |
So I tried to change the following: before:
to
and the kv size seem right at
However, the compute buffer is now trying to allocate I believe this is I will try |
Yes, the compute buffer size is proportional to the micro batch (ub) size. Typically performance first increases with increasing |
That makes sense, will try playing with micro batch size and batch size. Also waiting to see if split mode row makes a difference. It might seem to more evenly split the layers across the cards which would be useful as sometimes it doesn't divide evenly into the gpu's, and as a result it limits the kvcache (previously at least), and some cards would have a few GB remaining but without much you could do as some GPUs were right at the limit. If this does work, and with the new MLA implementation, do you think there's a sweeet spot for type of quant to use for R1 in terms of quality etc. |
Just tried with DeepSeek-Lite. For a context of 32k tokens the CUDA compute buffer size is 1172 MiB with default batch/u-batch size. If I use |
Is that behaving as expected for you when you see that? I can't tell if I should see similar amounts, or is what I'm seeing correct for the model size. |
|
Do you quantize the model yourself or do you download a quantized model from somewhere? For DeepSeek it seems it is important to use more bits for the attention tensors and the shared experts. As most of the size is in the MoE experts this does not lead to a very significant increase in model size. After that you go with the highest bpw for the MoE experts that you can fit into VRAM (after deducting KV cache and compute buffers). But all of this are just guesses as I have never tried DeepSeekV3/R1 myself. |
Makes sense. Thank you. I am currently using https://huggingface.co/gghfez/DeepSeek-R1-11446-Q2_K, but now that it seems I'll be able to unlock a good bit of VRAM with your implementation (thank you), I may venture into trying to trying to quantize the model myself with a IQ3_XXS. It really depends on finding a sweet spot with this compute buffer! Thank you for all of your help/work, it's massively appreciated. |
Doing some testing with different batch sizes, micro-batch sizes and context. Test 1: At
I see pretty good performance overall. I have seen 140~ prefill before but I believe that was without MLA.
Test 2:
Test 3:
While the KV cache at max context of 163k is reasonable The compute buffer goes pretty insane per GPU: So I'm not too sure what's up with the compute buffer. Maybe this is just the size of it given the size of the model. But allocating 42.8GB per gpu, across 15 gpu's would be 642GB VRAM just for compute buffer. Definitely seems a magnitude out, but I'm also really not sure what I'm taking about! |
The model size might not go up significantly but the performance does noticeably go down if you do that strategy as those weights are always used unlike the expert weights, this may not matter as much with them being on CUDA but from another user's reports on llama.cpp who was offloading those to CUDA they still had a performance hit. For me on CPU only inference IQ4_K_R4 V2 is slower than V1 with 2.63 t/s for V2 vs 3.22 t/s V1. Here's a table of early perplexity values I've collected for various quants of Deepseek.
**For these quants in the format A/B/C (also imatrix is Bartowski imatrix for experts only)
My V1/V2/V3, I employ a strategy similar to the one above but FAR less aggresive, slightly increasing the size of the model but IMO the performance difference was not worth it (that might change with hybrid/full offload). All tensors for mine were imatrixed with mradermacher imatrix except for the new split tensor. Also for reference here is some compute buffer sizes I've seen (with an earlier build and default ubatch size): n_ctx = 128000 |
I may have to start experimenting with quants myself, this is really useful. For the compute buffers, would you happen to know what batch/micro batch sizes were set to? I’m getting a total of 67GB for 32k context. It would be nice if I could claw back some some how… |
Let me know if you do, as you can tell I'm collecting info on that. Also if you do want to easily benchmark and plot performance across your full context window for both TG and PP you can use the sweep-bench example I recently ported over to ik_llama.cpp
n_batch = 2048
I agree, that would be nice, I'm also curious as to why the split-mode row doesn't work. I've never run a setup with it but I've seen other it giving nice performance gains. For now I'm still stuck on CPU only, I did work a bit on porting the RPC updates to support it (and other models and cache quantization for models that were already supported) so that I can run hybrid CPU+GPU over RPC but I'm running into issues that I don't really understand. |
So, based on this discussion, reducing compute buffer size is by far more important than reducing KV cache size. I'll see if I can do something about that.
Don't think about the fact that there are 15 GPUs. With per layer model split, each GPU needs to compute a full layer, so each GPU needs the exact same compute buffer as if the entire model was running on it (I.e., if you had a single GPU with enough VRAM to fit the entire model, the compute buffer will still be 42.8 GB and not 15 x 42.8 GB). Why does the compute buffer become 42.8 GB for 160k context? There is the If you use flash attention ( |
The model was only trained to supports a context length of 128k (131,072). The huggingface and github page both list 128k as the context length as well, so I'm not sure why the original Deepseek V3/R1 config.json (that the GGUF pulls the metadata from) says 160k (163,840). |
That makes way more sense. Thank you. Would split mode row, if it worked, be a solution/help with this? I tried to look into the assert that came up, but wasn't able to understand to resolve myself. I however, have tested
The performance mirrors the previous share above with tokens (with So there's a relatively decent drop that does have an impact on usability, but it does unlock 19k~ new max tokens. Would there be any other optimisation that I could use that would improve the prefill time? Increasing pipeline parallelism, or anything like that? I don't fully understand that to know myself. It doesn't seem to be affected by batch size either. |
Use If you change
to
in |
I'm attempting to run llama-bench but it's trying to allocate the full model to device zero, even though I've set tensor splits.
|
Well, not sure why But I think you will like PR #237 very much. Simply add
to your command line, and the compute buffers should be no more than 3 GiB even for a context of 163k tokens! |
Holy shit. Will report back! |
The
-mla
(or--mla-use
) command line option turns from previously a boolean value to an integer:mla = 0
: use standard attentionmla = 1
: use MLA with transposed cache - this is the existing MLA implementationmla = 2
: use MLA without transposed cache - this is the option added by this PRWhy do we need this? Apparently many people are interested in using the maximum context length of long context models. For DeepSeekV3/R1, the rage of the day, it is 163k tokens. This requires a lot of RAM/VRAM. Let's take a look:
n_layer * (3072 * sizeof(K cache element) + 2048 * sizeof(V cache element))
. For DeepSeekV3/R1 this works out to 610 kB per token when usingfp16
cache. ForQ8_0
K and V cache it is 324 kB per token, but this requires FA, so CPU-only inference (CUDA does not support FA with different K and V head sizes as found in the DeepSeek models). So, for GPU or mixed CPU/GPU inference the best one can do isQ8_0
for K cache andf16
for V cache, so 438.4 kB per token.n_layer * (576 * sizeof(K cache element) + 512 * sizeof(V cache element))
. For DeepSeekV3/R1 this works out to 129.6 kB per token forfp16
cache. When using MLA the V cache is transposed, so quantization is not possible at all, so the best one can do isQ8_0
for K cache andfp16
for V cache. This results in 97.5 kB per tokenn_layer * 576 * sizeof(K cache element)
, so 68.6 kB per token withfp16
cache and 36.5 kB per token withQ8_0
cache.I.e., for GPU-only or hybrid GPU/CPU inference, where VRAM is the limiting factor (unless one keeps the cache on the host and copies it to the GPU as needed, but this would make performance much lower), the new option added by the PR uses 12X less KV cache memory than standards attention and 2.7X less than the existing MLA implementation. For a context of 163k tokens the memory required will be 5.67 GiB.
The down side of this is that one has to transpose the K cache during inference (
ggml
, despite representing itself as a general purpose ML library, lacks the ability to perform transposed matrix multiplications, and I haven't come around to add this ability to my fork). This adds an additional computation and requires an extra compute buffer (to hold the contiguous transposed copy of the entire K cache for one layer). The size of this extra buffer can be computed asn_token * 512 * sizeof(float) = 318 MiB
for 163k tokens, so this should not be a serious limitation. But the additional operation that copies the transposed K cache into contiguous memory may result in a significant performance penalty, so let's look at that. As I don't have the ability to run DeepSeekV3/R1, I'm using for the performance comparisons below. DeepSeek-Lite has the same architecture as DeepSeekV3/R1 with fewer parameters (16B, MoE, 64 experts, 6 used experts, exat same attention tensor sizes as DeepSeekV3/R1).Note: at this point
ggml
does not support transposing quantized data, so formla = 2
the K cache must befp16
orbf16
. Hence, the above analysis for quantized cache withmla = 2
will only apply when I have come around to implement transposing a quantized cache.Hybrid GPU/CPU inference
The GPU is RTX-4080, the CPU is Ryzen-7950X. Experts are kept on the CPU, all other tensors are offloaded to the GPU.
I.e., for prompt processing (a.k.a. "prefill") MLA is very slightly slower than standard attention, but there is not real difference between
mla = 1
andmla = 2
added by this PR.For token generation (TG) I use the
-gp
option inllama-bench
to evaluate TG performance as a function of the number of tokens in the KV cache. Here are the results:I.e., for short contexts
mla = 2
is about on par withmla = 1
. As the context grows it becomes slower due to the added cost of transposing the K cache, but it is still better than standard attention (mla = 0
) at 8k tokens.CPU only inference
I.e., when running only on the CPU MLA is significantly slower than standard attention for prompt processing, but there is no real difference between
mla = 1
andmla = 2
.Here
mla = 2
is much slower thanmla = 1
for long contexts, and about on par with standard attention (mla = 0
). Looking at the code inggml_compute_forward_dup_bytes
, which gets invoked to copy the transposed K cache data to contiguous memory, it is pretty much as inefficient as it gets. But I leave this for a follow up PR.