Skip to content

Commit d76c29c

Browse files
authored
output logprobs (#3852)
* support logprobs * packed output * fix * expose args
1 parent 75971e9 commit d76c29c

File tree

10 files changed

+203
-43
lines changed

10 files changed

+203
-43
lines changed

lmdeploy/cli/serve.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def add_parser_api_server():
9191
ArgumentHelper.device(pt_group)
9292
ArgumentHelper.eager_mode(pt_group)
9393
ArgumentHelper.disable_vision_encoder(pt_group)
94+
ArgumentHelper.logprobs_mode(pt_group)
9495

9596
# common engine args
9697
dtype_act = ArgumentHelper.dtype(pt_group)
@@ -217,6 +218,7 @@ def api_server(args):
217218
model_format=args.model_format,
218219
hf_overrides=args.hf_overrides,
219220
disable_vision_encoder=args.disable_vision_encoder,
221+
logprobs_mode=args.logprobs_mode,
220222
)
221223
else:
222224
from lmdeploy.messages import TurbomindEngineConfig

lmdeploy/cli/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,15 @@ def disable_vision_encoder(parser):
601601
default=False,
602602
help='enable metrics system')
603603

604+
@staticmethod
605+
def logprobs_mode(parser):
606+
"""The mode of logprobs."""
607+
parser.add_argument('--logprobs-mode',
608+
type=str,
609+
default=None,
610+
choices=[None, 'raw_logits', 'raw_logprobs'],
611+
help='The mode of logprobs.')
612+
604613

605614
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py
606615
class FlexibleArgumentParser(argparse.ArgumentParser):

