Skip to content

Commit

Permalink
[python] Update vllm rolling batch to return only deltas
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Dec 14, 2024
1 parent 96a0efd commit d20ea7f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,47 +91,26 @@ def update_request_cache_with_output(request_cache: OrderedDict,

def update_multiple_sequences(cache, request_output, vllm_request_output):
for completion_output in vllm_request_output.outputs:

sequence_index = completion_output.index
if f"sequence_index_{sequence_index}" not in cache:
cache[f"sequence_index_{sequence_index}"] = {
"curr_length": 0,
"num_generated_tokens": 0
}

if sequence_index not in request_output.sequences:
request_output.sequences[sequence_index] = Sequence()

# set token of the sequence
# previous length of token ids generated
prev_len = cache[f"sequence_index_{sequence_index}"][
'num_generated_tokens']
# curr length of the token ids generated so far
cur_len = len(completion_output.token_ids)
cache[f"sequence_index_{sequence_index}"][
"num_generated_tokens"] = cur_len

# get the newly generated token_ids
new_token_ids = completion_output.token_ids[
prev_len:
cur_len] if prev_len < cur_len else completion_output.token_ids
new_token_ids = completion_output.token_ids

# get the newly generated token texts for speculative decoding
output_token_texts = []
if hasattr(completion_output, "output_token_texts"):
output_token_texts = completion_output.output_token_texts[
prev_len:
cur_len] if prev_len < cur_len else completion_output.output_token_texts
output_token_texts = completion_output.output_token_texts

top_tokens = []
token_texts = []
# calculate log probs and token_texts
if completion_output.logprobs:
new_logprobs_list = completion_output.logprobs[
prev_len:
cur_len] if prev_len < cur_len else completion_output.logprobs
new_logprobs = []
for token_id, logprobs in zip(new_token_ids, new_logprobs_list):
for token_id, logprobs in zip(new_token_ids,
completion_output.logprobs):
new_logprobs.append(logprobs[token_id].logprob)
decoded_token = logprobs[token_id].decoded_token if logprobs[
token_id].decoded_token else ""
Expand All @@ -141,13 +120,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
Token(id=token_id_key,
text=logprob.decoded_token,
log_prob=logprob.logprob))

elif new_token_ids:
# TODO: Test and remove this. logprobs is always set 1. This case should never happen.
new_logprobs = [None] * len(new_token_ids)
curr_length = cache[f"sequence_index_{sequence_index}"][
"curr_length"]
token_texts.append(completion_output.text[curr_length:])
token_texts.append(completion_output.text)

if not output_token_texts:
if len(token_texts) != len(new_token_ids):
Expand Down Expand Up @@ -186,9 +162,6 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
request_output.sequences[sequence_index].set_next_top_tokens(
top_tokens)

cache[f"sequence_index_{sequence_index}"]["curr_length"] = len(
completion_output.text)


def get_speculative_decoding_metrics_record(
completion_output: CompletionOutput,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections import OrderedDict, defaultdict

from vllm import LLMEngine, SamplingParams
from vllm.sampling_params import RequestOutputKind
from vllm.utils import random_uuid, AtomicCounter

from djl_python.request import Request
Expand Down Expand Up @@ -128,7 +129,8 @@ def inference(self, new_requests: List[Request]) -> List:
request_id = random_uuid()
prompt_inputs = get_prompt_inputs(request)
params = self.translate_vllm_params(request.parameters)
sampling_params = SamplingParams(**params)
sampling_params = SamplingParams(
output_kind=RequestOutputKind.DELTA, **params)
request_params = dict()
if request.adapter is not None:
adapter_name = request.adapter.get_property("name")
Expand Down
11 changes: 5 additions & 6 deletions engines/python/setup/djl_python/tests/test_rb_vllm_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
import unittest
import uuid
from dataclasses import dataclass
from typing import List, Optional, Dict, Union
from collections import OrderedDict
Expand All @@ -12,7 +11,7 @@
import djl_python
from djl_python.output_formatter import _json_output_formatter
from djl_python.request import Request
from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token, RequestInput
from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token
'''These Mock classes are in compliance with vllm RequestOutput version 0.5.3.post1'''


Expand Down Expand Up @@ -148,7 +147,7 @@ def __init__(
],
outputs=[
MockCompletionOutput(index=1,
text=' member of',
text=' of',
token_ids=[4292, 302],
cumulative_logprob=-4.3041129764169455,
logprobs=[{
Expand Down Expand Up @@ -181,7 +180,7 @@ def __init__(
finish_reason=None,
stop_reason=None),
MockCompletionOutput(index=0,
text=' consolidated',
text='ated',
token_ids=[22968, 601],
cumulative_logprob=-13.402491569519043,
logprobs=[{
Expand Down Expand Up @@ -235,7 +234,7 @@ def __init__(
],
outputs=[
MockCompletionOutput(index=1,
text=' member of the',
text=' the',
token_ids=[4292, 302,
272],
cumulative_logprob=-4.815703457221389,
Expand Down Expand Up @@ -282,7 +281,7 @@ def __init__(
finish_reason='length',
stop_reason=None),
MockCompletionOutput(index=0,
text=' consolidated or',
text=' or',
token_ids=[22968, 601, 442],
cumulative_logprob=-20.4010648727417,
logprobs=[{
Expand Down

0 comments on commit d20ea7f

Please sign in to comment.