From 79dc5fa7426fe02f1b3dfe1e2e4eb26f9667e546 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai <83252681+yunfeng-scale@users.noreply.github.com> Date: Wed, 20 Mar 2024 21:47:02 -0700 Subject: [PATCH] Guided decoding (#476) * Guided decoding * endpoints * fix * update client * unit tests * fix test * coverage * coverage * fix * try to bump coverage * more tests! * lint --- clients/python/llmengine/__init__.py | 2 +- clients/python/llmengine/completion.py | 42 ++++- clients/python/llmengine/data_types.py | 8 + clients/python/pyproject.toml | 2 +- clients/python/setup.py | 2 +- docs/guides/completions.md | 53 +++++++ .../model_engine_server/common/dtos/llms.py | 24 +++ .../use_cases/llm_model_endpoint_use_cases.py | 32 ++++ .../inference/vllm/requirements.txt | 3 +- .../inference/vllm/vllm_server.py | 32 +++- model-engine/tests/unit/conftest.py | 132 ++++++++++++++++ .../tests/unit/domain/test_llm_use_cases.py | 149 ++++++++++++++++++ 12 files changed, 474 insertions(+), 7 deletions(-) diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 17dacfa9..dfae78cf 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0b27" +__version__ = "0.0.0b28" import os from typing import Sequence diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 43d0813c..0181b733 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -1,4 +1,4 @@ -from typing import AsyncIterable, Iterator, List, Optional, Union +from typing import Any, AsyncIterable, Dict, Iterator, List, Optional, Union from llmengine.api_engine import APIEngine from llmengine.data_types import ( @@ -43,6 +43,10 @@ async def acreate( frequency_penalty: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + include_stop_str_in_output: Optional[bool] = False, + guided_json: Optional[Dict[str, Any]] = None, + guided_regex: Optional[str] = None, + guided_choice: Optional[List[str]] = None, timeout: int = COMPLETION_TIMEOUT, stream: bool = False, ) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]: @@ -102,6 +106,18 @@ async def acreate( Float that controls the cumulative probability of the top tokens to consider. Range: (0.0, 1.0]. 1.0 means consider all tokens. + include_stop_str_in_output (Optional[bool]): + Whether to include the stop sequence in the output. Default to False. + + guided_json (Optional[Dict[str, Any]]): + If specified, the output will follow the JSON schema. For examples see https://json-schema.org/learn/miscellaneous-examples. + + guided_regex (Optional[str]): + If specified, the output will follow the regex pattern. + + guided_choice (Optional[List[str]]): + If specified, the output will be exactly one of the choices. + timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -198,6 +214,10 @@ async def _acreate_stream( frequency_penalty=frequency_penalty, top_k=top_k, top_p=top_p, + include_stop_str_in_output=include_stop_str_in_output, + guided_json=guided_json, + guided_regex=guided_regex, + guided_choice=guided_choice, timeout=timeout, ) @@ -237,6 +257,10 @@ def create( frequency_penalty: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, + include_stop_str_in_output: Optional[bool] = False, + guided_json: Optional[Dict[str, Any]] = None, + guided_regex: Optional[str] = None, + guided_choice: Optional[List[str]] = None, timeout: int = COMPLETION_TIMEOUT, stream: bool = False, ) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]: @@ -297,6 +321,18 @@ def create( Float that controls the cumulative probability of the top tokens to consider. Range: (0.0, 1.0]. 1.0 means consider all tokens. + include_stop_str_in_output (Optional[bool]): + Whether to include the stop sequence in the output. Default to False. + + guided_json (Optional[Dict[str, Any]]): + If specified, the output will follow the JSON schema. + + guided_regex (Optional[str]): + If specified, the output will follow the regex pattern. + + guided_choice (Optional[List[str]]): + If specified, the output will be exactly one of the choices. + timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -396,6 +432,10 @@ def _create_stream(**kwargs): frequency_penalty=frequency_penalty, top_k=top_k, top_p=top_p, + include_stop_str_in_output=include_stop_str_in_output, + guided_json=guided_json, + guided_regex=guided_regex, + guided_choice=guided_choice, ).dict() response = cls.post_sync( resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}", diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 70abd6cb..a3ed3209 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -279,6 +279,10 @@ class CompletionSyncV1Request(BaseModel): frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) top_k: Optional[int] = Field(default=None, ge=-1) top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + include_stop_str_in_output: Optional[bool] = Field(default=False) + guided_json: Optional[Dict[str, Any]] = Field(default=None) + guided_regex: Optional[str] = Field(default=None) + guided_choice: Optional[List[str]] = Field(default=None) class TokenOutput(BaseModel): @@ -349,6 +353,10 @@ class CompletionStreamV1Request(BaseModel): frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) top_k: Optional[int] = Field(default=None, ge=-1) top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + include_stop_str_in_output: Optional[bool] = Field(default=False) + guided_json: Optional[Dict[str, Any]] = Field(default=None) + guided_regex: Optional[str] = Field(default=None) + guided_choice: Optional[List[str]] = Field(default=None) class CompletionStreamOutput(BaseModel): diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 2563b814..8ddec08f 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta27" +version = "0.0.0.beta28" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] diff --git a/clients/python/setup.py b/clients/python/setup.py index 257516fc..a33e6a03 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -3,6 +3,6 @@ setup( name="scale-llm-engine", python_requires=">=3.7", - version="0.0.0.beta27", + version="0.0.0.beta28", packages=find_packages(), ) diff --git a/docs/guides/completions.md b/docs/guides/completions.md index 69dfe1bd..86bb9f0b 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -193,6 +193,59 @@ response = Completion.batch_create( print(response.json()) ``` +## Guided decoding + +Guided decoding is supported by vLLM and backed by [Outlines](https://github.com/outlines-dev/outlines). +It enforces certain token generation patterns by tinkering with the sampling logits. + +=== "Guided decoding with regex" +```python +from llmengine import Completion + +response = Completion.create( + model="llama-2-7b", + prompt="Hello, my name is", + max_new_tokens=10, + temperature=0.2, + guided_regex="Sean.*", +) + +print(response.json()) +# {"request_id":"c19f0fae-317e-4f69-8e06-c04189299b9c","output":{"text":"Sean. I'm a 2","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}} +``` + +=== "Guided decoding with choice" +```python +from llmengine import Completion + +response = Completion.create( + model="llama-2-7b", + prompt="Hello, my name is", + max_new_tokens=10, + temperature=0.2, + guided_choice=["Sean", "Brian", "Tim"], +) + +print(response.json()) +# {"request_id":"641e2af3-a3e3-4493-98b9-d38115ba0d22","output":{"text":"Sean","num_prompt_tokens":6,"num_completion_tokens":4,"tokens":null}} +``` + +=== "Guided decoding with JSON schema" +```python +from llmengine import Completion + +response = Completion.create( + model="llama-2-7b", + prompt="Hello, my name is", + max_new_tokens=10, + temperature=0.2, + guided_json={"properties":{"myString":{"type":"string"}},"required":["myString"]}, +) + +print(response.json()) +# {"request_id":"5b184654-96b6-4932-9eb6-382a51fdb3d5","output":{"text":"{\"myString\" : \"John Doe","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}} +``` + ## Which model should I use? See the [Model Zoo](../../model_zoo) for more information on best practices for which model to use for Completions. diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 8d335d8d..9fb8ed1d 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -184,6 +184,18 @@ class CompletionSyncV1Request(BaseModel): """ Whether to include the stop strings in output text. """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. + """ class TokenOutput(BaseModel): @@ -248,6 +260,18 @@ class CompletionStreamV1Request(BaseModel): """ Whether to include the stop strings in output text. """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. Only supported in vllm. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. Only supported in vllm. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. Only supported in vllm. + """ class CompletionStreamOutput(BaseModel): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index b458343c..65973ced 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1365,6 +1365,26 @@ def validate_and_update_completion_params( "include_stop_str_in_output is only supported in vllm." ) + guided_count = 0 + if request.guided_choice is not None: + guided_count += 1 + if request.guided_json is not None: + guided_count += 1 + if request.guided_regex is not None: + guided_count += 1 + + if guided_count > 1: + raise ObjectHasInvalidValueException( + "Only one of guided_json, guided_choice, guided_regex can be enabled." + ) + + if ( + request.guided_choice is not None + or request.guided_regex is not None + or request.guided_json is not None + ) and not inference_framework == LLMInferenceFramework.VLLM: + raise ObjectHasInvalidValueException("Guided decoding is only supported in vllm.") + return request @@ -1656,6 +1676,12 @@ async def execute( vllm_args["logprobs"] = 1 if request.include_stop_str_in_output is not None: vllm_args["include_stop_str_in_output"] = request.include_stop_str_in_output + if request.guided_choice is not None: + vllm_args["guided_choice"] = request.guided_choice + if request.guided_regex is not None: + vllm_args["guided_regex"] = request.guided_regex + if request.guided_json is not None: + vllm_args["guided_json"] = request.guided_json inference_request = SyncEndpointPredictV1Request( args=vllm_args, @@ -1918,6 +1944,12 @@ async def execute( args["logprobs"] = 1 if request.include_stop_str_in_output is not None: args["include_stop_str_in_output"] = request.include_stop_str_in_output + if request.guided_choice is not None: + args["guided_choice"] = request.guided_choice + if request.guided_regex is not None: + args["guided_regex"] = request.guided_regex + if request.guided_json is not None: + args["guided_json"] = request.guided_json args["stream"] = True elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: args = { diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index 78e033bb..3c1cf851 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,3 +1,2 @@ -ray>=2.9 -vllm==0.3.2 +vllm==0.3.3 pydantic>=2.0 diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index 5bd3f6e4..c4dd0eed 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -7,10 +7,12 @@ from typing import AsyncGenerator import uvicorn -from fastapi import BackgroundTasks, FastAPI, Request +from fastapi import BackgroundTasks, FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest +from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -38,7 +40,35 @@ async def generate(request: Request) -> Response: request_dict = await request.json() prompt = request_dict.pop("prompt") stream = request_dict.pop("stream", False) + guided_json = request_dict.pop("guided_json", None) + guided_regex = request_dict.pop("guided_regex", None) + guided_choice = request_dict.pop("guided_choice", None) sampling_params = SamplingParams(**request_dict) + + # Dummy request to get guided decode logit processor + try: + partial_openai_request = OpenAICompletionRequest.model_validate( + { + "model": "", + "prompt": "", + "guided_json": guided_json, + "guided_regex": guided_regex, + "guided_choice": guided_choice, + } + ) + except Exception: + raise HTTPException( + status_code=400, detail="Bad request: failed to parse guided decoding parameters." + ) + + guided_decode_logit_processor = await get_guided_decoding_logits_processor( + partial_openai_request, engine.get_tokenizer() + ) + if guided_decode_logit_processor is not None: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(guided_decode_logit_processor) + request_id = random_uuid() results_generator = engine.generate(prompt, sampling_params, request_id) diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 61473b37..4b57afa1 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -3725,6 +3725,138 @@ def llm_model_endpoint_sync( return model_endpoint, model_endpoint_json +@pytest.fixture +def llm_model_endpoint_stream( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + model_endpoint = ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_2", + name="test_llm_model_endpoint_name_1", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.STREAMING, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_1", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + __root__=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) + model_endpoint_json: Dict[str, Any] = { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "model_name": "llama-7b", + "source": "hugging_face", + "status": "READY", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + "spec": { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "endpoint_type": "streaming", + "destination": "test_destination", + "deployment_name": f"{test_api_key}-test_llm_model_endpoint_name_1", + "metadata": { + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + "bundle_name": "test_model_bundle_name_1", + "status": "READY", + "post_inference_hooks": ["callback"], + "default_callback_url": "http://www.example.com", + "default_callback_auth": { + "kind": "basic", + "username": "test_username", + "password": "test_password", + }, + "labels": {}, + "aws_role": "test_aws_role", + "results_s3_bucket": "test_s3_bucket", + "created_by": test_api_key, + "created_at": "2022-01-03T00:00:00", + "last_updated_at": "2022-01-03T00:00:00", + "deployment_state": { + "min_workers": 1, + "max_workers": 3, + "per_worker": 2, + "available_workers": 1, + "unavailable_workers": 1, + }, + "resource_state": { + "cpus": "1", + "gpus": 1, + "memory": "1G", + "gpu_type": "nvidia-tesla-t4", + "storage": "10G", + "optimize_costs": True, + }, + "num_queued_items": 1, + "public_inference": True, + }, + } + return model_endpoint, model_endpoint_json + + @pytest.fixture def llm_model_endpoint_sync_tgi( test_api_key: str, model_bundle_1: ModelBundle diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 10b37c7d..b2496ff9 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -51,6 +51,7 @@ UpdateLLMModelEndpointV1UseCase, _include_safetensors_bin_or_pt, infer_hardware_from_model_name, + validate_and_update_completion_params, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @@ -602,6 +603,8 @@ async def test_completion_sync_use_case_success( llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], completion_sync_request: CompletionSyncV1Request, ): + completion_sync_request.include_stop_str_in_output = True + completion_sync_request.guided_json = {} fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( SyncEndpointPredictV1Response( @@ -987,6 +990,42 @@ async def test_completion_sync_use_case_not_sync_endpoint_raises( ) +@pytest.mark.asyncio +async def test_validate_and_update_completion_params(): + completion_sync_request = CompletionSyncV1Request( + prompt="What is machine learning?", + max_new_tokens=10, + temperature=0.5, + return_token_log_probs=True, + ) + + validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request) + + validate_and_update_completion_params( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request + ) + + completion_sync_request.include_stop_str_in_output = True + with pytest.raises(ObjectHasInvalidValueException): + validate_and_update_completion_params( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request + ) + completion_sync_request.include_stop_str_in_output = None + + completion_sync_request.guided_regex = "" + completion_sync_request.guided_json = {} + completion_sync_request.guided_choice = [""] + with pytest.raises(ObjectHasInvalidValueException): + validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request) + + completion_sync_request.guided_regex = None + completion_sync_request.guided_choice = None + with pytest.raises(ObjectHasInvalidValueException): + validate_and_update_completion_params( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request + ) + + @pytest.mark.asyncio @mock.patch( "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", @@ -1079,6 +1118,116 @@ async def test_completion_stream_use_case_success( i += 1 +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=7, +) +async def test_completion_stream_vllm_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_stream: Tuple[ModelEndpoint, Any], + completion_stream_request: CompletionStreamV1Request, +): + completion_stream_request.guided_json = {} + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_stream[0]) + fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": "I", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 1, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": " am", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 2, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": " a", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 3, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": " new", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 4, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": "bie", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 5, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": ".", + "finished": True, + "count_prompt_tokens": 7, + "count_output_tokens": 6, + } + }, + traceback=None, + ), + ] + use_case = CompletionStreamV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_stream[0].record.name, + request=completion_stream_request, + ) + output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] + i = 0 + async for message in response_1: + assert message.dict()["output"]["text"] == output_texts[i] + if i == 5: + assert message.dict()["output"]["num_prompt_tokens"] == 7 + assert message.dict()["output"]["num_completion_tokens"] == 6 + i += 1 + + @pytest.mark.asyncio @mock.patch( "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens",