Skip to content

Conversation

DePasqualeOrg
Copy link
Contributor

@DePasqualeOrg DePasqualeOrg commented Jun 13, 2025

I've separated the KV cache part out of my Gemma 3 PR. I'm testing the quantized KV cache on Qwen 3. If we go with this approach, we'll need to update all the models to use attentionWithCacheUpdate as Qwen 3 does here.

The Python API is more elegant because it allows more dynamic typing. In Swift we need this wrapper function due to the requirements of the KVCache protocol.

You can test inference with and without the quantized KV cache with these arguments in llm-tool:

--model mlx-community/Qwen3-1.7B-4bit --prompt "Explain quantum computing in simple terms" --max-tokens 100 --kv-bits 4
--model mlx-community/Qwen3-1.7B-4bit --prompt "Explain quantum computing in simple terms" --max-tokens 100

For short sequences like this, using the quantized KV cache is actually slower, but it should be more efficient for much longer sequences.


Some preliminary test results: Qwen3-1.7B-4bit using 4-bit quantization for the KV cache results in the model getting stuck in repetitive loops. The same model with 8-bit quantization for the KV cache works well, but even at 1200 tokens it's still about 8% slower than when using the non-quantized KV cache. Maybe someone who's interested in using a quantized KV cache can do some more testing on other models with much longer sequence lengths (cc @mzbac).

@DePasqualeOrg DePasqualeOrg mentioned this pull request Jun 13, 2025
@mzbac
Copy link
Contributor

mzbac commented Jun 13, 2025

Thanks for the great work, I also noticed that the 4bit kv cache can cause some performance degradation, especially for the thinking model. I did some tests earlier for my implementation and found that memory savings become noticeable during token generation up to around 2,000 or 4,000 tokens. However, I didn't track the speed, will try to test your implementation once I get a chance

/// - cache: The model cache state
/// - metadata: Optional metadata to save along with cache state
public func savePromptCache(
fileName: String,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should be URL? e.g. if you are constructing a path to the caches directory this would normally be a url.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense.

flattenedData["__metadata_user_value_\(i)"] = MLXArray(valueBytes.map { Int32($0) })
}
}
flattenedData["__metadata_user_count"] = MLXArray([metadata.count])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the metadata comes back as [String:String] below, why not use: public func saveToData( arrays: [String: MLXArray], metadata: [String: String] = [:])? It suports a metadata dictionary directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've improved this in the commit that I'll add in a moment. Please check it, since I'm not familiar with the usage.

/// - Returns: The prompt cache and optionally the metadata
public func loadPromptCache(
fileName: String,
returnMetadata: Bool = false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would you not want to return this? The return value always includes it (though it might be empty)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mirrored the behavior in Python, but I think you're right that it makes sense to consistently return it.

/// - cache: Array of KV caches to potentially quantize
/// - kvBits: Number of bits for quantization (nil = no quantization)
/// - kvGroupSize: Group size for quantization
/// - quantizedKVStart: Step to begin quantizing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with the typical use of this -- the step is the token offset? So if you have a long prompt you will switch to quantized before you start evaluating the response -- that is the intent?

Copy link
Contributor Author

@DePasqualeOrg DePasqualeOrg Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mirrors maybe_quantize_kv_cache in mlx-lm. The comment was misleading, and in fact quantizedKVStart refers to the token count.

cache[i] = simpleCache.toQuantized(groupSize: kvGroupSize, bits: kvBits)
}
// Note: RotatingKVCache.toQuantized() is not implemented yet like in Python
// When implemented, add: else if let rotatingCache = cache[i] as? RotatingKVCache { ... }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this may support rotating caches in the future I wonder if the name of the function should be more generic? Less about quantization -- maybe something about a step did complete?

Is this something someone might want to customize? Should this be a protocol or closure passed in to the iterator? I guess we can decide that later if needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mirrors the implementation in mlx-lm.

private var keep: Int
private var keys: MLXArray?
private var values: MLXArray?
// TODO: `offset` from the Python implementation is not implemented here. Do we need it?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meaning offset is always 0 from the base?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I see offset += S below so maybe I am misunderstanding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. I was confused.

@DePasqualeOrg
Copy link
Contributor Author

maxSize in RotatingKVCache is optional in the Python implementation, but I'm not sure this makes sense, since a rotating KV cache should always have a maximum size. I've made it required in the Swift implementation. Is there a good reason why this is optional in Python?

Copy link
Collaborator

@davidkoski davidkoski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, ran integration tests and all was well. Thank you!

@davidkoski davidkoski merged commit f7da396 into ml-explore:main Jun 24, 2025
4 checks passed
@davidkoski davidkoski mentioned this pull request Jun 24, 2025
@DePasqualeOrg
Copy link
Contributor Author

DePasqualeOrg commented Jun 25, 2025

If we go with this approach, we'll need to update all the models to use attentionWithCacheUpdate as Qwen 3 does here.

@davidkoski, I think we still need to update the other models so that they can use the cache routing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants