diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index cc8f28be..ac05d1bb 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -25,6 +25,10 @@ CompletionStreamOutput, CompletionStreamResponse, CompletionSyncResponse, + CreateBatchCompletionsModelConfig, + CreateBatchCompletionsRequest, + CreateBatchCompletionsRequestContent, + CreateBatchCompletionsResponse, CreateFineTuneRequest, CreateFineTuneResponse, DeleteFileResponse, @@ -51,6 +55,10 @@ "CompletionStreamOutput", "CompletionStreamResponse", "CompletionSyncResponse", + "CreateBatchCompletionsModelConfig", + "CreateBatchCompletionsRequest", + "CreateBatchCompletionsRequestContent", + "CreateBatchCompletionsResponse", "CreateFineTuneRequest", "CreateFineTuneResponse", "DeleteFileResponse", diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 507754d8..3a02f04e 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -6,9 +6,14 @@ CompletionStreamV1Request, CompletionSyncResponse, CompletionSyncV1Request, + CreateBatchCompletionsModelConfig, + CreateBatchCompletionsRequest, + CreateBatchCompletionsRequestContent, + CreateBatchCompletionsResponse, ) COMPLETION_TIMEOUT = 300 +HTTP_TIMEOUT = 60 class Completion(APIEngine): @@ -397,3 +402,96 @@ def _create_stream(**kwargs): timeout=timeout, ) return CompletionSyncResponse.parse_obj(response) + + @classmethod + def batch_create( + cls, + output_data_path: str, + model_config: CreateBatchCompletionsModelConfig, + content: Optional[CreateBatchCompletionsRequestContent] = None, + input_data_path: Optional[str] = None, + data_parallelism: int = 1, + max_runtime_sec: int = 24 * 3600, + ) -> CreateBatchCompletionsResponse: + """ + Creates a batch completion for the provided input data. The job runs offline and does not depend on an existing model endpoint. + + Prompts can be passed in from an input file, or as a part of the request. + + Args: + output_data_path (str): + The path to the output file. The output file will be a JSON file containing the completions. + + model_config (CreateBatchCompletionsModelConfig): + The model configuration to use for the batch completion. + + content (Optional[CreateBatchCompletionsRequestContent]): + The content to use for the batch completion. Either one of `content` or `input_data_path` must be provided. + + input_data_path (Optional[str]): + The path to the input file. The input file should be a JSON file with data of type `BatchCompletionsRequestContent`. Either one of `content` or `input_data_path` must be provided. + + data_parallelism (int): + The number of parallel jobs to run. Data will be evenly distributed to the jobs. Defaults to 1. + + max_runtime_sec (int): + The maximum runtime of the batch completion in seconds. Defaults to 24 hours. + + Returns: + response (CreateBatchCompletionsResponse): The response containing the job id. + + === "Batch completions with prompts in the request" + ```python + from llmengine import Completion + from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent + + response = Completion.batch_create( + output_data_path="s3://my-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + content=CreateBatchCompletionsRequestContent( + prompts=["What is deep learning", "What is a neural network"], + max_new_tokens=10, + temperature=0.0 + ) + ) + print(response.json()) + ``` + + === "Batch completions with prompts in a file and with 2 parallel jobs" + ```python + from llmengine import Completion + from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent + + # Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path" + + response = Completion.batch_create( + input_data_path="s3://my-input-path", + output_data_path="s3://my-output-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + data_parallelism=2 + ) + print(response.json()) + ``` + """ + data = CreateBatchCompletionsRequest( + model_config=model_config, + content=content, + input_data_path=input_data_path, + output_data_path=output_data_path, + data_parallelism=data_parallelism, + max_runtime_sec=max_runtime_sec, + ).dict() + response = cls.post_sync( + resource_name="v1/llm/batch-completions", + data=data, + timeout=HTTP_TIMEOUT, + ) + return CreateBatchCompletionsResponse.parse_obj(response) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 2a37e912..64ec45a0 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -591,3 +591,99 @@ class GetFileContentResponse(BaseModel): content: str = Field(..., description="File content.") """File content.""" + + +class CreateBatchCompletionsRequestContent(BaseModel): + prompts: List[str] + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + + +class CreateBatchCompletionsModelConfig(BaseModel): + model: str + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + labels: Dict[str, str] + """ + Labels to attach to the batch inference job. + """ + num_shards: Optional[int] = 1 + """ + Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. + System may decide to use a different number than the given value. + """ + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + seed: Optional[int] = None + """ + Random seed for the model. + """ + + +class CreateBatchCompletionsRequest(BaseModel): + """ + Request object for batch completions. + """ + + input_data_path: Optional[str] + output_data_path: str + """ + Path to the output file. The output file will be a JSON file of type List[CompletionOutput]. + """ + content: Optional[CreateBatchCompletionsRequestContent] = None + """ + Either `input_data_path` or `content` needs to be provided. + When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. + """ + model_config: CreateBatchCompletionsModelConfig + """ + Model configuration for the batch inference. Hardware configurations are inferred. + """ + data_parallelism: Optional[int] = Field(default=1, ge=1, le=64) + """ + Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference. + """ + max_runtime_sec: Optional[int] = Field(default=24 * 3600, ge=1, le=2 * 24 * 3600) + """ + Maximum runtime of the batch inference in seconds. Default to one day. + """ + + +class CreateBatchCompletionsResponse(BaseModel): + job_id: str + """ + The ID of the batch completions job. + """ diff --git a/docs/api/data_types.md b/docs/api/data_types.md index 44dd3d8f..206c93e6 100644 --- a/docs/api/data_types.md +++ b/docs/api/data_types.md @@ -110,3 +110,41 @@ options: members: - deleted + +::: llmengine.CreateBatchCompletionsRequestContent + options: + members: + - prompts + - max_new_tokens + - temperature + - stop_sequences + - return_token_log_probs + - presence_penalty + - frequency_penalty + - top_k + - top_p + +::: llmengine.CreateBatchCompletionsModelConfig + options: + members: + - model + - checkpoint_path + - labels + - num_shards + - quantize + - seed + +::: llmengine.CreateBatchCompletionsRequest + options: + members: + - input_data_path + - output_data_path + - content + - model_config + - data_parallelism + - max_runtime_sec + +::: llmengine.CreateBatchCompletionsResponse + options: + members: + - job_id diff --git a/docs/api/python_client.md b/docs/api/python_client.md index d77d28bc..c9e22723 100644 --- a/docs/api/python_client.md +++ b/docs/api/python_client.md @@ -5,6 +5,7 @@ members: - create - acreate + - batch_create ::: llmengine.FineTune options: diff --git a/docs/contributing.md b/docs/contributing.md index 37a6793a..8423c202 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -21,7 +21,7 @@ pip install -r requirements-docs.txt Our Python client API reference is autogenerated from our client. You can install the client in editable mode with ``` -pip install -r clients/python +pip install -e clients/python ``` ### Step 4: Run Locally