-
Notifications
You must be signed in to change notification settings - Fork 309
Add missing KV cache functionality #334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for the great work, I also noticed that the 4bit kv cache can cause some performance degradation, especially for the thinking model. I did some tests earlier for my implementation and found that memory savings become noticeable during token generation up to around 2,000 or 4,000 tokens. However, I didn't track the speed, will try to test your implementation once I get a chance |
Libraries/MLXLMCommon/KVCache.swift
Outdated
/// - cache: The model cache state | ||
/// - metadata: Optional metadata to save along with cache state | ||
public func savePromptCache( | ||
fileName: String, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this should be URL? e.g. if you are constructing a path to the caches directory this would normally be a url.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that makes sense.
Libraries/MLXLMCommon/KVCache.swift
Outdated
flattenedData["__metadata_user_value_\(i)"] = MLXArray(valueBytes.map { Int32($0) }) | ||
} | ||
} | ||
flattenedData["__metadata_user_count"] = MLXArray([metadata.count]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the metadata comes back as [String:String]
below, why not use: public func saveToData( arrays: [String: MLXArray], metadata: [String: String] = [:])
? It suports a metadata dictionary directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I've improved this in the commit that I'll add in a moment. Please check it, since I'm not familiar with the usage.
Libraries/MLXLMCommon/KVCache.swift
Outdated
/// - Returns: The prompt cache and optionally the metadata | ||
public func loadPromptCache( | ||
fileName: String, | ||
returnMetadata: Bool = false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would you not want to return this? The return value always includes it (though it might be empty)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mirrored the behavior in Python, but I think you're right that it makes sense to consistently return it.
Libraries/MLXLMCommon/KVCache.swift
Outdated
/// - cache: Array of KV caches to potentially quantize | ||
/// - kvBits: Number of bits for quantization (nil = no quantization) | ||
/// - kvGroupSize: Group size for quantization | ||
/// - quantizedKVStart: Step to begin quantizing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not familiar with the typical use of this -- the step is the token offset? So if you have a long prompt you will switch to quantized before you start evaluating the response -- that is the intent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mirrors maybe_quantize_kv_cache
in mlx-lm. The comment was misleading, and in fact quantizedKVStart
refers to the token count.
cache[i] = simpleCache.toQuantized(groupSize: kvGroupSize, bits: kvBits) | ||
} | ||
// Note: RotatingKVCache.toQuantized() is not implemented yet like in Python | ||
// When implemented, add: else if let rotatingCache = cache[i] as? RotatingKVCache { ... } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this may support rotating caches in the future I wonder if the name of the function should be more generic? Less about quantization -- maybe something about a step did complete?
Is this something someone might want to customize? Should this be a protocol or closure passed in to the iterator? I guess we can decide that later if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mirrors the implementation in mlx-lm.
Libraries/MLXLMCommon/KVCache.swift
Outdated
private var keep: Int | ||
private var keys: MLXArray? | ||
private var values: MLXArray? | ||
// TODO: `offset` from the Python implementation is not implemented here. Do we need it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meaning offset is always 0 from the base?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but I see offset += S
below so maybe I am misunderstanding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right. I was confused.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, ran integration tests and all was well. Thank you!
@davidkoski, I think we still need to update the other models so that they can use the cache routing. |
I've separated the KV cache part out of my Gemma 3 PR. I'm testing the quantized KV cache on Qwen 3. If we go with this approach, we'll need to update all the models to use
attentionWithCacheUpdate
as Qwen 3 does here.The Python API is more elegant because it allows more dynamic typing. In Swift we need this wrapper function due to the requirements of the
KVCache
protocol.You can test inference with and without the quantized KV cache with these arguments in llm-tool:
--model mlx-community/Qwen3-1.7B-4bit --prompt "Explain quantum computing in simple terms" --max-tokens 100 --kv-bits 4
--model mlx-community/Qwen3-1.7B-4bit --prompt "Explain quantum computing in simple terms" --max-tokens 100
For short sequences like this, using the quantized KV cache is actually slower, but it should be more efficient for much longer sequences.
Some preliminary test results: Qwen3-1.7B-4bit using 4-bit quantization for the KV cache results in the model getting stuck in repetitive loops. The same model with 8-bit quantization for the KV cache works well, but even at 1200 tokens it's still about 8% slower than when using the non-quantized KV cache. Maybe someone who's interested in using a quantized KV cache can do some more testing on other models with much longer sequence lengths (cc @mzbac).