diff --git a/pyproject.toml b/pyproject.toml index a2789082..d53a0060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] dependencies = [ - "openai >= 1.6.1, < 2.0.0", + "openai >= 1.34.0, < 2.0.0", "typer >= 0.7.0, < 1.0.0", "click >= 7.1.1, < 9.0.0", "rich >= 13.1.0, < 14.0.0", diff --git a/sgpt/function.py b/sgpt/function.py index ed4c99a5..d1156ae2 100644 --- a/sgpt/function.py +++ b/sgpt/function.py @@ -59,4 +59,15 @@ def get_function(name: str) -> Callable[..., Any]: def get_openai_schemas() -> List[Dict[str, Any]]: - return [function.openai_schema for function in functions] + transformed_schemas = [] + for function in functions: + schema = { + "type": "function", + "function": { + "name": function.openai_schema["name"], + "description": function.openai_schema.get("description", ""), + "parameters": function.openai_schema.get("parameters", {}), + }, + } + transformed_schemas.append(schema) + return transformed_schemas diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index ab03fd47..c6302132 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -9,6 +9,7 @@ from ..role import DefaultRoles, SystemRole completion: Callable[..., Any] = lambda *args, **kwargs: Generator[Any, None, None] + base_url = cfg.get("API_BASE_URL") use_litellm = cfg.get("USE_LITELLM") == "true" additional_kwargs = { @@ -96,12 +97,16 @@ def get_completion( if is_shell_role or is_code_role or is_dsc_shell_role: functions = None + if functions: + additional_kwargs["tool_choice"] = "auto" + additional_kwargs["tools"] = functions + additional_kwargs["parallel_tool_calls"] = False + response = completion( model=model, temperature=temperature, top_p=top_p, messages=messages, - functions=functions, stream=True, **additional_kwargs, ) @@ -109,16 +114,18 @@ def get_completion( try: for chunk in response: delta = chunk.choices[0].delta + # LiteLLM uses dict instead of Pydantic object like OpenAI does. - function_call = ( - delta.get("function_call") if use_litellm else delta.function_call + tool_calls = ( + delta.get("tool_calls") if use_litellm else delta.tool_calls ) - if function_call: - if function_call.name: - name = function_call.name - if function_call.arguments: - arguments += function_call.arguments - if chunk.choices[0].finish_reason == "function_call": + if tool_calls: + for tool_call in tool_calls: + if tool_call.function.name: + name = tool_call.function.name + if tool_call.function.arguments: + arguments += tool_call.function.arguments + if chunk.choices[0].finish_reason == "tool_calls": yield from self.handle_function_call(messages, name, arguments) yield from self.get_completion( model=model, diff --git a/tests/test_default.py b/tests/test_default.py index a9b11a25..0e5e4a57 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -209,7 +209,6 @@ def test_llm_options(completion): model=args["--model"], temperature=args["--temperature"], top_p=args["--top-p"], - functions=None, ) completion.assert_called_once_with(**expected_args) assert result.exit_code == 0 diff --git a/tests/utils.py b/tests/utils.py index 8da7d947..c1a89864 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -54,7 +54,6 @@ def comp_args(role, prompt, **kwargs): "model": cfg.get("DEFAULT_MODEL"), "temperature": 0.0, "top_p": 1.0, - "functions": None, "stream": True, **kwargs, }