From fe88db69fb0ecb6b6f37bc06fd5eb2e7ee7dfbad Mon Sep 17 00:00:00 2001 From: Tijs Zwinkels Date: Fri, 24 May 2024 16:59:47 +0200 Subject: [PATCH 1/6] Adopt to new function calling format --- sgpt/function.py | 15 +++++++++++++-- sgpt/handlers/handler.py | 23 ++++++++++++++--------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/sgpt/function.py b/sgpt/function.py index ed4c99a5..aee653fc 100644 --- a/sgpt/function.py +++ b/sgpt/function.py @@ -2,7 +2,7 @@ import sys from abc import ABCMeta from pathlib import Path -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Union from .config import cfg @@ -59,4 +59,15 @@ def get_function(name: str) -> Callable[..., Any]: def get_openai_schemas() -> List[Dict[str, Any]]: - return [function.openai_schema for function in functions] + transformed_schemas = [] + for function in functions: + schema = { + "type": "function", + "function": { + "name": function.openai_schema["name"], + "description": function.openai_schema.get("description", ""), + "parameters": function.openai_schema.get("parameters", {}), + }, + } + transformed_schemas.append(schema) + return transformed_schemas diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index ab03fd47..4193a1fb 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -9,6 +9,7 @@ from ..role import DefaultRoles, SystemRole completion: Callable[..., Any] = lambda *args, **kwargs: Generator[Any, None, None] + base_url = cfg.get("API_BASE_URL") use_litellm = cfg.get("USE_LITELLM") == "true" additional_kwargs = { @@ -89,6 +90,7 @@ def get_completion( messages: List[Dict[str, Any]], functions: Optional[List[Dict[str, str]]], ) -> Generator[str, None, None]: + name = arguments = "" is_shell_role = self.role.name == DefaultRoles.SHELL.value is_code_role = self.role.name == DefaultRoles.CODE.value @@ -101,7 +103,8 @@ def get_completion( temperature=temperature, top_p=top_p, messages=messages, - functions=functions, + tools=functions, + tool_choice="auto", stream=True, **additional_kwargs, ) @@ -109,16 +112,18 @@ def get_completion( try: for chunk in response: delta = chunk.choices[0].delta + # LiteLLM uses dict instead of Pydantic object like OpenAI does. - function_call = ( - delta.get("function_call") if use_litellm else delta.function_call + tool_calls = ( + delta.get("tool_calls") if use_litellm else delta.tool_calls ) - if function_call: - if function_call.name: - name = function_call.name - if function_call.arguments: - arguments += function_call.arguments - if chunk.choices[0].finish_reason == "function_call": + if tool_calls: + for tool_call in tool_calls: + if tool_call.function.name: + name = tool_call.function.name + if tool_call.function.arguments: + arguments += tool_call.function.arguments + if chunk.choices[0].finish_reason == "tool_calls": yield from self.handle_function_call(messages, name, arguments) yield from self.get_completion( model=model, From db61625774e22744d6bdb7face3a033321257bff Mon Sep 17 00:00:00 2001 From: Tijs Zwinkels Date: Wed, 12 Jun 2024 21:53:33 +0200 Subject: [PATCH 2/6] Process review comments: - Prevent crash when user has empty functions directory - Prevent parallel function calling from the LLM as we don't handle that correctly - The option to disable parallel function calling is added in a later release of the openai library, so upgraded it --- pyproject.toml | 2 +- sgpt/handlers/handler.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2789082..d53a0060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] dependencies = [ - "openai >= 1.6.1, < 2.0.0", + "openai >= 1.34.0, < 2.0.0", "typer >= 0.7.0, < 1.0.0", "click >= 7.1.1, < 9.0.0", "rich >= 13.1.0, < 14.0.0", diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index 4193a1fb..1cbc4ae7 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -98,13 +98,16 @@ def get_completion( if is_shell_role or is_code_role or is_dsc_shell_role: functions = None + if functions: + additional_kwargs["tool_choice"] = "auto" + additional_kwargs["tools"] = functions + additional_kwargs["parallel_tool_calls"] = False + response = completion( model=model, temperature=temperature, top_p=top_p, messages=messages, - tools=functions, - tool_choice="auto", stream=True, **additional_kwargs, ) From bab51dc76c672336c6496daa831122f5845778d3 Mon Sep 17 00:00:00 2001 From: Tijs Zwinkels Date: Fri, 24 May 2024 16:59:47 +0200 Subject: [PATCH 3/6] Adopt to new function calling format --- sgpt/function.py | 15 +++++++++++++-- sgpt/handlers/handler.py | 23 ++++++++++++++--------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/sgpt/function.py b/sgpt/function.py index ed4c99a5..aee653fc 100644 --- a/sgpt/function.py +++ b/sgpt/function.py @@ -2,7 +2,7 @@ import sys from abc import ABCMeta from pathlib import Path -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Union from .config import cfg @@ -59,4 +59,15 @@ def get_function(name: str) -> Callable[..., Any]: def get_openai_schemas() -> List[Dict[str, Any]]: - return [function.openai_schema for function in functions] + transformed_schemas = [] + for function in functions: + schema = { + "type": "function", + "function": { + "name": function.openai_schema["name"], + "description": function.openai_schema.get("description", ""), + "parameters": function.openai_schema.get("parameters", {}), + }, + } + transformed_schemas.append(schema) + return transformed_schemas diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index ab03fd47..4193a1fb 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -9,6 +9,7 @@ from ..role import DefaultRoles, SystemRole completion: Callable[..., Any] = lambda *args, **kwargs: Generator[Any, None, None] + base_url = cfg.get("API_BASE_URL") use_litellm = cfg.get("USE_LITELLM") == "true" additional_kwargs = { @@ -89,6 +90,7 @@ def get_completion( messages: List[Dict[str, Any]], functions: Optional[List[Dict[str, str]]], ) -> Generator[str, None, None]: + name = arguments = "" is_shell_role = self.role.name == DefaultRoles.SHELL.value is_code_role = self.role.name == DefaultRoles.CODE.value @@ -101,7 +103,8 @@ def get_completion( temperature=temperature, top_p=top_p, messages=messages, - functions=functions, + tools=functions, + tool_choice="auto", stream=True, **additional_kwargs, ) @@ -109,16 +112,18 @@ def get_completion( try: for chunk in response: delta = chunk.choices[0].delta + # LiteLLM uses dict instead of Pydantic object like OpenAI does. - function_call = ( - delta.get("function_call") if use_litellm else delta.function_call + tool_calls = ( + delta.get("tool_calls") if use_litellm else delta.tool_calls ) - if function_call: - if function_call.name: - name = function_call.name - if function_call.arguments: - arguments += function_call.arguments - if chunk.choices[0].finish_reason == "function_call": + if tool_calls: + for tool_call in tool_calls: + if tool_call.function.name: + name = tool_call.function.name + if tool_call.function.arguments: + arguments += tool_call.function.arguments + if chunk.choices[0].finish_reason == "tool_calls": yield from self.handle_function_call(messages, name, arguments) yield from self.get_completion( model=model, From 6167d18fe00cc391b263bfef2bc246bc5ae5f2a3 Mon Sep 17 00:00:00 2001 From: Tijs Zwinkels Date: Wed, 12 Jun 2024 21:53:33 +0200 Subject: [PATCH 4/6] Process review comments: - Prevent crash when user has empty functions directory - Prevent parallel function calling from the LLM as we don't handle that correctly - The option to disable parallel function calling is added in a later release of the openai library, so upgraded it --- pyproject.toml | 2 +- sgpt/handlers/handler.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2789082..d53a0060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] dependencies = [ - "openai >= 1.6.1, < 2.0.0", + "openai >= 1.34.0, < 2.0.0", "typer >= 0.7.0, < 1.0.0", "click >= 7.1.1, < 9.0.0", "rich >= 13.1.0, < 14.0.0", diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index 4193a1fb..1cbc4ae7 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -98,13 +98,16 @@ def get_completion( if is_shell_role or is_code_role or is_dsc_shell_role: functions = None + if functions: + additional_kwargs["tool_choice"] = "auto" + additional_kwargs["tools"] = functions + additional_kwargs["parallel_tool_calls"] = False + response = completion( model=model, temperature=temperature, top_p=top_p, messages=messages, - tools=functions, - tool_choice="auto", stream=True, **additional_kwargs, ) From 03e0524d28952b8f511bb18dc89ed157b153b96a Mon Sep 17 00:00:00 2001 From: Hejia Zhang Date: Mon, 8 Jul 2024 21:52:02 +0800 Subject: [PATCH 5/6] Fix linter error --- sgpt/function.py | 2 +- sgpt/handlers/handler.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sgpt/function.py b/sgpt/function.py index aee653fc..d1156ae2 100644 --- a/sgpt/function.py +++ b/sgpt/function.py @@ -2,7 +2,7 @@ import sys from abc import ABCMeta from pathlib import Path -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List from .config import cfg diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index 1cbc4ae7..c6302132 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -90,7 +90,6 @@ def get_completion( messages: List[Dict[str, Any]], functions: Optional[List[Dict[str, str]]], ) -> Generator[str, None, None]: - name = arguments = "" is_shell_role = self.role.name == DefaultRoles.SHELL.value is_code_role = self.role.name == DefaultRoles.CODE.value From 510030b94057cdd208c67194a85c308061f3015f Mon Sep 17 00:00:00 2001 From: Hejia Zhang Date: Mon, 8 Jul 2024 22:41:32 +0800 Subject: [PATCH 6/6] Fix tests --- tests/test_default.py | 1 - tests/utils.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/test_default.py b/tests/test_default.py index a9b11a25..0e5e4a57 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -209,7 +209,6 @@ def test_llm_options(completion): model=args["--model"], temperature=args["--temperature"], top_p=args["--top-p"], - functions=None, ) completion.assert_called_once_with(**expected_args) assert result.exit_code == 0 diff --git a/tests/utils.py b/tests/utils.py index 8da7d947..c1a89864 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -54,7 +54,6 @@ def comp_args(role, prompt, **kwargs): "model": cfg.get("DEFAULT_MODEL"), "temperature": 0.0, "top_p": 1.0, - "functions": None, "stream": True, **kwargs, }