Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def model_forward(
context=context,
)
output = model(**input_dict)
return dict(hidden_states=output, model_metas=model_metas)
seq_length = ctx_mgr.current_context().q_seqlens

return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length)


@record_function('stopping_criteria')
Expand Down Expand Up @@ -502,7 +504,10 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int):
if not is_long_context:
ret = await __forward(inputs)
if not return_logits and not inputs.is_decoding:
last_token_loc = inputs.seq_length.cumsum(0) - 1
# fetch seq_length from the returned context, since models may change it (e.g. InternVL-Flash)
seq_length = ret.get('seq_length', None)
last_token_loc = seq_length.cumsum(0) - 1

ret['hidden_states'] = ret['hidden_states'][:, last_token_loc]
else:
ret = await __long_context_single_forward(inputs, max_seqlen)
Expand Down
Loading