Skip to content

Commit

Permalink
Guided decoding (#476)
Browse files Browse the repository at this point in the history
* Guided decoding

* endpoints

* fix

* update client

* unit tests

* fix test

* coverage

* coverage

* fix

* try to bump coverage

* more tests!

* lint
  • Loading branch information
yunfeng-scale authored Mar 21, 2024
1 parent 44fe4e8 commit 79dc5fa
Show file tree
Hide file tree
Showing 12 changed files with 474 additions and 7 deletions.
2 changes: 1 addition & 1 deletion clients/python/llmengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.0.0b27"
__version__ = "0.0.0b28"

import os
from typing import Sequence
Expand Down
42 changes: 41 additions & 1 deletion clients/python/llmengine/completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncIterable, Iterator, List, Optional, Union
from typing import Any, AsyncIterable, Dict, Iterator, List, Optional, Union

from llmengine.api_engine import APIEngine
from llmengine.data_types import (
Expand Down Expand Up @@ -43,6 +43,10 @@ async def acreate(
frequency_penalty: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
include_stop_str_in_output: Optional[bool] = False,
guided_json: Optional[Dict[str, Any]] = None,
guided_regex: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
timeout: int = COMPLETION_TIMEOUT,
stream: bool = False,
) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]:
Expand Down Expand Up @@ -102,6 +106,18 @@ async def acreate(
Float that controls the cumulative probability of the top tokens to consider.
Range: (0.0, 1.0]. 1.0 means consider all tokens.
include_stop_str_in_output (Optional[bool]):
Whether to include the stop sequence in the output. Default to False.
guided_json (Optional[Dict[str, Any]]):
If specified, the output will follow the JSON schema. For examples see https://json-schema.org/learn/miscellaneous-examples.
guided_regex (Optional[str]):
If specified, the output will follow the regex pattern.
guided_choice (Optional[List[str]]):
If specified, the output will be exactly one of the choices.
timeout (int):
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.
Expand Down Expand Up @@ -198,6 +214,10 @@ async def _acreate_stream(
frequency_penalty=frequency_penalty,
top_k=top_k,
top_p=top_p,
include_stop_str_in_output=include_stop_str_in_output,
guided_json=guided_json,
guided_regex=guided_regex,
guided_choice=guided_choice,
timeout=timeout,
)

Expand Down Expand Up @@ -237,6 +257,10 @@ def create(
frequency_penalty: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
include_stop_str_in_output: Optional[bool] = False,
guided_json: Optional[Dict[str, Any]] = None,
guided_regex: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
timeout: int = COMPLETION_TIMEOUT,
stream: bool = False,
) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]:
Expand Down Expand Up @@ -297,6 +321,18 @@ def create(
Float that controls the cumulative probability of the top tokens to consider.
Range: (0.0, 1.0]. 1.0 means consider all tokens.
include_stop_str_in_output (Optional[bool]):
Whether to include the stop sequence in the output. Default to False.
guided_json (Optional[Dict[str, Any]]):
If specified, the output will follow the JSON schema.
guided_regex (Optional[str]):
If specified, the output will follow the regex pattern.
guided_choice (Optional[List[str]]):
If specified, the output will be exactly one of the choices.
timeout (int):
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.
Expand Down Expand Up @@ -396,6 +432,10 @@ def _create_stream(**kwargs):
frequency_penalty=frequency_penalty,
top_k=top_k,
top_p=top_p,
include_stop_str_in_output=include_stop_str_in_output,
guided_json=guided_json,
guided_regex=guided_regex,
guided_choice=guided_choice,
).dict()
response = cls.post_sync(
resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}",
Expand Down
8 changes: 8 additions & 0 deletions clients/python/llmengine/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ class CompletionSyncV1Request(BaseModel):
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
top_k: Optional[int] = Field(default=None, ge=-1)
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
include_stop_str_in_output: Optional[bool] = Field(default=False)
guided_json: Optional[Dict[str, Any]] = Field(default=None)
guided_regex: Optional[str] = Field(default=None)
guided_choice: Optional[List[str]] = Field(default=None)


class TokenOutput(BaseModel):
Expand Down Expand Up @@ -349,6 +353,10 @@ class CompletionStreamV1Request(BaseModel):
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
top_k: Optional[int] = Field(default=None, ge=-1)
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
include_stop_str_in_output: Optional[bool] = Field(default=False)
guided_json: Optional[Dict[str, Any]] = Field(default=None)
guided_regex: Optional[str] = Field(default=None)
guided_choice: Optional[List[str]] = Field(default=None)


class CompletionStreamOutput(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "scale-llm-engine"
version = "0.0.0.beta27"
version = "0.0.0.beta28"
description = "Scale LLM Engine Python client"
license = "Apache-2.0"
authors = ["Phil Chen <[email protected]>"]
Expand Down
2 changes: 1 addition & 1 deletion clients/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
setup(
name="scale-llm-engine",
python_requires=">=3.7",
version="0.0.0.beta27",
version="0.0.0.beta28",
packages=find_packages(),
)
53 changes: 53 additions & 0 deletions docs/guides/completions.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,59 @@ response = Completion.batch_create(
print(response.json())
```

## Guided decoding

Guided decoding is supported by vLLM and backed by [Outlines](https://github.com/outlines-dev/outlines).
It enforces certain token generation patterns by tinkering with the sampling logits.

=== "Guided decoding with regex"
```python
from llmengine import Completion

response = Completion.create(
model="llama-2-7b",
prompt="Hello, my name is",
max_new_tokens=10,
temperature=0.2,
guided_regex="Sean.*",
)

print(response.json())
# {"request_id":"c19f0fae-317e-4f69-8e06-c04189299b9c","output":{"text":"Sean. I'm a 2","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}}
```

=== "Guided decoding with choice"
```python
from llmengine import Completion

response = Completion.create(
model="llama-2-7b",
prompt="Hello, my name is",
max_new_tokens=10,
temperature=0.2,
guided_choice=["Sean", "Brian", "Tim"],
)

print(response.json())
# {"request_id":"641e2af3-a3e3-4493-98b9-d38115ba0d22","output":{"text":"Sean","num_prompt_tokens":6,"num_completion_tokens":4,"tokens":null}}
```

=== "Guided decoding with JSON schema"
```python
from llmengine import Completion

response = Completion.create(
model="llama-2-7b",
prompt="Hello, my name is",
max_new_tokens=10,
temperature=0.2,
guided_json={"properties":{"myString":{"type":"string"}},"required":["myString"]},
)

print(response.json())
# {"request_id":"5b184654-96b6-4932-9eb6-382a51fdb3d5","output":{"text":"{\"myString\" : \"John Doe","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}}
```

## Which model should I use?

See the [Model Zoo](../../model_zoo) for more information on best practices for which model to use for Completions.
24 changes: 24 additions & 0 deletions model-engine/model_engine_server/common/dtos/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ class CompletionSyncV1Request(BaseModel):
"""
Whether to include the stop strings in output text.
"""
guided_json: Optional[Dict[str, Any]] = None
"""
JSON schema for guided decoding.
"""
guided_regex: Optional[str] = None
"""
Regex for guided decoding.
"""
guided_choice: Optional[List[str]] = None
"""
Choices for guided decoding.
"""


class TokenOutput(BaseModel):
Expand Down Expand Up @@ -248,6 +260,18 @@ class CompletionStreamV1Request(BaseModel):
"""
Whether to include the stop strings in output text.
"""
guided_json: Optional[Dict[str, Any]] = None
"""
JSON schema for guided decoding. Only supported in vllm.
"""
guided_regex: Optional[str] = None
"""
Regex for guided decoding. Only supported in vllm.
"""
guided_choice: Optional[List[str]] = None
"""
Choices for guided decoding. Only supported in vllm.
"""


class CompletionStreamOutput(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,26 @@ def validate_and_update_completion_params(
"include_stop_str_in_output is only supported in vllm."
)

guided_count = 0
if request.guided_choice is not None:
guided_count += 1
if request.guided_json is not None:
guided_count += 1
if request.guided_regex is not None:
guided_count += 1

if guided_count > 1:
raise ObjectHasInvalidValueException(
"Only one of guided_json, guided_choice, guided_regex can be enabled."
)

if (
request.guided_choice is not None
or request.guided_regex is not None
or request.guided_json is not None
) and not inference_framework == LLMInferenceFramework.VLLM:
raise ObjectHasInvalidValueException("Guided decoding is only supported in vllm.")

return request


Expand Down Expand Up @@ -1656,6 +1676,12 @@ async def execute(
vllm_args["logprobs"] = 1
if request.include_stop_str_in_output is not None:
vllm_args["include_stop_str_in_output"] = request.include_stop_str_in_output
if request.guided_choice is not None:
vllm_args["guided_choice"] = request.guided_choice
if request.guided_regex is not None:
vllm_args["guided_regex"] = request.guided_regex
if request.guided_json is not None:
vllm_args["guided_json"] = request.guided_json

inference_request = SyncEndpointPredictV1Request(
args=vllm_args,
Expand Down Expand Up @@ -1918,6 +1944,12 @@ async def execute(
args["logprobs"] = 1
if request.include_stop_str_in_output is not None:
args["include_stop_str_in_output"] = request.include_stop_str_in_output
if request.guided_choice is not None:
args["guided_choice"] = request.guided_choice
if request.guided_regex is not None:
args["guided_regex"] = request.guided_regex
if request.guided_json is not None:
args["guided_json"] = request.guided_json
args["stream"] = True
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
args = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
ray>=2.9
vllm==0.3.2
vllm==0.3.3
pydantic>=2.0
32 changes: 31 additions & 1 deletion model-engine/model_engine_server/inference/vllm/vllm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from typing import AsyncGenerator

import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

Expand Down Expand Up @@ -38,7 +40,35 @@ async def generate(request: Request) -> Response:
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
guided_json = request_dict.pop("guided_json", None)
guided_regex = request_dict.pop("guided_regex", None)
guided_choice = request_dict.pop("guided_choice", None)
sampling_params = SamplingParams(**request_dict)

# Dummy request to get guided decode logit processor
try:
partial_openai_request = OpenAICompletionRequest.model_validate(
{
"model": "",
"prompt": "",
"guided_json": guided_json,
"guided_regex": guided_regex,
"guided_choice": guided_choice,
}
)
except Exception:
raise HTTPException(
status_code=400, detail="Bad request: failed to parse guided decoding parameters."
)

guided_decode_logit_processor = await get_guided_decoding_logits_processor(
partial_openai_request, engine.get_tokenizer()
)
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(guided_decode_logit_processor)

request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id)

Expand Down
Loading

0 comments on commit 79dc5fa

Please sign in to comment.