Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Add tool calling support #2675

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls


def is_chat_completions_request(inputs: Dict) -> bool:
Expand All @@ -30,7 +31,6 @@ def parse_chat_completions_request_vllm(
rolling_batch,
tokenizer,
chat_template: Optional[str] = None,
image_token: Optional[str] = None,
configs: Properties = None,
is_mistral_tokenizer: bool = False,
):
Expand All @@ -47,12 +47,36 @@ def parse_chat_completions_request_vllm(
f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, "
f"please ensure that your tokenizer supports chat templates.")

tool_parser = rolling_batch.get_tool_parser()
chat_params = ChatProperties(**input_map)
exclude = {"messages"}

if chat_params.tool_choice == "required":
raise ValueError("tool_choice = \"required\" is not supported!")

if is_mistral_tokenizer:
maybe_serialize_tool_calls(chat_params)
elif chat_params.tool_choice == "auto" and tool_parser is None:
raise ValueError(
"\"auto\" tool choice requires tool_call_parser to be available")

should_parse_tools = tool_parser is not None and (hasattr(
chat_params, "tool_choice") and chat_params.tool_choice != "none")
if should_parse_tools:
chat_params = tool_parser(tokenizer).adjust_request( # type: ignore
request=chat_params)

exclude = {"messages", "tools"}
param = chat_params.model_dump(exclude_none=True, exclude=exclude)

tool_dicts = None if chat_params.tools is None else [
tool.model_dump() for tool in chat_params.tools
]

conversation, mm_data = parse_chat_messages(
chat_params.messages, rolling_batch.get_model_config(), tokenizer)
chat_params.messages,
rolling_batch.get_model_config(),
tokenizer,
content_format="string")

prompt_data: Union[str, List[int]]
if is_mistral_tokenizer:
Expand All @@ -61,13 +85,15 @@ def parse_chat_completions_request_vllm(
messages=chat_params.messages,
chat_template=chat_template,
add_generation_prompt=True,
tools=tool_dicts,
)
else:
text_inputs = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=True,
tools=tool_dicts,
)

param["details"] = True # Enable details for chat completions
Expand Down
1 change: 0 additions & 1 deletion engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
kwargs.get("is_rolling_batch"),
rolling_batch,
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import ast
from enum import Enum
from typing import Optional, Any, Mapping, Tuple, Dict

from pydantic import field_validator, model_validator

from djl_python.properties_manager.properties import Properties
from vllm.entrypoints.openai.tool_parsers import ToolParserManager


class VllmRbProperties(Properties):
Expand Down Expand Up @@ -72,6 +72,10 @@ class VllmRbProperties(Properties):
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None

# Tool calling properties
enable_auto_tool_choice: Optional[bool] = False
tool_call_parser: Optional[str] = None

@field_validator('engine')
def validate_engine(cls, engine):
if engine != "Python":
Expand Down Expand Up @@ -132,3 +136,12 @@ def validate_pipeline_parallel(self):
"Pipeline parallelism is not supported in vLLM's LLMEngine used in rolling_batch implementation"
)
return self

@model_validator(mode='after')
def validate_tool_call_parser(self):
valid_tool_parses = ToolParserManager.tool_parsers.keys()
if self.enable_auto_tool_choice \
and self.tool_call_parser not in valid_tool_parses:
raise ValueError(
f"Invalid tool call parser: {self.tool_call_parser} "
f"(chose from {{ {','.join(valid_tool_parses)} }})")
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def use_vllm_chat_completions(self):
"""
return False

def get_tool_parser(self):
"""
:return: the tool call parser if available
"""
return None

@abstractmethod
def inference(self, new_requests: List[Request]) -> List:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from collections import OrderedDict, defaultdict

from vllm import LLMEngine, SamplingParams
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid, AtomicCounter

from djl_python.request import Request
Expand All @@ -21,7 +23,7 @@
update_request_cache_with_output, create_lora_request, get_lora_request,
get_engine_args_from_config, get_prompt_inputs)
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from typing import List, Optional
from typing import Callable, List, Optional

# FIXME: Once all vllm versions are past 0.6.0 we can move to just struct_fields
VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__) if hasattr(
Expand Down Expand Up @@ -53,6 +55,16 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = {}
self.is_mistral_tokenizer = self.vllm_configs.tokenizer_mode == 'mistral'
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.vllm_configs.enable_auto_tool_choice:
try:
self.tool_parser = ToolParserManager.get_tool_parser(
self.vllm_configs.tool_call_parser)
except Exception as e:
raise TypeError(
"Error: option.enable_auto_tools requires "
f"tool call parser:'{self.vllm_configs.tool_call_parser}' which has not "
"been registered") from e

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
Expand All @@ -66,6 +78,9 @@ def get_huggingface_model_config(self):
def use_vllm_chat_completions(self):
return True

def get_tool_parser(self):
return self.tool_parser

def reset(self) -> None:
"""
Aborts all requests
Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run(self):
test_requirements = [
'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops',
'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf',
'pydantic>=2.0', "objgraph", "vllm==0.6.3.post1"
'pydantic>=2.0', "objgraph", "vllm"
]

setup(name='djl_python',
Expand Down
Loading