Skip to content

Commit 4a25965

Browse files
authored
ENH: Continuous batching supports vision model ability (#1724)
1 parent 48fb744 commit 4a25965

File tree

11 files changed

+859
-161
lines changed

11 files changed

+859
-161
lines changed

xinference/core/model.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class _OutOfMemoryError(Exception):
6565
OutOfMemoryError = _OutOfMemoryError
6666

6767

68+
XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = ["qwen-vl-chat", "cogvlm2", "glm-4v"]
69+
70+
6871
def request_limit(fn):
6972
"""
7073
Used by ModelActor.
@@ -268,11 +271,25 @@ def allow_batching(self) -> bool:
268271

269272
model_ability = self._model_description.get("model_ability", [])
270273

271-
return (
272-
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
273-
and isinstance(self._model, PytorchModel)
274-
and "vision" not in model_ability
274+
condition = XINFERENCE_TRANSFORMERS_ENABLE_BATCHING and isinstance(
275+
self._model, PytorchModel
275276
)
277+
if condition and "vision" in model_ability:
278+
if (
279+
self._model.model_family.model_name
280+
in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
281+
or self._model.model_family.model_family
282+
in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
283+
):
284+
return True
285+
else:
286+
logger.warning(
287+
f"Currently for multimodal models, "
288+
f"xinference only supports {', '.join(XINFERENCE_BATCHING_ALLOWED_VISION_MODELS)} for batching. "
289+
f"Your model {self._model.model_family.model_name} with model family {self._model.model_family.model_family} is disqualified."
290+
)
291+
return False
292+
return condition
276293

277294
async def load(self):
278295
self._model.load()

xinference/core/scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(self, prompt, future_or_queue, is_prefill, *args, **kwargs):
8282
# Record error message when this request has error.
8383
# Must set stopped=True when this field is set.
8484
self.error_msg: Optional[str] = None
85+
# For compatibility. Record some extra parameters for some special cases.
86+
self.extra_kwargs = {}
8587

8688
# check the integrity of args passed upstream
8789
self._check_args()

xinference/core/tests/test_continuous_batching.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,10 @@ def run_internal(self):
7171
assert isinstance(res, dict)
7272
choices = res["choices"]
7373
assert isinstance(choices, list)
74-
choice = choices[0]["text"]
75-
assert isinstance(choice, str)
76-
assert len(choice) > 0
74+
choice = choices[0]["message"]
75+
assert isinstance(choice, dict)
76+
content = choice["content"]
77+
assert len(content) > 0
7778

7879

7980
class InferenceThreadWithError(InferenceThread):

xinference/model/llm/llm_family.json

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@
944944
"none"
945945
],
946946
"model_id": "THUDM/glm-4v-9b",
947-
"model_revision": "e8b84fefc07e58a90c8489337675573fda95e289"
947+
"model_revision": "6c2e4732db8443f64a48d5af04b74425a7d169c4"
948948
}
949949
],
950950
"prompt_style": {
@@ -5913,6 +5913,16 @@
59135913
"roles": [
59145914
"user",
59155915
"assistant"
5916+
],
5917+
"stop_token_ids": [
5918+
151643,
5919+
151644,
5920+
151645
5921+
],
5922+
"stop": [
5923+
"<|endoftext|>",
5924+
"<|im_start|>",
5925+
"<|im_end|>"
59165926
]
59175927
}
59185928
},

xinference/model/llm/llm_family_modelscope.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3510,6 +3510,16 @@
35103510
"roles": [
35113511
"user",
35123512
"assistant"
3513+
],
3514+
"stop_token_ids": [
3515+
151643,
3516+
151644,
3517+
151645
3518+
],
3519+
"stop": [
3520+
"<|endoftext|>",
3521+
"<|im_start|>",
3522+
"<|im_end|>"
35133523
]
35143524
}
35153525
},

xinference/model/llm/pytorch/chatglm.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,6 @@ def _stream_generator():
336336
),
337337
)
338338

339-
@staticmethod
340-
def require_attention_mask():
341-
"""
342-
GLM4 needs to use attention mask and position ids during inference.
343-
Otherwise, the inference result would be not available.
344-
"""
345-
return True
346-
347339
def prepare_sanitize_generate_config(self, req: InferenceRequest):
348340
"""
349341
Set temperature and top_p to 0.8 by default

xinference/model/llm/pytorch/cogvlm2.py

Lines changed: 206 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424
from PIL import Image
2525

26+
from ....core.scheduler import InferenceRequest
2627
from ....model.utils import select_device
2728
from ....types import (
2829
ChatCompletion,
@@ -35,11 +36,30 @@
3536
)
3637
from ..llm_family import LLMFamilyV1, LLMSpecV1
3738
from .core import PytorchChatModel, PytorchGenerateConfig
39+
from .utils import get_max_src_len
3840

3941
logger = logging.getLogger(__name__)
4042

41-
IMAGENET_MEAN = (0.485, 0.456, 0.406)
42-
IMAGENET_STD = (0.229, 0.224, 0.225)
43+
44+
LANGUAGE_TOKEN_TYPE = 0
45+
VISION_TOKEN_TYPE = 1
46+
47+
48+
def recur_move_to(item, tgt, criterion_func):
49+
"""
50+
This function is copied from https://github.com/THUDM/CogVLM2/blob/main/basic_demo/cli_demo_batch_inference.py
51+
"""
52+
if criterion_func(item):
53+
device_copy = item.to(tgt)
54+
return device_copy
55+
elif isinstance(item, list):
56+
return [recur_move_to(v, tgt, criterion_func) for v in item]
57+
elif isinstance(item, tuple):
58+
return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
59+
elif isinstance(item, dict):
60+
return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
61+
else:
62+
return item
4363

4464

4565
class CogVLM2Model(PytorchChatModel):
@@ -176,11 +196,33 @@ def _image_to_piexl_values(image):
176196
content["image_url"]["url"]
177197
)
178198
assistant = chat_history[i + 1]["content"]
179-
query = query + f" USER: {user} ASSISTANT:"
180-
history.append((query, assistant))
181-
query = query + f" {assistant}"
199+
history.append((user, assistant))
200+
query = assistant # type: ignore
182201
return query, history, [pixel_values]
183202

