23
23
import torch
24
24
from PIL import Image
25
25
26
+ from ....core .scheduler import InferenceRequest
26
27
from ....model .utils import select_device
27
28
from ....types import (
28
29
ChatCompletion ,
35
36
)
36
37
from ..llm_family import LLMFamilyV1 , LLMSpecV1
37
38
from .core import PytorchChatModel , PytorchGenerateConfig
39
+ from .utils import get_max_src_len
38
40
39
41
logger = logging .getLogger (__name__ )
40
42
41
- IMAGENET_MEAN = (0.485 , 0.456 , 0.406 )
42
- IMAGENET_STD = (0.229 , 0.224 , 0.225 )
43
+
44
+ LANGUAGE_TOKEN_TYPE = 0
45
+ VISION_TOKEN_TYPE = 1
46
+
47
+
48
+ def recur_move_to (item , tgt , criterion_func ):
49
+ """
50
+ This function is copied from https://github.com/THUDM/CogVLM2/blob/main/basic_demo/cli_demo_batch_inference.py
51
+ """
52
+ if criterion_func (item ):
53
+ device_copy = item .to (tgt )
54
+ return device_copy
55
+ elif isinstance (item , list ):
56
+ return [recur_move_to (v , tgt , criterion_func ) for v in item ]
57
+ elif isinstance (item , tuple ):
58
+ return tuple ([recur_move_to (v , tgt , criterion_func ) for v in item ])
59
+ elif isinstance (item , dict ):
60
+ return {k : recur_move_to (v , tgt , criterion_func ) for k , v in item .items ()}
61
+ else :
62
+ return item
43
63
44
64
45
65
class CogVLM2Model (PytorchChatModel ):
@@ -176,11 +196,33 @@ def _image_to_piexl_values(image):
176
196
content ["image_url" ]["url" ]
177
197
)
178
198
assistant = chat_history [i + 1 ]["content" ]
179
- query = query + f" USER: { user } ASSISTANT:"
180
- history .append ((query , assistant ))
181
- query = query + f" { assistant } "
199
+ history .append ((user , assistant ))
200
+ query = assistant # type: ignore
182
201
return query , history , [pixel_values ]
183
202
203
+ def get_query_and_history (
204
+ self ,
205
+ prompt : Union [str , List [Dict ]],
206
+ system_prompt : Optional [str ] = None ,
207
+ chat_history : Optional [List [ChatCompletionMessage ]] = None ,
208
+ ):
209
+ content , image = self ._message_content_to_cogvlm2 (prompt )
210
+
211
+ history = []
212
+ history_image = None
213
+ if chat_history :
214
+ query , history , history_image = self ._history_content_to_cogvlm2 (
215
+ system_prompt , chat_history # type: ignore
216
+ )
217
+
218
+ if image and history_image :
219
+ history = []
220
+ query = content
221
+ else :
222
+ image = image if image else history_image
223
+ query = content
224
+ return query , image , history
225
+
184
226
def chat (
185
227
self ,
186
228
prompt : Union [str , List [Dict ]],
@@ -198,22 +240,9 @@ def chat(
198
240
else 512 ,
199
241
}
200
242
201
- content , image = self ._message_content_to_cogvlm2 (prompt )
202
-
203
- history = []
204
- query = ""
205
- history_image = None
206
- if chat_history :
207
- query , history , history_image = self ._history_content_to_cogvlm2 (
208
- system_prompt , chat_history
209
- )
210
-
211
- if image and history_image :
212
- history = []
213
- query = system_prompt + f" USER: { content } ASSISTANT:"
214
- else :
215
- image = image if image else history_image
216
- query = query + f" USER: { content } ASSISTANT:"
243
+ query , image , history = self .get_query_and_history (
244
+ prompt , system_prompt = system_prompt , chat_history = chat_history
245
+ )
217
246
218
247
input_by_model = self ._model .build_conversation_input_ids (
219
248
self ._tokenizer ,
@@ -319,3 +348,159 @@ def _streaming_chat_response(
319
348
),
320
349
)
321
350
yield chunk
351
+
352
+ @staticmethod
353
+ def build_position_ids (x , attention_mask = None ):
354
+ """
355
+ Copied from https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B-int4/blob/main/modeling_cogvlm.py
356
+ """
357
+ # Fix: 参考官方开源代码
358
+ if attention_mask is not None :
359
+ tmp = x .clone ()
360
+ tmp [~ (attention_mask .bool ())] = - 1
361
+ else :
362
+ tmp = x .clone ()
363
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
364
+ is_boi_eoi = torch .zeros_like (x , dtype = torch .bool )
365
+ is_boi_eoi [:, 1 :] |= (tmp [:, 1 :] == VISION_TOKEN_TYPE ) & (
366
+ tmp [:, :- 1 ] == LANGUAGE_TOKEN_TYPE
367
+ )
368
+ is_boi_eoi [:, 0 ] |= tmp [:, 0 ] == VISION_TOKEN_TYPE
369
+ is_boi_eoi [:, :- 1 ] |= (tmp [:, :- 1 ] == VISION_TOKEN_TYPE ) & (
370
+ tmp [:, 1 :] == LANGUAGE_TOKEN_TYPE
371
+ )
372
+ is_boi_eoi [:, - 1 ] |= tmp [:, - 1 ] == VISION_TOKEN_TYPE
373
+ tmp [is_boi_eoi ] = LANGUAGE_TOKEN_TYPE
374
+ # final position ids
375
+ y = torch .zeros_like (x , dtype = torch .long )
376
+ y [:, 1 :] = (tmp [:, 1 :] == LANGUAGE_TOKEN_TYPE ) | (
377
+ (tmp [:, 1 :] == VISION_TOKEN_TYPE ) & (tmp [:, :- 1 ] == LANGUAGE_TOKEN_TYPE )
378
+ )
379
+ y = y .cumsum (dim = - 1 )
380
+ return y
381
+
382
+ def get_dtype (self ):
383
+ return self ._torch_type
384
+
385
+ def _get_full_prompt (self , prompt , system_prompt , chat_history , tools ):
386
+ query , image , history = self .get_query_and_history (
387
+ prompt , system_prompt = system_prompt , chat_history = chat_history
388
+ )
389
+
390
+ input_by_model : dict = self ._model .build_conversation_input_ids (
391
+ self ._tokenizer ,
392
+ query = query ,
393
+ history = history ,
394
+ images = image ,
395
+ template_version = "chat" ,
396
+ )
397
+ return {
398
+ "input_ids" : input_by_model ["input_ids" ], # seq_len
399
+ "token_type_ids" : input_by_model ["token_type_ids" ], # seq_len
400
+ "attention_mask" : input_by_model ["attention_mask" ], # seq_len
401
+ "images" : input_by_model ["images" ],
402
+ }
403
+
404
+ def prepare_sanitize_generate_config (self , req : InferenceRequest ):
405
+ """
406
+ See https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B/blob/main/generation_config.json
407
+ """
408
+ raw_config = req .inference_kwargs .get ("raw_params" , {})
409
+ temperature = raw_config .get ("temperature" , None )
410
+ if temperature is None :
411
+ raw_config ["temperature" ] = 0.6
412
+ top_p = raw_config .get ("top_p" , None )
413
+ if top_p is None :
414
+ raw_config ["top_p" ] = 0.9
415
+ return raw_config
416
+
417
+ def build_prefill_kwargs (self , prompts : List , req_list : List [InferenceRequest ]):
418
+ context_len = self .get_context_len ()
419
+ assert isinstance (prompts [0 ], dict )
420
+ images = []
421
+ max_length = float ("-inf" )
422
+ for i , feature in enumerate (prompts ):
423
+ req = req_list [i ]
424
+ if "images" in feature :
425
+ images .append (feature .pop ("images" , None ))
426
+ max_src_len = get_max_src_len (context_len , req )
427
+ input_ids = feature ["input_ids" ][- max_src_len :]
428
+ req .prompt_tokens = input_ids .tolist ()
429
+ feature ["input_ids" ] = input_ids
430
+ feature ["token_type_ids" ] = feature ["token_type_ids" ][- max_src_len :]
431
+ feature ["attention_mask" ] = feature ["attention_mask" ][- max_src_len :]
432
+ req .extra_kwargs ["attention_mask_seq_len" ] = feature [
433
+ "attention_mask"
434
+ ].shape [0 ]
435
+ max_length = max (len (input_ids ), max_length )
436
+
437
+ def pad_to_max_length_internal (feature , max_len , idx ):
438
+ padding_length = max_len - len (feature ["input_ids" ])
439
+ req_list [idx ].padding_len = padding_length
440
+ feature ["input_ids" ] = torch .cat (
441
+ [torch .full ((padding_length ,), 0 ), feature ["input_ids" ]]
442
+ )
443
+ feature ["token_type_ids" ] = torch .cat (
444
+ [
445
+ torch .zeros (padding_length , dtype = torch .long ),
446
+ feature ["token_type_ids" ],
447
+ ]
448
+ )
449
+ feature ["attention_mask" ] = torch .cat (
450
+ [
451
+ torch .zeros (padding_length , dtype = torch .long ),
452
+ feature ["attention_mask" ],
453
+ ]
454
+ )
455
+ return feature
456
+
457
+ features = [
458
+ pad_to_max_length_internal (feature , max_length , i )
459
+ for i , feature in enumerate (prompts )
460
+ ]
461
+ batch = {
462
+ key : torch .stack ([feature [key ] for feature in features ])
463
+ for key in features [0 ].keys ()
464
+ }
465
+
466
+ position_ids = self .build_position_ids (batch ["token_type_ids" ])
467
+ batch ["position_ids" ] = position_ids
468
+
469
+ for i in range (len (prompts )):
470
+ req = req_list [i ]
471
+ req .extra_kwargs ["max_position_id" ] = position_ids [i : i + 1 , - 1 ].item ()
472
+
473
+ if images :
474
+ batch ["images" ] = images
475
+
476
+ batch = recur_move_to (
477
+ batch , self ._device , lambda x : isinstance (x , torch .Tensor )
478
+ )
479
+ dtype = self .get_dtype ()
480
+ if dtype :
481
+ batch = recur_move_to (
482
+ batch ,
483
+ dtype ,
484
+ lambda x : isinstance (x , torch .Tensor ) and torch .is_floating_point (x ),
485
+ )
486
+ return batch
487
+
488
+ def build_decode_token_type_ids (
489
+ self , batch_size : int , seq_length : int , reqs : List [InferenceRequest ]
490
+ ):
491
+ token_type_ids = torch .full (
492
+ (batch_size , 1 ), fill_value = 1 , dtype = torch .long , device = self ._device
493
+ )
494
+ return token_type_ids
495
+
496
+ def build_decode_position_ids (
497
+ self , batch_size : int , seq_length : int , reqs : List [InferenceRequest ]
498
+ ):
499
+ tmp = []
500
+ for r in reqs :
501
+ r .extra_kwargs ["max_position_id" ] += 1
502
+ tmp .append (r .extra_kwargs ["max_position_id" ])
503
+ position_ids = torch .as_tensor (
504
+ tmp , device = self ._device , dtype = torch .long
505
+ ).unsqueeze (1 )
506
+ return position_ids
0 commit comments