diff --git a/vllm/entrypoints/grpc/grpc_server.py b/vllm/entrypoints/grpc/grpc_server.py index 8b5674ed0..662da2faf 100644 --- a/vllm/entrypoints/grpc/grpc_server.py +++ b/vllm/entrypoints/grpc/grpc_server.py @@ -15,19 +15,12 @@ SamplingParams) from vllm.config import ModelConfig from vllm.entrypoints.grpc.pb import generation_pb2_grpc -from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest, - BatchedGenerationResponse, - BatchedTokenizeRequest, - BatchedTokenizeResponse, - DecodingMethod, - GenerationResponse, - ModelInfoRequest, - ModelInfoResponse, - Parameters, - ResponseOptions, - SingleGenerationRequest, - StopReason, TokenInfo, - TokenizeResponse) +from vllm.entrypoints.grpc.pb.generation_pb2 import ( + BatchedGenerationRequest, BatchedGenerationResponse, + BatchedTokenizeRequest, BatchedTokenizeResponse, DecodingMethod, + GenerationResponse, ModelInfoRequest, ModelInfoResponse, Parameters, + ResponseOptions, SingleGenerationRequest, StopReason, TokenInfo, + TokenizeResponse) from vllm.entrypoints.grpc.validation import validate_input, validate_params from vllm.entrypoints.openai.serving_completion import merge_async_iterators from vllm.logger import init_logger @@ -314,8 +307,7 @@ async def _validate_and_convert_params( if not greedy and 0.0 < sampling.typical_p < 1.0: logits_processors.append( - TypicalLogitsWarperWrapper(mass=sampling.typical_p) - ) + TypicalLogitsWarperWrapper(mass=sampling.typical_p)) if params.decoding.length_penalty is not None: length_penalty = ( params.decoding.length_penalty.start_index, @@ -323,7 +315,7 @@ async def _validate_and_convert_params( ) logits_processors.append( LengthPenaltyWarper(length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.tokenizer.eos_token_id)) time_limit_millis = stopping.time_limit_millis deadline = time.time(