From d20ea7f3544cbd9181af740ea324b114c51cebd6 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Fri, 13 Dec 2024 18:27:44 -0800 Subject: [PATCH] [python] Update vllm rolling batch to return only deltas --- .../rolling_batch/rolling_batch_vllm_utils.py | 37 +++---------------- .../rolling_batch/vllm_rolling_batch.py | 4 +- .../djl_python/tests/test_rb_vllm_utils.py | 11 +++--- 3 files changed, 13 insertions(+), 39 deletions(-) diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index bad3cc8eb6..ee7be75b7c 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -91,47 +91,26 @@ def update_request_cache_with_output(request_cache: OrderedDict, def update_multiple_sequences(cache, request_output, vllm_request_output): for completion_output in vllm_request_output.outputs: - sequence_index = completion_output.index - if f"sequence_index_{sequence_index}" not in cache: - cache[f"sequence_index_{sequence_index}"] = { - "curr_length": 0, - "num_generated_tokens": 0 - } if sequence_index not in request_output.sequences: request_output.sequences[sequence_index] = Sequence() - # set token of the sequence - # previous length of token ids generated - prev_len = cache[f"sequence_index_{sequence_index}"][ - 'num_generated_tokens'] - # curr length of the token ids generated so far - cur_len = len(completion_output.token_ids) - cache[f"sequence_index_{sequence_index}"][ - "num_generated_tokens"] = cur_len - # get the newly generated token_ids - new_token_ids = completion_output.token_ids[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.token_ids + new_token_ids = completion_output.token_ids # get the newly generated token texts for speculative decoding output_token_texts = [] if hasattr(completion_output, "output_token_texts"): - output_token_texts = completion_output.output_token_texts[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.output_token_texts + output_token_texts = completion_output.output_token_texts top_tokens = [] token_texts = [] # calculate log probs and token_texts if completion_output.logprobs: - new_logprobs_list = completion_output.logprobs[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.logprobs new_logprobs = [] - for token_id, logprobs in zip(new_token_ids, new_logprobs_list): + for token_id, logprobs in zip(new_token_ids, + completion_output.logprobs): new_logprobs.append(logprobs[token_id].logprob) decoded_token = logprobs[token_id].decoded_token if logprobs[ token_id].decoded_token else "" @@ -141,13 +120,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): Token(id=token_id_key, text=logprob.decoded_token, log_prob=logprob.logprob)) - elif new_token_ids: # TODO: Test and remove this. logprobs is always set 1. This case should never happen. new_logprobs = [None] * len(new_token_ids) - curr_length = cache[f"sequence_index_{sequence_index}"][ - "curr_length"] - token_texts.append(completion_output.text[curr_length:]) + token_texts.append(completion_output.text) if not output_token_texts: if len(token_texts) != len(new_token_ids): @@ -186,9 +162,6 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): request_output.sequences[sequence_index].set_next_top_tokens( top_tokens) - cache[f"sequence_index_{sequence_index}"]["curr_length"] = len( - completion_output.text) - def get_speculative_decoding_metrics_record( completion_output: CompletionOutput, diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 66abbf811e..6ce21a9b9d 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -13,6 +13,7 @@ from collections import OrderedDict, defaultdict from vllm import LLMEngine, SamplingParams +from vllm.sampling_params import RequestOutputKind from vllm.utils import random_uuid, AtomicCounter from djl_python.request import Request @@ -128,7 +129,8 @@ def inference(self, new_requests: List[Request]) -> List: request_id = random_uuid() prompt_inputs = get_prompt_inputs(request) params = self.translate_vllm_params(request.parameters) - sampling_params = SamplingParams(**params) + sampling_params = SamplingParams( + output_kind=RequestOutputKind.DELTA, **params) request_params = dict() if request.adapter is not None: adapter_name = request.adapter.get_property("name") diff --git a/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py b/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py index 41486fb20c..24959627f3 100644 --- a/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py +++ b/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py @@ -1,6 +1,5 @@ import sys import unittest -import uuid from dataclasses import dataclass from typing import List, Optional, Dict, Union from collections import OrderedDict @@ -12,7 +11,7 @@ import djl_python from djl_python.output_formatter import _json_output_formatter from djl_python.request import Request -from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token, RequestInput +from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token '''These Mock classes are in compliance with vllm RequestOutput version 0.5.3.post1''' @@ -148,7 +147,7 @@ def __init__( ], outputs=[ MockCompletionOutput(index=1, - text=' member of', + text=' of', token_ids=[4292, 302], cumulative_logprob=-4.3041129764169455, logprobs=[{ @@ -181,7 +180,7 @@ def __init__( finish_reason=None, stop_reason=None), MockCompletionOutput(index=0, - text=' consolidated', + text='ated', token_ids=[22968, 601], cumulative_logprob=-13.402491569519043, logprobs=[{ @@ -235,7 +234,7 @@ def __init__( ], outputs=[ MockCompletionOutput(index=1, - text=' member of the', + text=' the', token_ids=[4292, 302, 272], cumulative_logprob=-4.815703457221389, @@ -282,7 +281,7 @@ def __init__( finish_reason='length', stop_reason=None), MockCompletionOutput(index=0, - text=' consolidated or', + text=' or', token_ids=[22968, 601, 442], cumulative_logprob=-20.4010648727417, logprobs=[{