-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: use json schema from openapi doc and recover in case of assistant protocol error #64
base: development
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,6 @@ | |
ModelClient, | ||
ReasonLengthException, | ||
) | ||
from aidial_assistant.open_api.operation_selector import collect_operations | ||
from aidial_assistant.utils.open_ai import user_message | ||
from aidial_assistant.utils.open_ai_plugin import OpenAIPluginInfo | ||
|
||
|
@@ -62,25 +61,35 @@ async def _run_plugin( | |
self, query: str, execution_callback: ExecutionCallback | ||
) -> ResultObject: | ||
info = self.plugin.info | ||
ops = collect_operations(info.open_api, info.ai_plugin.api.url) | ||
api_schema = "\n\n".join([op.to_typescript() for op in ops.values()]) # type: ignore | ||
spec = info.open_api | ||
spec_url = info.get_full_spec_url() | ||
operations = [ | ||
APIOperation.from_openapi_spec(spec, path, method) | ||
for path in spec.paths | ||
for method in spec.get_methods_for_path(path) | ||
] | ||
api_schema = "\n\n".join([op.to_typescript() for op in operations]) # type: ignore | ||
|
||
def create_command(op: APIOperation): | ||
return lambda: OpenAPIChatCommand(op, self.plugin.auth) | ||
|
||
command_dict: dict[str, CommandConstructor] = {} | ||
for name, op in ops.items(): | ||
# The function is necessary to capture the current value of op. | ||
# Otherwise, only first op will be used for all commands | ||
command_dict[name] = create_command(op) | ||
return lambda: OpenAPIChatCommand.create( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a way to get rid of this function: command_dict: CommandDict = {
op.operation_id: lambda op=op: OpenAPIChatCommand.create(
spec_url, op, self.plugin.auth
)
for op in operations
} |
||
spec_url, op, self.plugin.auth | ||
) | ||
|
||
command_dict: dict[str, CommandConstructor] = { | ||
operation.operation_id: create_command(operation) | ||
for operation in operations | ||
} | ||
command_names = command_dict.keys() | ||
if Reply.token() in command_dict: | ||
Exception(f"Operation with name '{Reply.token()}' is not allowed.") | ||
|
||
command_dict[Reply.token()] = Reply | ||
|
||
history = History( | ||
assistant_system_message_template=ADDON_SYSTEM_DIALOG_MESSAGE.build( | ||
command_names=ops.keys(), | ||
command_names=command_names, | ||
api_description=info.ai_plugin.description_for_model, | ||
api_schema=api_schema, | ||
), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,11 @@ | ||
from typing import Any | ||
|
||
from langchain_community.tools.openapi.utils.api_models import ( | ||
APIOperation, | ||
APIPropertyBase, | ||
) | ||
from openai.types.chat import ChatCompletionToolParam | ||
from langchain_community.tools.openapi.utils.api_models import APIOperation | ||
from typing_extensions import override | ||
|
||
from aidial_assistant.commands.base import ( | ||
Command, | ||
CommandConstructor, | ||
ExecutionCallback, | ||
ResultObject, | ||
TextResult, | ||
|
@@ -21,46 +18,9 @@ | |
ModelClient, | ||
ReasonLengthException, | ||
) | ||
from aidial_assistant.open_api.operation_selector import collect_operations | ||
from aidial_assistant.tools_chain.tools_chain import ( | ||
CommandTool, | ||
CommandToolDict, | ||
ToolsChain, | ||
) | ||
from aidial_assistant.utils.open_ai import ( | ||
construct_tool, | ||
system_message, | ||
user_message, | ||
) | ||
|
||
|
||
def _construct_property(p: APIPropertyBase) -> dict[str, Any]: | ||
parameter = { | ||
"type": p.type, | ||
"description": p.description, | ||
} | ||
return {k: v for k, v in parameter.items() if v is not None} | ||
|
||
|
||
def _construct_tool(op: APIOperation) -> ChatCompletionToolParam: | ||
properties = {} | ||
required = [] | ||
for p in op.properties: | ||
properties[p.name] = _construct_property(p) | ||
|
||
if p.required: | ||
required.append(p.name) | ||
|
||
if op.request_body is not None: | ||
for p in op.request_body.properties: | ||
properties[p.name] = _construct_property(p) | ||
|
||
if p.required: | ||
required.append(p.name) | ||
|
||
return construct_tool( | ||
op.operation_id, op.description or "", properties, required | ||
) | ||
from aidial_assistant.tools_chain.tools_chain import CommandToolDict, ToolsChain | ||
from aidial_assistant.utils.open_ai import system_message, user_message | ||
from aidial_assistant.utils.open_api import construct_tool_from_spec | ||
|
||
|
||
class RunTool(Command): | ||
|
@@ -81,17 +41,26 @@ async def execute( | |
) -> ResultObject: | ||
query = get_required_field(args, "query") | ||
|
||
ops = collect_operations( | ||
self.plugin.info.open_api, self.plugin.info.ai_plugin.api.url | ||
) | ||
spec = self.plugin.info.open_api | ||
spec_url = self.plugin.info.get_full_spec_url() | ||
|
||
def create_command_tool(op: APIOperation) -> CommandTool: | ||
return lambda: OpenAPIChatCommand( | ||
op, self.plugin.auth | ||
), _construct_tool(op) | ||
def create_command(operation: APIOperation) -> CommandConstructor: | ||
# The function is necessary to capture the current value of op. | ||
# Otherwise, only first op will be used for all commands | ||
return lambda: OpenAPIChatCommand.create( | ||
spec_url, operation, self.plugin.auth | ||
) | ||
|
||
commands: CommandToolDict = { | ||
name: create_command_tool(op) for name, op in ops.items() | ||
operation.operation_id: (create_command(operation), tool) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hard to read: three for's where only two are required. Maybe rewrite it as two explicit for-loops. And you can remove the |
||
for path in spec.paths | ||
for operation, tool in ( | ||
( | ||
APIOperation.from_openapi_spec(spec, path, method), | ||
construct_tool_from_spec(spec, path, method), | ||
) | ||
for method in spec.get_methods_for_path(path) | ||
) | ||
} | ||
|
||
chain = ToolsChain(self.model, commands, self.max_completion_tokens) | ||
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has nothing to do with Chat strictly speaking. Maybe simply
OpenAPICommand
?