Skip to content

Commit

Permalink
Batch inference client / doc (#424)
Browse files Browse the repository at this point in the history
* batch inference client / doc

* fix

* fixes
  • Loading branch information
yunfeng-scale authored Jan 26, 2024
1 parent d4be9b9 commit 69f8bcb
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 1 deletion.
8 changes: 8 additions & 0 deletions clients/python/llmengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
CompletionStreamOutput,
CompletionStreamResponse,
CompletionSyncResponse,
CreateBatchCompletionsModelConfig,
CreateBatchCompletionsRequest,
CreateBatchCompletionsRequestContent,
CreateBatchCompletionsResponse,
CreateFineTuneRequest,
CreateFineTuneResponse,
DeleteFileResponse,
Expand All @@ -51,6 +55,10 @@
"CompletionStreamOutput",
"CompletionStreamResponse",
"CompletionSyncResponse",
"CreateBatchCompletionsModelConfig",
"CreateBatchCompletionsRequest",
"CreateBatchCompletionsRequestContent",
"CreateBatchCompletionsResponse",
"CreateFineTuneRequest",
"CreateFineTuneResponse",
"DeleteFileResponse",
Expand Down
98 changes: 98 additions & 0 deletions clients/python/llmengine/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
CompletionStreamV1Request,
CompletionSyncResponse,
CompletionSyncV1Request,
CreateBatchCompletionsModelConfig,
CreateBatchCompletionsRequest,
CreateBatchCompletionsRequestContent,
CreateBatchCompletionsResponse,
)

COMPLETION_TIMEOUT = 300
HTTP_TIMEOUT = 60


class Completion(APIEngine):
Expand Down Expand Up @@ -397,3 +402,96 @@ def _create_stream(**kwargs):
timeout=timeout,
)
return CompletionSyncResponse.parse_obj(response)

@classmethod
def batch_create(
cls,
output_data_path: str,
model_config: CreateBatchCompletionsModelConfig,
content: Optional[CreateBatchCompletionsRequestContent] = None,
input_data_path: Optional[str] = None,
data_parallelism: int = 1,
max_runtime_sec: int = 24 * 3600,
) -> CreateBatchCompletionsResponse:
"""
Creates a batch completion for the provided input data. The job runs offline and does not depend on an existing model endpoint.
Prompts can be passed in from an input file, or as a part of the request.
Args:
output_data_path (str):
The path to the output file. The output file will be a JSON file containing the completions.
model_config (CreateBatchCompletionsModelConfig):
The model configuration to use for the batch completion.
content (Optional[CreateBatchCompletionsRequestContent]):
The content to use for the batch completion. Either one of `content` or `input_data_path` must be provided.
input_data_path (Optional[str]):
The path to the input file. The input file should be a JSON file with data of type `BatchCompletionsRequestContent`. Either one of `content` or `input_data_path` must be provided.
data_parallelism (int):
The number of parallel jobs to run. Data will be evenly distributed to the jobs. Defaults to 1.
max_runtime_sec (int):
The maximum runtime of the batch completion in seconds. Defaults to 24 hours.
Returns:
response (CreateBatchCompletionsResponse): The response containing the job id.
=== "Batch completions with prompts in the request"
```python
from llmengine import Completion
from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent
response = Completion.batch_create(
output_data_path="s3://my-path",
model_config=CreateBatchCompletionsModelConfig(
model="llama-2-7b",
checkpoint_path="s3://checkpoint-path",
labels={"team":"my-team", "product":"my-product"}
),
content=CreateBatchCompletionsRequestContent(
prompts=["What is deep learning", "What is a neural network"],
max_new_tokens=10,
temperature=0.0
)
)
print(response.json())
```
=== "Batch completions with prompts in a file and with 2 parallel jobs"
```python
from llmengine import Completion
from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent
# Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path"
response = Completion.batch_create(
input_data_path="s3://my-input-path",
output_data_path="s3://my-output-path",
model_config=CreateBatchCompletionsModelConfig(
model="llama-2-7b",
checkpoint_path="s3://checkpoint-path",
labels={"team":"my-team", "product":"my-product"}
),
data_parallelism=2
)
print(response.json())
```
"""
data = CreateBatchCompletionsRequest(
model_config=model_config,
content=content,
input_data_path=input_data_path,
output_data_path=output_data_path,
data_parallelism=data_parallelism,
max_runtime_sec=max_runtime_sec,
).dict()
response = cls.post_sync(
resource_name="v1/llm/batch-completions",
data=data,
timeout=HTTP_TIMEOUT,
)
return CreateBatchCompletionsResponse.parse_obj(response)
96 changes: 96 additions & 0 deletions clients/python/llmengine/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,3 +591,99 @@ class GetFileContentResponse(BaseModel):

