Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use json schema from openapi doc and recover in case of assistant protocol error #64

Open
wants to merge 4 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions aidial_assistant/commands/open_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,35 @@
ExecutionCallback,
ResultObject,
)
from aidial_assistant.open_api.requester import OpenAPIEndpointRequester
from aidial_assistant.open_api.requester import (
OpenAPIEndpointRequester,
ParamMapping,
)


class OpenAPIChatCommand(Command):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has nothing to do with Chat strictly speaking. Maybe simply OpenAPICommand?

@staticmethod
def token() -> str:
return "open-api-chat-command"

def __init__(self, op: APIOperation, plugin_auth: str | None):
self.op = op
self.plugin_auth = plugin_auth
def __init__(self, requester: OpenAPIEndpointRequester):
self.requester = requester

@override
async def execute(
self, args: dict[str, Any], execution_callback: ExecutionCallback
) -> ResultObject:
return await OpenAPIEndpointRequester(
self.op, self.plugin_auth
).execute(args)
return await self.requester.execute(args)

@classmethod
def create(
cls, base_url: str, operation: APIOperation, auth: str | None
) -> "OpenAPIChatCommand":
path = base_url.rstrip("/") + operation.path
method = operation.method
param_mapping = ParamMapping(
query_params=operation.query_params,
body_params=operation.body_params,
path_params=operation.path_params,
)
return cls(OpenAPIEndpointRequester(path, method, param_mapping, auth))
2 changes: 1 addition & 1 deletion aidial_assistant/commands/plugin_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def on_state(self, request: str, response: str):

@override
def on_error(self, title: str, error: str):
pass
self.callback(f"```\n{title}: {error}\n```\n")

@property
def result(self) -> str:
Expand Down
27 changes: 18 additions & 9 deletions aidial_assistant/commands/run_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
ModelClient,
ReasonLengthException,
)
from aidial_assistant.open_api.operation_selector import collect_operations
from aidial_assistant.utils.open_ai import user_message
from aidial_assistant.utils.open_ai_plugin import OpenAIPluginInfo

Expand Down Expand Up @@ -62,25 +61,35 @@ async def _run_plugin(
self, query: str, execution_callback: ExecutionCallback
) -> ResultObject:
info = self.plugin.info
ops = collect_operations(info.open_api, info.ai_plugin.api.url)
api_schema = "\n\n".join([op.to_typescript() for op in ops.values()]) # type: ignore
spec = info.open_api
spec_url = info.get_full_spec_url()
operations = [
APIOperation.from_openapi_spec(spec, path, method)
for path in spec.paths
for method in spec.get_methods_for_path(path)
]
api_schema = "\n\n".join([op.to_typescript() for op in operations]) # type: ignore

def create_command(op: APIOperation):
return lambda: OpenAPIChatCommand(op, self.plugin.auth)

command_dict: dict[str, CommandConstructor] = {}
for name, op in ops.items():
# The function is necessary to capture the current value of op.
# Otherwise, only first op will be used for all commands
command_dict[name] = create_command(op)
return lambda: OpenAPIChatCommand.create(
Copy link
Contributor

@adubovik adubovik Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a way to get rid of this function:

command_dict: CommandDict = {
    op.operation_id: lambda op=op: OpenAPIChatCommand.create(
        spec_url, op, self.plugin.auth
    )
    for op in operations
}

spec_url, op, self.plugin.auth
)

command_dict: dict[str, CommandConstructor] = {
operation.operation_id: create_command(operation)
for operation in operations
}
command_names = command_dict.keys()
if Reply.token() in command_dict:
Exception(f"Operation with name '{Reply.token()}' is not allowed.")

command_dict[Reply.token()] = Reply

