Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
JohntheLi committed Dec 15, 2023
2 parents 117f4ba + 386a8e3 commit 357fb5e
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/fastapi_poe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"PartialResponse",
"ErrorResponse",
"MetaResponse",
"ToolDefinition",
]

from .base import PoeBot, make_app, run
Expand All @@ -38,4 +39,5 @@
ReportFeedbackRequest,
SettingsRequest,
SettingsResponse,
ToolDefinition,
)
175 changes: 166 additions & 9 deletions src/fastapi_poe/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
ProtocolMessage,
QueryRequest,
SettingsResponse,
ToolCallDefinition,
ToolDefinition,
ToolResultDefinition,
)

PROTOCOL_VERSION = "1.0"
Expand Down Expand Up @@ -119,24 +122,31 @@ async def fetch_settings(self) -> SettingsResponse:
return resp.json()

async def perform_query_request(
self, request: QueryRequest
self,
*,
request: QueryRequest,
tools: List[ToolDefinition],
tool_calls: List[ToolCallDefinition],
tool_results: List[ToolResultDefinition],
) -> AsyncGenerator[BotMessage, None]:
chunks: List[str] = []
message_id = request.message_id
event_count = 0
error_reported = False
payload = request.model_dump()
payload["tools"] = [tool.model_dump() for tool in tools]
payload["tool_calls"] = [tool_call.model_dump() for tool_call in tool_calls]
payload["tool_results"] = [
tool_result.model_dump() for tool_result in tool_results
]
async with httpx_sse.aconnect_sse(
self.session,
"POST",
self.endpoint,
headers=self.headers,
json=request.dict(),
self.session, "POST", self.endpoint, headers=self.headers, json=payload
) as event_source:
async for event in event_source.aiter_sse():
event_count += 1
if event.event == "done":
# Don't send a report if we already told the bot about some other mistake.
if not chunks and not error_reported:
if not chunks and not error_reported and not tools:
await self.report_error(
"Bot returned no text in response",
{"message_id": message_id},
Expand All @@ -162,6 +172,11 @@ async def perform_query_request(
is_suggested_reply=True,
)
continue
elif event.event == "json":
yield BotMessage(
text="", data=json.loads(event.data), full_prompt=repr(request)
)
continue
elif event.event == "meta":
if event_count != 1:
# spec says a meta event that is not the first event is ignored
Expand Down Expand Up @@ -278,6 +293,140 @@ async def stream_request(
bot_name: str,
api_key: str = "",
*,
tools: Optional[List[ToolDefinition]] = None,
tool_executables: Optional[List[Callable]] = None,
access_key: str = "",
access_key_deprecation_warning_stacklevel: int = 2,
session: Optional[httpx.AsyncClient] = None,
on_error: ErrorHandler = _default_error_handler,
num_tries: int = 2,
retry_sleep_time: float = 0.5,
base_url: str = "https://api.poe.com/bot/",
) -> AsyncGenerator[BotMessage, None]:
tool_calls = None
tool_results = None
if tools is not None:
assert tool_executables is not None
tool_calls = await _get_tool_calls(
request=request,
bot_name=bot_name,
api_key=api_key,
tools=tools,
access_key=access_key,
access_key_deprecation_warning_stacklevel=access_key_deprecation_warning_stacklevel,
session=session,
on_error=on_error,
num_tries=num_tries,
retry_sleep_time=retry_sleep_time,
base_url=base_url,
)
tool_results = _get_tool_results(
tool_executables=tool_executables, tool_calls=tool_calls
)
async for message in stream_request_base(
request=request,
bot_name=bot_name,
api_key=api_key,
tools=tools,
tool_calls=tool_calls,
tool_results=tool_results,
access_key=access_key,
access_key_deprecation_warning_stacklevel=access_key_deprecation_warning_stacklevel,
session=session,
on_error=on_error,
num_tries=num_tries,
retry_sleep_time=retry_sleep_time,
base_url=base_url,
):
yield message


def _get_tool_results(
tool_executables: List[Callable], tool_calls: List[ToolCallDefinition]
) -> List[ToolResultDefinition]:
tool_executables_dict = {
executable.__name__: executable for executable in tool_executables
}
tool_results = []
for tool_call in tool_calls:
tool_call_id = tool_call.id
name = tool_call.function.name
arguments = json.loads(tool_call.function.arguments)
content = tool_executables_dict[name](**arguments)
tool_results.append(
ToolResultDefinition(
role="tool",
tool_call_id=tool_call_id,
name=name,
content=json.dumps(content),
)
)
return tool_results


async def _get_tool_calls(
request: QueryRequest,
bot_name: str,
api_key: str = "",
*,
tools: List[ToolDefinition],
access_key: str = "",
access_key_deprecation_warning_stacklevel: int = 2,
session: Optional[httpx.AsyncClient] = None,
on_error: ErrorHandler = _default_error_handler,
num_tries: int = 2,
retry_sleep_time: float = 0.5,
base_url: str = "https://api.poe.com/bot/",
) -> List[ToolCallDefinition]:
tool_call_object_dict: Dict[int, Dict[str, Any]] = {}
async for message in stream_request_base(
request=request,
bot_name=bot_name,
api_key=api_key,
tools=tools,
access_key=access_key,
access_key_deprecation_warning_stacklevel=access_key_deprecation_warning_stacklevel,
session=session,
on_error=on_error,
num_tries=num_tries,
retry_sleep_time=retry_sleep_time,
base_url=base_url,
):
if message.data is not None:
finish_reason = message.data["choices"][0]["finish_reason"]
if finish_reason is None:
try:
tool_call_object = message.data["choices"][0]["delta"][
"tool_calls"
][0]
index = tool_call_object.pop("index")
if index not in tool_call_object_dict:
tool_call_object_dict[index] = tool_call_object
else:
function_info = tool_call_object["function"]
tool_call_object_dict[index]["function"][
"arguments"
] += function_info["arguments"]
except KeyError:
continue
tool_call_object_list = [
tool_call_object
for index, tool_call_object in sorted(tool_call_object_dict.items())
]
return [
ToolCallDefinition(**tool_call_object)
for tool_call_object in tool_call_object_list
]


async def stream_request_base(
request: QueryRequest,
bot_name: str,
api_key: str = "",
*,
tools: Optional[List[ToolDefinition]] = None,
tool_calls: Optional[List[ToolCallDefinition]] = None,
tool_results: Optional[List[ToolResultDefinition]] = None,
access_key: str = "",
access_key_deprecation_warning_stacklevel: int = 2,
session: Optional[httpx.AsyncClient] = None,
Expand All @@ -286,7 +435,6 @@ async def stream_request(
retry_sleep_time: float = 0.5,
base_url: str = "https://api.poe.com/bot/",
) -> AsyncGenerator[BotMessage, None]:
"""Streams BotMessages from a Poe bot."""
if access_key != "":
warnings.warn(
"the access_key param is no longer necessary when using this function.",
Expand All @@ -304,7 +452,12 @@ async def stream_request(
got_response = False
for i in range(num_tries):
try:
async for message in ctx.perform_query_request(request):
async for message in ctx.perform_query_request(
request=request,
tools=tools if tools is not None else [],
tool_calls=tool_calls if tool_calls is not None else [],
tool_results=tool_results if tool_results is not None else [],
):
got_response = True
yield message
break
Expand Down Expand Up @@ -333,6 +486,8 @@ def get_bot_response(
bot_name: str,
api_key: str,
*,
tools: Optional[List[ToolDefinition]] = None,
tool_executables: Optional[List[Callable]] = None,
temperature: Optional[float] = None,
skip_system_prompt: Optional[bool] = None,
logit_bias: Optional[Dict[str, float]] = None,
Expand Down Expand Up @@ -364,6 +519,8 @@ def get_bot_response(
request=query,
bot_name=bot_name,
api_key=api_key,
tools=tools,
tool_executables=tool_executables,
base_url=base_url,
session=session,
)
Expand Down
35 changes: 35 additions & 0 deletions src/fastapi_poe/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ class PartialResponse(BaseModel):
"""

data: Optional[Dict[str, Any]] = None
"""Used when a bot returns the json event."""

raw_response: object = None
"""For debugging, the raw response from the bot."""

Expand Down Expand Up @@ -133,3 +136,35 @@ class MetaResponse(PartialResponse):
suggested_replies: bool = True
content_type: ContentType = "text/markdown"
refetch_settings: bool = False


class ToolDefinition(BaseModel):
class FunctionDefinition(BaseModel):
class ParametersDefinition(BaseModel):
type: str
properties: Dict[str, object]
required: Optional[List[str]] = None

name: str
description: str
parameters: ParametersDefinition

type: str
function: FunctionDefinition


class ToolCallDefinition(BaseModel):
class FunctionDefinition(BaseModel):
name: str
arguments: str

id: str
type: str
function: FunctionDefinition


class ToolResultDefinition(BaseModel):
role: str
name: str
tool_call_id: str
content: str

0 comments on commit 357fb5e

Please sign in to comment.