content: str = Field(..., description="File content.")
"""File content."""


class CreateBatchCompletionsRequestContent(BaseModel):
prompts: List[str]
max_new_tokens: int
temperature: float = Field(ge=0.0, le=1.0)
"""
Temperature of the sampling. Setting to 0 equals to greedy sampling.
"""
stop_sequences: Optional[List[str]] = None
"""
List of sequences to stop the completion at.
"""
return_token_log_probs: Optional[bool] = False
"""
Whether to return the log probabilities of the tokens.
"""
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
"""
Only supported in vllm, lightllm
Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty
"""
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
"""
Only supported in vllm, lightllm
Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty
"""
top_k: Optional[int] = Field(default=None, ge=-1)
"""
Controls the number of top tokens to consider. -1 means consider all tokens.
"""
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
"""
Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens.
"""


class CreateBatchCompletionsModelConfig(BaseModel):
model: str
checkpoint_path: Optional[str] = None
"""
Path to the checkpoint to load the model from.
"""
labels: Dict[str, str]
"""
Labels to attach to the batch inference job.
"""
num_shards: Optional[int] = 1
"""
Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config.
System may decide to use a different number than the given value.
"""
quantize: Optional[Quantization] = None
"""
Whether to quantize the model.
"""
seed: Optional[int] = None
"""
Random seed for the model.
"""


class CreateBatchCompletionsRequest(BaseModel):
"""
Request object for batch completions.
"""

input_data_path: Optional[str]
output_data_path: str
"""
Path to the output file. The output file will be a JSON file of type List[CompletionOutput].
"""
content: Optional[CreateBatchCompletionsRequestContent] = None
"""
Either `input_data_path` or `content` needs to be provided.
When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent.
"""
model_config: CreateBatchCompletionsModelConfig
"""
Model configuration for the batch inference. Hardware configurations are inferred.
"""
data_parallelism: Optional[int] = Field(default=1, ge=1, le=64)
"""
Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference.
"""
max_runtime_sec: Optional[int] = Field(default=24 * 3600, ge=1, le=2 * 24 * 3600)
"""
Maximum runtime of the batch inference in seconds. Default to one day.
"""


class CreateBatchCompletionsResponse(BaseModel):
job_id: str
"""
The ID of the batch completions job.
"""
38 changes: 38 additions & 0 deletions docs/api/data_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,41 @@
options:
members:
- deleted

::: llmengine.CreateBatchCompletionsRequestContent
options:
members:
- prompts
- max_new_tokens
- temperature
- stop_sequences
- return_token_log_probs
- presence_penalty
- frequency_penalty
- top_k
- top_p

::: llmengine.CreateBatchCompletionsModelConfig
options:
members:
- model
- checkpoint_path
- labels
- num_shards
- quantize
- seed

::: llmengine.CreateBatchCompletionsRequest
options:
members:
- input_data_path
- output_data_path
- content
- model_config
- data_parallelism
- max_runtime_sec

::: llmengine.CreateBatchCompletionsResponse
options:
members:
- job_id
1 change: 1 addition & 0 deletions docs/api/python_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
members:
- create
- acreate
- batch_create

::: llmengine.FineTune
options:
Expand Down
2 changes: 1 addition & 1 deletion docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pip install -r requirements-docs.txt
Our Python client API reference is autogenerated from our client. You can install the client in editable mode with

```
pip install -r clients/python
pip install -e clients/python
```

### Step 4: Run Locally
Expand Down

0 comments on commit 69f8bcb

Please sign in to comment.