203+
def get_query_and_history(
204+
self,
205+
prompt: Union[str, List[Dict]],
206+
system_prompt: Optional[str] = None,
207+
chat_history: Optional[List[ChatCompletionMessage]] = None,
208+
):
209+
content, image = self._message_content_to_cogvlm2(prompt)
210+
211+
history = []
212+
history_image = None
213+
if chat_history:
214+
query, history, history_image = self._history_content_to_cogvlm2(
215+
system_prompt, chat_history # type: ignore
216+
)
217+
218+
if image and history_image:
219+
history = []
220+
query = content
221+
else:
222+
image = image if image else history_image
223+
query = content
224+
return query, image, history
225+
184226
def chat(
185227
self,
186228
prompt: Union[str, List[Dict]],
@@ -198,22 +240,9 @@ def chat(
198240
else 512,
199241
}
200242

201-
content, image = self._message_content_to_cogvlm2(prompt)
202-
203-
history = []
204-
query = ""
205-
history_image = None
206-
if chat_history:
207-
query, history, history_image = self._history_content_to_cogvlm2(
208-
system_prompt, chat_history
209-
)
210-
211-
if image and history_image:
212-
history = []
213-
query = system_prompt + f" USER: {content} ASSISTANT:"
214-
else:
215-
image = image if image else history_image
216-
query = query + f" USER: {content} ASSISTANT:"
243+
query, image, history = self.get_query_and_history(
244+
prompt, system_prompt=system_prompt, chat_history=chat_history
245+
)
217246

218247
input_by_model = self._model.build_conversation_input_ids(
219248
self._tokenizer,
@@ -319,3 +348,159 @@ def _streaming_chat_response(
319348
),
320349
)
321350
yield chunk
351+
352+
@staticmethod
353+
def build_position_ids(x, attention_mask=None):
354+
"""
355+
Copied from https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B-int4/blob/main/modeling_cogvlm.py
356+
"""
357+
# Fix: 参考官方开源代码
358+
if attention_mask is not None:
359+
tmp = x.clone()
360+
tmp[~(attention_mask.bool())] = -1
361+
else:
362+
tmp = x.clone()
363+
# image boi eoi token as LANGUAGE_TOKEN_TYPE
364+
is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
365+
is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
366+
tmp[:, :-1] == LANGUAGE_TOKEN_TYPE
367+
)
368+
is_boi_eoi[:, 0] |= tmp[:, 0] == VISION_TOKEN_TYPE
369+
is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
370+
tmp[:, 1:] == LANGUAGE_TOKEN_TYPE
371+
)
372+
is_boi_eoi[:, -1] |= tmp[:, -1] == VISION_TOKEN_TYPE
373+
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
374+
# final position ids
375+
y = torch.zeros_like(x, dtype=torch.long)
376+
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
377+
(tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
378+
)
379+
y = y.cumsum(dim=-1)
380+
return y
381+
382+
def get_dtype(self):
383+
return self._torch_type
384+
385+
def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
386+
query, image, history = self.get_query_and_history(
387+
prompt, system_prompt=system_prompt, chat_history=chat_history
388+
)
389+
390+
input_by_model: dict = self._model.build_conversation_input_ids(
391+
self._tokenizer,
392+
query=query,
393+
history=history,
394+
images=image,
395+
template_version="chat",
396+
)
397+
return {
398+
"input_ids": input_by_model["input_ids"], # seq_len
399+
"token_type_ids": input_by_model["token_type_ids"], # seq_len
400+
"attention_mask": input_by_model["attention_mask"], # seq_len
401+
"images": input_by_model["images"],
402+
}
403+
404+
def prepare_sanitize_generate_config(self, req: InferenceRequest):
405+
"""
406+
See https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B/blob/main/generation_config.json
407+
"""
408+
raw_config = req.inference_kwargs.get("raw_params", {})
409+
temperature = raw_config.get("temperature", None)
410+
if temperature is None:
411+
raw_config["temperature"] = 0.6
412+
top_p = raw_config.get("top_p", None)
413+
if top_p is None:
414+
raw_config["top_p"] = 0.9
415+
return raw_config
416+
417+
def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
418+
context_len = self.get_context_len()
419+
assert isinstance(prompts[0], dict)
420+
images = []
421+
max_length = float("-inf")
422+
for i, feature in enumerate(prompts):
423+
req = req_list[i]
424+
if "images" in feature:
425+
images.append(feature.pop("images", None))
426+
max_src_len = get_max_src_len(context_len, req)
427+
input_ids = feature["input_ids"][-max_src_len:]
428+
req.prompt_tokens = input_ids.tolist()
429+
feature["input_ids"] = input_ids
430+
feature["token_type_ids"] = feature["token_type_ids"][-max_src_len:]
431+
feature["attention_mask"] = feature["attention_mask"][-max_src_len:]
432+
req.extra_kwargs["attention_mask_seq_len"] = feature[
433+
"attention_mask"
434+
].shape[0]
435+
max_length = max(len(input_ids), max_length)
436+
437+
def pad_to_max_length_internal(feature, max_len, idx):
438+
padding_length = max_len - len(feature["input_ids"])
439+
req_list[idx].padding_len = padding_length
440+
feature["input_ids"] = torch.cat(
441+
[torch.full((padding_length,), 0), feature["input_ids"]]
442+
)
443+
feature["token_type_ids"] = torch.cat(
444+
[
445+
torch.zeros(padding_length, dtype=torch.long),
446+
feature["token_type_ids"],
447+
]
448+
)
449+
feature["attention_mask"] = torch.cat(
450+
[
451+
torch.zeros(padding_length, dtype=torch.long),
452+
feature["attention_mask"],
453+
]
454+
)
455+
return feature
456+
457+
features = [
458+
pad_to_max_length_internal(feature, max_length, i)
459+
for i, feature in enumerate(prompts)
460+
]
461+
batch = {
462+
key: torch.stack([feature[key] for feature in features])
463+
for key in features[0].keys()
464+
}
465+
466+
position_ids = self.build_position_ids(batch["token_type_ids"])
467+
batch["position_ids"] = position_ids
468+
469+
for i in range(len(prompts)):
470+
req = req_list[i]
471+
req.extra_kwargs["max_position_id"] = position_ids[i : i + 1, -1].item()
472+
473+
if images:
474+
batch["images"] = images
475+
476+
batch = recur_move_to(
477+
batch, self._device, lambda x: isinstance(x, torch.Tensor)
478+
)
479+
dtype = self.get_dtype()
480+
if dtype:
481+
batch = recur_move_to(
482+
batch,
483+
dtype,
484+
lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x),
485+
)
486+
return batch
487+
488+
def build_decode_token_type_ids(
489+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
490+
):
491+
token_type_ids = torch.full(
492+
(batch_size, 1), fill_value=1, dtype=torch.long, device=self._device
493+
)
494+
return token_type_ids
495+
496+
def build_decode_position_ids(
497+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
498+
):
499+
tmp = []
500+
for r in reqs:
501+
r.extra_kwargs["max_position_id"] += 1
502+
tmp.append(r.extra_kwargs["max_position_id"])
503+
position_ids = torch.as_tensor(
504+
tmp, device=self._device, dtype=torch.long
505+
).unsqueeze(1)
506+
return position_ids

0 commit comments

Comments
 (0)