history = History(
assistant_system_message_template=ADDON_SYSTEM_DIALOG_MESSAGE.build(
command_names=ops.keys(),
command_names=command_names,
api_description=info.ai_plugin.description_for_model,
api_schema=api_schema,
),
Expand Down
75 changes: 22 additions & 53 deletions aidial_assistant/commands/run_tool.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from typing import Any

from langchain_community.tools.openapi.utils.api_models import (
APIOperation,
APIPropertyBase,
)
from openai.types.chat import ChatCompletionToolParam
from langchain_community.tools.openapi.utils.api_models import APIOperation
from typing_extensions import override

from aidial_assistant.commands.base import (
Command,
CommandConstructor,
ExecutionCallback,
ResultObject,
TextResult,
Expand All @@ -21,46 +18,9 @@
ModelClient,
ReasonLengthException,
)
from aidial_assistant.open_api.operation_selector import collect_operations
from aidial_assistant.tools_chain.tools_chain import (
CommandTool,
CommandToolDict,
ToolsChain,
)
from aidial_assistant.utils.open_ai import (
construct_tool,
system_message,
user_message,
)


def _construct_property(p: APIPropertyBase) -> dict[str, Any]:
parameter = {
"type": p.type,
"description": p.description,
}
return {k: v for k, v in parameter.items() if v is not None}


def _construct_tool(op: APIOperation) -> ChatCompletionToolParam:
properties = {}
required = []
for p in op.properties:
properties[p.name] = _construct_property(p)

if p.required:
required.append(p.name)

if op.request_body is not None:
for p in op.request_body.properties:
properties[p.name] = _construct_property(p)

if p.required:
required.append(p.name)

return construct_tool(
op.operation_id, op.description or "", properties, required
)
from aidial_assistant.tools_chain.tools_chain import CommandToolDict, ToolsChain
from aidial_assistant.utils.open_ai import system_message, user_message
from aidial_assistant.utils.open_api import construct_tool_from_spec


class RunTool(Command):
Expand All @@ -81,17 +41,26 @@ async def execute(
) -> ResultObject:
query = get_required_field(args, "query")

ops = collect_operations(
self.plugin.info.open_api, self.plugin.info.ai_plugin.api.url
)
spec = self.plugin.info.open_api
spec_url = self.plugin.info.get_full_spec_url()

def create_command_tool(op: APIOperation) -> CommandTool:
return lambda: OpenAPIChatCommand(
op, self.plugin.auth
), _construct_tool(op)
def create_command(operation: APIOperation) -> CommandConstructor:
# The function is necessary to capture the current value of op.
# Otherwise, only first op will be used for all commands
return lambda: OpenAPIChatCommand.create(
spec_url, operation, self.plugin.auth
)

commands: CommandToolDict = {
name: create_command_tool(op) for name, op in ops.items()
operation.operation_id: (create_command(operation), tool)
Copy link
Contributor

@adubovik adubovik Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard to read: three for's where only two are required.

Maybe rewrite it as two explicit for-loops.

And you can remove the create_command method as I suggested in another comment.

for path in spec.paths
for operation, tool in (
(
APIOperation.from_openapi_spec(spec, path, method),
construct_tool_from_spec(spec, path, method),
)
for method in spec.get_methods_for_path(path)
)
}

chain = ToolsChain(self.model, commands, self.max_completion_tokens)
Expand Down
51 changes: 0 additions & 51 deletions aidial_assistant/open_api/operation_selector.py

This file was deleted.

27 changes: 14 additions & 13 deletions aidial_assistant/open_api/requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

import aiohttp.client_exceptions
from aiohttp import hdrs
from langchain.tools.openapi.utils.api_models import APIOperation

from aidial_assistant.commands.base import JsonResult, ResultObject, TextResult
from aidial_assistant.utils.requests import arequest

logger = logging.getLogger(__name__)


class _ParamMapping(NamedTuple):
class ParamMapping(NamedTuple):
"""Mapping from parameter name to parameter value."""

query_params: List[str]
Expand All @@ -25,18 +24,21 @@ class OpenAPIEndpointRequester:
Based on OpenAPIEndpointChain from LangChain.
"""

def __init__(self, operation: APIOperation, plugin_auth: str | None):
self.operation = operation
self.param_mapping = _ParamMapping(
query_params=operation.query_params, # type: ignore
body_params=operation.body_params, # type: ignore
path_params=operation.path_params, # type: ignore
)
def __init__(
self,
url: str,
method: str,
param_mapping: ParamMapping,
plugin_auth: str | None,
):
self.url = url
self.method = method
self.param_mapping = param_mapping
self.plugin_auth = plugin_auth

def _construct_path(self, args: Dict[str, str]) -> str:
"""Construct the path from the deserialized input."""
path = self.operation.base_url.rstrip("/") + self.operation.path # type: ignore
path = self.url
for param in self.param_mapping.path_params:
path = path.replace(f"{{{param}}}", str(args.pop(param, "")))
return path
Expand Down Expand Up @@ -87,17 +89,16 @@ async def execute(
)
logger.debug(f"Request args: {request_args}")
async with arequest(
self.operation.method.value, headers=headers, **request_args # type: ignore
self.method, headers=headers, **request_args # type: ignore
) as response:
if response.status != 200:
try:
return JsonResult(json.dumps(await response.json()))
except aiohttp.ContentTypeError:
method_str = str(self.operation.method.value) # type: ignore
error_object = {
"reason": response.reason,
"status_code": response.status,
"method:": method_str.upper(),
"method:": self.method.upper(),
"url": request_args["url"],
"params": request_args["params"],
}
Expand Down
Loading
Loading