Skip to content

Commit 6533d5d

Browse files
committed
Add support for the Chat input to the MLXLM model
1 parent 849a969 commit 6533d5d

File tree

4 files changed

+150
-16
lines changed

4 files changed

+150
-16
lines changed

docs/features/models/llamacpp.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ prompt = Chat([
8181

8282
# Call the model to generate a response
8383
response = model(prompt, max_tokens=50)
84-
print(response) # 'This is a picture of a black dog.'
84+
print(response) # 'Riga.'
8585
```
8686

8787
#### Streaming

docs/features/models/mlxlm.md

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import mlx_lm
2929

3030
# Create the model
3131
model = outlines.from_mlxlm(
32-
*mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
32+
*mlx_lm.load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")
3333
)
3434
```
3535

@@ -45,14 +45,43 @@ import mlx_lm
4545

4646
# Load the model
4747
model = outlines.from_mlxlm(
48-
*mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
48+
*mlx_lm.load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")
4949
)
5050

5151
# Call it to generate text
5252
result = model("What's the capital of Latvia?", max_tokens=20)
5353
print(result) # 'Riga'
5454
```
5555

56+
#### Chat
57+
58+
You can use chat inputs with the `MLXLM` model. To do so, call the model with a `Chat` instance.
59+
60+
For instance:
61+
62+
```python
63+
import outlines
64+
import mlx_lm
65+
from outlines.inputs import Chat
66+
67+
# Load the model
68+
model = outlines.from_mlxlm(
69+
*mlx_lm.load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")
70+
)
71+
72+
# Create the prompt containing the text and the image
73+
prompt = Chat([
74+
{"role": "system", "content": "You are a helpful assistant."},
75+
{"role": "assistant", "content": "What's the capital of Latvia?"},
76+
])
77+
78+
# Call the model to generate a response
79+
response = model(prompt, max_tokens=50)
80+
print(response) # 'Riga.'
81+
```
82+
83+
#### Streaming
84+
5685
The `MLXLM` model also supports streaming. For instance:
5786

5887
```python
@@ -61,7 +90,7 @@ import mlx_lm
6190

6291
# Load the model
6392
model = outlines.from_mlxlm(
64-
*mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
93+
*mlx_lm.load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")
6594
)
6695

6796
# Stream text
@@ -73,7 +102,7 @@ for chunk in model.stream("Write a short story about a cat.", max_tokens=100):
73102

74103
As a local model, `MLXLM` supports all forms of structured generation available in Outlines.
75104

76-
### Basic Type
105+
#### Basic Type
77106

78107
```python
79108
import outlines
@@ -82,14 +111,14 @@ import mlx_lm
82111
output_type = int
83112

84113
model = outlines.from_mlxlm(
85-
*mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
114+
*mlx_lm.load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")
86115
)
87116

88117
result = model("How many countries are there in the world?", output_type)
89118
print(result) # '200'
90119
```
91120

92-
### JSON Schema
121+
#### JSON Schema
93122

94123
```python
95124
from pydantic import BaseModel
@@ -103,15 +132,15 @@ class Character(BaseModel):
103132
skills: List[str]
104133

105134
model = outlines.from_mlxlm(
106-
*mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
135+
*mlx_lm.load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")
107136
)
108137

109138
result = model("Create a character.", output_type=Character)
110139
print(result) # '{"name": "Evelyn", "age": 34, "skills": ["archery", "stealth", "alchemy"]}'
111140
print(Character.model_validate_json(result)) # name=Evelyn, age=34, skills=['archery', 'stealth', 'alchemy']
112141
```
113142

114-
### Multiple Choice
143+
#### Multiple Choice
115144

116145
```python
117146
from typing import Literal
@@ -121,14 +150,14 @@ import mlx_lm
121150
output_type = Literal["Paris", "London", "Rome", "Berlin"]
122151

123152
model = outlines.from_mlxlm(
124-
*mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
153+
*mlx_lm.load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")
125154
)
126155

127156
result = model("What is the capital of France?", output_type)
128157
print(result) # 'Paris'
129158
```
130159

131-
### Regex
160+
#### Regex
132161

133162
```python
134163
from outlines.types import Regex
@@ -138,14 +167,14 @@ import mlx_lm
138167
output_type = Regex(r"\d{3}-\d{2}-\d{4}")
139168

140169
model = outlines.from_mlxlm(
141-
*mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
170+
*mlx_lm.load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")
142171
)
143172

144173
result = model("Generate a fake social security number.", output_type)
145174
print(result) # '782-32-3789'
146175
```
147176

148-
### Context-Free Grammar
177+
#### Context-Free Grammar
149178

150179
```python
151180
from outlines.types import CFG
@@ -175,7 +204,7 @@ arithmetic_grammar = """
175204
output_type = CFG(arithmetic_grammar)
176205

177206
model = outlines.from_mlxlm(
178-
*mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
207+
*mlx_lm.load("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit")
179208
)
180209

181210
result = model("Write an addition.", output_type, max_tokens=20)

outlines/models/mlxlm.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import singledispatchmethod
44
from typing import TYPE_CHECKING, Iterator, List, Optional
55

6+
from outlines.inputs import Chat
67
from outlines.models.base import Model, ModelTypeAdapter
78
from outlines.models.transformers import TransformerTokenizer
89
from outlines.processors import OutlinesLogitsProcessor
@@ -17,6 +18,9 @@
1718
class MLXLMTypeAdapter(ModelTypeAdapter):
1819
"""Type adapter for the `MLXLM` model."""
1920

21+
def __init__(self, **kwargs):
22+
self.tokenizer = kwargs.get("tokenizer")
23+
2024
@singledispatchmethod
2125
def format_input(self, model_input):
2226
"""Generate the prompt argument to pass to the model.
@@ -34,13 +38,30 @@ def format_input(self, model_input):
3438
"""
3539
raise NotImplementedError(
3640
f"The input type {input} is not available with mlx-lm. "
37-
"The only available type is `str`."
41+
"The available types are `str` and `Chat`."
3842
)
3943

4044
@format_input.register(str)
4145
def format_str_input(self, model_input: str):
4246
return model_input
4347

48+
@format_input.register(Chat)
49+
def format_chat_input(self, model_input: Chat) -> str:
50+
if not all(
51+
isinstance(message["content"], str)
52+
for message in model_input.messages
53+
):
54+
raise ValueError(
55+
"mlx-lm does not support multi-modal messages."
56+
+ "The content of each message must be a string."
57+
)
58+
59+
return self.tokenizer.apply_chat_template(
60+
model_input.messages,
61+
tokenize=False,
62+
add_generation_prompt=True,
63+
)
64+
4465
def format_output_type(
4566
self, output_type: Optional[OutlinesLogitsProcessor] = None,
4667
) -> Optional[List[OutlinesLogitsProcessor]]:
@@ -92,7 +113,7 @@ def __init__(
92113
self.mlx_tokenizer = tokenizer
93114
# self.tokenizer is used by the logits processor
94115
self.tokenizer = TransformerTokenizer(tokenizer._tokenizer)
95-
self.type_adapter = MLXLMTypeAdapter()
116+
self.type_adapter = MLXLMTypeAdapter(tokenizer=tokenizer)
96117

97118
def generate(
98119
self,
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import pytest
2+
import io
3+
4+
from outlines_core import Index, Vocabulary
5+
from PIL import Image as PILImage
6+
7+
from outlines.backends.outlines_core import OutlinesCoreLogitsProcessor
8+
from outlines.inputs import Chat, Image
9+
from outlines.models.mlxlm import MLXLMTypeAdapter
10+
11+
try:
12+
import mlx_lm
13+
import mlx.core as mx
14+
15+
HAS_MLX = mx.metal.is_available()
16+
except ImportError:
17+
HAS_MLX = False
18+
19+
20+
MODEL_NAME = "mlx-community/SmolLM-135M-Instruct-4bit"
21+
22+
23+
@pytest.fixture
24+
def adapter():
25+
_, tokenizer = mlx_lm.load(MODEL_NAME)
26+
return MLXLMTypeAdapter(tokenizer=tokenizer)
27+
28+
29+
@pytest.fixture
30+
def logits_processor():
31+
vocabulary = Vocabulary.from_pretrained(MODEL_NAME)
32+
index = Index(r"[0-9]{3}", vocabulary)
33+
return OutlinesCoreLogitsProcessor(index, "mlx")
34+
35+
36+
@pytest.fixture
37+
def image():
38+
width, height = 1, 1
39+
white_background = (255, 255, 255)
40+
image = PILImage.new("RGB", (width, height), white_background)
41+
buffer = io.BytesIO()
42+
image.save(buffer, format="PNG")
43+
buffer.seek(0)
44+
image = PILImage.open(buffer)
45+
46+
return image
47+
48+
49+
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
50+
def test_mlxlm_type_adapter_format_input(adapter, image):
51+
# Anything else than a string/Chat (invalid)
52+
with pytest.raises(NotImplementedError):
53+
adapter.format_input(["Hello, world!"])
54+
55+
# String
56+
assert adapter.format_input("Hello, world!") == "Hello, world!"
57+
58+
# Chat
59+
messages = [
60+
{"role": "user", "content": "Hello, world!"},
61+
{"role": "assistant", "content": "Hello, world!"},
62+
]
63+
expected = (
64+
"<|im_start|>user\nHello, world!<|im_end|>\n<|im_start|>assistant\n"
65+
+ "Hello, world!<|im_end|>\n<|im_start|>assistant\n"
66+
)
67+
assert adapter.format_input(Chat(messages=messages)) == expected
68+
69+
# Multi-modal (invalid)
70+
with pytest.raises(
71+
ValueError,
72+
match="mlx-lm does not support multi-modal messages."
73+
):
74+
adapter.format_input(Chat(messages=[
75+
{"role": "user", "content": ["prompt", Image(image)]},
76+
]))
77+
78+
79+
@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
80+
def test_mlxlm_type_adapter_format_output_type(adapter, logits_processor):
81+
formatted = adapter.format_output_type(logits_processor)
82+
assert isinstance(formatted, list)
83+
assert len(formatted) == 1
84+
assert isinstance(formatted[0], OutlinesCoreLogitsProcessor)

0 commit comments

Comments
 (0)