lmdeploy/messages.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ class PytorchEngineConfig:
333333
It can be used to override the default config of the model,
334334
disable_vision_encoder (bool): Whether to disable loading vision
335335
encoder. Default to False.
336+
logprobs_mode (str): The mode of logprob, options: ['raw_logits', 'raw_logprobs']
336337
"""
337338
dtype: str = 'auto'
338339
tp: int = 1
@@ -366,6 +367,7 @@ class PytorchEngineConfig:
366367
enable_metrics: bool = False
367368
hf_overrides: Optional[Dict[str, Any]] = None
368369
disable_vision_encoder: bool = False
370+
logprobs_mode: str = None
369371

370372
role: EngineRole = EngineRole.Hybrid
371373
migration_backend: MigrationBackend = MigrationBackend.DLSlime

lmdeploy/pytorch/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class MiscConfig:
293293
model_format: str = None
294294
hf_overrides: Dict[str, Any] = None
295295
disable_vision_encoder: bool = False
296+
logprobs_mode: str = None
296297

297298
@classmethod
298299
def from_engine_config(cls, engine_config: PytorchEngineConfig):
@@ -302,5 +303,6 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig):
302303
prefill_interval=engine_config.prefill_interval,
303304
model_format=engine_config.model_format,
304305
hf_overrides=engine_config.hf_overrides,
305-
disable_vision_encoder=engine_config.disable_vision_encoder)
306+
disable_vision_encoder=engine_config.disable_vision_encoder,
307+
logprobs_mode=engine_config.logprobs_mode)
306308
return misc_config

lmdeploy/pytorch/engine/engine.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .engine_checker import EngineChecker
2828
from .executor import build_executor
2929
from .logits_process import SamplingInputs
30+
from .model_agent import BatchedOutputs
3031
from .request import Request, RequestManager, RequestType, Response
3132

3233
logger = get_logger('lmdeploy')
@@ -46,6 +47,7 @@ class InferOutput:
4647
meta: Any = None
4748
finish: bool = False
4849
logits: torch.Tensor = None
50+
logprobs: torch.Tensor = None
4951

5052
# send cache blocks back for migration in Disaggregated LLM Serving
5153
# when Prefill Engine is Done.
@@ -816,9 +818,18 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray,
816818
msg.update_token_ids(update_token, model_meta=model_meta)
817819
msg.status = MessageStatus.STOPPED
818820

819-
def _make_infer_outputs(self, new_token_timestamp: float, next_token_ids: torch.LongTensor, running: SeqList,
820-
logits: torch.Tensor, stopped: torch.Tensor, model_metas: List[Dict[str, Any]]):
821+
def _make_infer_outputs(
822+
self,
823+
batched_outputs: BatchedOutputs,
824+
running: SeqList,
825+
):
821826
"""Make infer output."""
827+
new_token_timestamp = batched_outputs.new_token_timestamp
828+
next_token_ids = batched_outputs.next_token_ids
829+
logits = batched_outputs.logits
830+
stopped = batched_outputs.stopped
831+
model_metas = batched_outputs.model_metas
832+
logprobs = batched_outputs.logprobs
822833

823834
seq_length = [seq.num_token_ids for seq in running]
824835
is_run = [seq.status == MessageStatus.LOCKED for seq in running]
@@ -839,13 +850,21 @@ def _make_infer_outputs(self, new_token_timestamp: float, next_token_ids: torch.
839850
cache_block_ids = self.scheduler.block_manager.get_block_table(msg).tolist()
840851
else:
841852
cache_block_ids = None
853+
854+
# logprobs
855+
num_logprobs = msg.sampling_param.num_logprobs
856+
cur_logprobs = None
857+
if num_logprobs >= 0:
858+
cur_logprobs = (logprobs.vals[idx, :num_logprobs + 1], logprobs.indices[idx, :num_logprobs + 1])
859+
842860
req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events)
843861
out = InferOutput(session_id=session_id,
844862
resp=msg.resp,
845863
finish=finish,
846864
token_ids=token_ids,
847865
cache_block_ids=cache_block_ids,
848-
req_metrics=req_metrics)
866+
req_metrics=req_metrics,
867+
logprobs=cur_logprobs)
849868
outputs[session_id] = out
850869

851870
if msg.return_logits:
@@ -977,12 +996,22 @@ def __log_resps(outputs: List[InferOutput]):
977996
def __send_resp(out: InferOutput):
978997
"""Send response."""
979998
resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS)
999+
cur_logprobs = out.logprobs
1000+
logprobs = None
1001+
if cur_logprobs is not None:
1002+
# logprobs to dict
1003+
vals = cur_logprobs[0].tolist()
1004+
indices = cur_logprobs[1].tolist()
1005+
cur_logprobs = dict(zip(indices, vals))
1006+
logprobs = [] if out.resp.data is None else out.resp.data.get('logprobs', [])
1007+
logprobs = logprobs + [cur_logprobs]
9801008
self._response(out.resp,
9811009
resp_type,
9821010
data=dict(token_ids=out.token_ids,
9831011
logits=out.logits,
9841012
cache_block_ids=out.cache_block_ids,
985-
req_metrics=out.req_metrics))
1013+
req_metrics=out.req_metrics,
1014+
logprobs=logprobs))
9861015

9871016
def __send_resps(step_outputs: List[InferOutput]):
9881017
"""Send response callback."""
@@ -1118,8 +1147,8 @@ async def _async_loop_main(
11181147

11191148
# send output
11201149
out = await self.executor.get_output_async()
1121-
if len(out) > 0:
1122-
step_outputs = self._make_infer_outputs(**out, running=running)
1150+
if out is not None:
1151+
step_outputs = self._make_infer_outputs(out, running=running)
11231152
resp_que.put_nowait(step_outputs)
11241153

11251154
# lock forward event

lmdeploy/pytorch/engine/engine_instance.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ async def async_stream_infer(self,
150150

151151
cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None
152152
req_metrics = resp.data.get('req_metrics', None) if resp.data else None
153+
logprobs = resp.data.get('logprobs', None) if resp.data else None
153154
if resp.type == ResponseType.SUCCESS:
154155
token_ids = resp.data['token_ids'].tolist()
155156
num_ids = len(token_ids)
@@ -158,7 +159,8 @@ async def async_stream_infer(self,
158159
token_ids,
159160
num_ids,
160161
cache_block_ids=cache_block_ids,
161-
req_metrics=req_metrics)
162+
req_metrics=req_metrics,
163+
logprobs=logprobs)
162164
elif resp.type == ResponseType.FINISH:
163165
resp_data = resp.data
164166
token_ids = resp_data['token_ids'].tolist()
@@ -170,7 +172,8 @@ async def async_stream_infer(self,
170172
num_ids,
171173
logits=logits,
172174
cache_block_ids=cache_block_ids,
173-
req_metrics=req_metrics)
175+
req_metrics=req_metrics,
176+
logprobs=logprobs)
174177
break
175178
else:
176179
logger.debug(f'session[{session_id}] failed.')

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os
66
from typing import Any, Dict, List, Optional, Tuple
77

8-
import numpy as np
98
import ray
109
import ray.exceptions
1110
import torch
@@ -221,13 +220,7 @@ def warmup_dist(self):
221220

222221
def pack_output(self, output: Dict):
223222
"""Pack output."""
224-
for k, v in output.items():
225-
if isinstance(v, torch.Tensor):
226-
# fix numpy do not have BFloat16 type
227-
if v.dtype is torch.bfloat16:
228-
v = v.to(torch.float16)
229-
output[k] = v.numpy()
230-
return output
223+
return output.to_numpy()
231224

232225
def remote_log_start(self, msg: str):
233226
"""Remote log start."""
@@ -385,10 +378,7 @@ async def _prefetch_outputs(self):
385378
outs = await self.workers[0].get_outputs.remote()
386379
logger.debug(f'Receive {len(outs)} outputs from worker[0].')
387380
for out in outs:
388-
# pack pytorch
389-
for k, v in out.items():
390-
if isinstance(v, np.ndarray):
391-
out[k] = torch.from_numpy(v)
381+
out = out.to_tensor()
392382
self.remote_outs.put_nowait(out)
393383

394384
def _prefetch_task_callback(self, task: asyncio.Task):

lmdeploy/pytorch/engine/logits_process.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class SamplingInputs:
126126
min_top_p: float = 1.0
127127
response_formats: Tuple[str] = ()
128128
logits_processors: List[List[LogitsProcessor]] = None
129+
max_num_logprobs: Optional[int] = None
129130

130131
@classmethod
131132
def from_sampling_params(cls, seqs: List[SchedulerSequence]):
@@ -142,6 +143,7 @@ def from_sampling_params(cls, seqs: List[SchedulerSequence]):
142143
random_offsets = [None] * batch_size
143144
response_formats = [None] * batch_size
144145
logits_processors = [None] * batch_size
146+
num_logprobs = [None] * batch_size
145147

146148
def __gather_params():
147149
"""Gather params."""
@@ -164,6 +166,7 @@ def __gather_params():
164166
bad_words[idx] = bw
165167
stop_words[idx] = sw
166168
logits_processors[idx] = param.logits_processors
169+
num_logprobs[idx] = param.num_logprobs
167170

168171
def __get_topp(top_p):
169172
"""Get topp."""
@@ -232,6 +235,8 @@ def __get_bad_words(bad_words):
232235
random_seeds = torch.tensor(random_seeds)
233236
random_offsets = torch.tensor(random_offsets)
234237

238+
max_num_logprobs = max(num_logprobs)
239+
235240
sampling_input = cls(
236241
temperature=temperature,
237242
bad_words=bad_words,
@@ -248,6 +253,7 @@ def __get_bad_words(bad_words):
248253
max_top_k=max_top_k,
249254
min_top_p=min_top_p,
250255
logits_processors=logits_processors,
256+
max_num_logprobs=max_num_logprobs,
251257
)
252258
return sampling_input
253259

@@ -280,11 +286,13 @@ def __init__(self,
280286
sampling_inputs: SamplingInputs,
281287
ignore_eos: torch.Tensor,
282288
tokenizer: Optional[Tokenizer] = None,
283-
sampling_vocab_size: Optional[int] = None):
289+
sampling_vocab_size: Optional[int] = None,
290+
logprobs_mode: Optional[str] = None):
284291
self.sampling_inputs: SamplingInputs = sampling_inputs
285292
self.ignore_eos = ignore_eos
286293
self.tokenizer = tokenizer
287294
self.sampling_vocab_size = sampling_vocab_size
295+
self.logprobs_mode = logprobs_mode
288296

289297
async def _wait_stream_once(self):
290298
"""Wait stream once."""
@@ -309,6 +317,19 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long
309317
torch.FloatTensor: The processed prediction scores.
310318
311319
"""
320+
321+
num_logprobs = self.sampling_inputs.max_num_logprobs
322+
# get raw logprobs
323+
if num_logprobs < 0:
324+
logprobs = None
325+
else:
326+
if self.logprobs_mode == 'raw_logits':
327+
logprobs = scores.clone()
328+
elif self.logprobs_mode == 'raw_logprobs':
329+
logprobs = scores.log_softmax(dim=-1)
330+
else:
331+
logprobs = None
332+
312333
sampling_inputs = self.sampling_inputs
313334

