@@ -126,6 +126,7 @@ class SamplingInputs:
126
126
min_top_p : float = 1.0
127
127
response_formats : Tuple [str ] = ()
128
128
logits_processors : List [List [LogitsProcessor ]] = None
129
+ max_num_logprobs : Optional [int ] = None
129
130
130
131
@classmethod
131
132
def from_sampling_params (cls , seqs : List [SchedulerSequence ]):
@@ -142,6 +143,7 @@ def from_sampling_params(cls, seqs: List[SchedulerSequence]):
142
143
random_offsets = [None ] * batch_size
143
144
response_formats = [None ] * batch_size
144
145
logits_processors = [None ] * batch_size
146
+ num_logprobs = [None ] * batch_size
145
147
146
148
def __gather_params ():
147
149
"""Gather params."""
@@ -164,6 +166,7 @@ def __gather_params():
164
166
bad_words [idx ] = bw
165
167
stop_words [idx ] = sw
166
168
logits_processors [idx ] = param .logits_processors
169
+ num_logprobs [idx ] = param .num_logprobs
167
170
168
171
def __get_topp (top_p ):
169
172
"""Get topp."""
@@ -232,6 +235,8 @@ def __get_bad_words(bad_words):
232
235
random_seeds = torch .tensor (random_seeds )
233
236
random_offsets = torch .tensor (random_offsets )
234
237
238
+ max_num_logprobs = max (num_logprobs )
239
+
235
240
sampling_input = cls (
236
241
temperature = temperature ,
237
242
bad_words = bad_words ,
@@ -248,6 +253,7 @@ def __get_bad_words(bad_words):
248
253
max_top_k = max_top_k ,
249
254
min_top_p = min_top_p ,
250
255
logits_processors = logits_processors ,
256
+ max_num_logprobs = max_num_logprobs ,
251
257
)
252
258
return sampling_input
253
259
@@ -280,11 +286,13 @@ def __init__(self,
280
286
sampling_inputs : SamplingInputs ,
281
287
ignore_eos : torch .Tensor ,
282
288
tokenizer : Optional [Tokenizer ] = None ,
283
- sampling_vocab_size : Optional [int ] = None ):
289
+ sampling_vocab_size : Optional [int ] = None ,
290
+ logprobs_mode : Optional [str ] = None ):
284
291
self .sampling_inputs : SamplingInputs = sampling_inputs
285
292
self .ignore_eos = ignore_eos
286
293
self .tokenizer = tokenizer
287
294
self .sampling_vocab_size = sampling_vocab_size
295
+ self .logprobs_mode = logprobs_mode
288
296
289
297
async def _wait_stream_once (self ):
290
298
"""Wait stream once."""
@@ -309,6 +317,19 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long
309
317
torch.FloatTensor: The processed prediction scores.
310
318
311
319
"""
320
+
321
+ num_logprobs = self .sampling_inputs .max_num_logprobs
322
+ # get raw logprobs
323
+ if num_logprobs < 0 :
324
+ logprobs = None
325
+ else :
326
+ if self .logprobs_mode == 'raw_logits' :
327
+ logprobs = scores .clone ()
328
+ elif self .logprobs_mode == 'raw_logprobs' :
329
+ logprobs = scores .log_softmax (dim = - 1 )
330
+ else :
331
+ logprobs = None
332
+
312
333
sampling_inputs = self .sampling_inputs
313
334
314
335
custom_logits_processors = self .sampling_inputs .logits_processors
@@ -338,7 +359,7 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long
338
359
if guided_input_ids is not None :
339
360
await self ._wait_stream_once ()
340
361
scores = _guided_sampling (sampling_inputs .response_formats , scores , guided_input_ids , self .tokenizer )
341
- return scores
362
+ return scores , logprobs
342
363
343
364
@torch .inference_mode ()
344
365
def sampling (self , logits : torch .Tensor ):
@@ -384,3 +405,19 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
384
405
else :
385
406
scores , indices = logits .topk (max_topk , dim = 1 )
386
407
return __random_sampling (scores , indices )
408
+
409
+ @torch .inference_mode ()
410
+ def compute_logprobs (self , raw_logprobs : torch .Tensor , token_ids : torch .LongTensor ):
411
+ """Compute logprobs."""
412
+ if raw_logprobs is None :
413
+ return None
414
+
415
+ indices = token_ids .unsqueeze (- 1 )
416
+ logprobs = raw_logprobs .gather (- 1 , indices )
417
+ num_logprobs = self .sampling_inputs .max_num_logprobs
418
+ if num_logprobs > 0 :
419
+ topk_logprobs , topk_indices = raw_logprobs .topk (num_logprobs , dim = - 1 )
420
+ logprobs = torch .cat ([logprobs , topk_logprobs ], dim = - 1 )
421
+ indices = torch .cat ([indices , topk_indices ], dim = - 1 )
422
+
423
+ return logprobs , indices .to (torch .int32 )
0 commit comments