314335
custom_logits_processors = self.sampling_inputs.logits_processors
@@ -338,7 +359,7 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long
338359
if guided_input_ids is not None:
339360
await self._wait_stream_once()
340361
scores = _guided_sampling(sampling_inputs.response_formats, scores, guided_input_ids, self.tokenizer)
341-
return scores
362+
return scores, logprobs
342363

343364
@torch.inference_mode()
344365
def sampling(self, logits: torch.Tensor):
@@ -384,3 +405,19 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
384405
else:
385406
scores, indices = logits.topk(max_topk, dim=1)
386407
return __random_sampling(scores, indices)
408+
409+
@torch.inference_mode()
410+
def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTensor):
411+
"""Compute logprobs."""
412+
if raw_logprobs is None:
413+
return None
414+
415+
indices = token_ids.unsqueeze(-1)
416+
logprobs = raw_logprobs.gather(-1, indices)
417+
num_logprobs = self.sampling_inputs.max_num_logprobs
418+
if num_logprobs > 0:
419+
topk_logprobs, topk_indices = raw_logprobs.topk(num_logprobs, dim=-1)
420+
logprobs = torch.cat([logprobs, topk_logprobs], dim=-1)
421+
indices = torch.cat([indices, topk_indices], dim=-1)
422+
423+
return logprobs, indices.to(torch.int32)

0 commit comments

Comments
 (0)