diff --git a/src/fastapi_poe/__init__.py b/src/fastapi_poe/__init__.py index 90dbfd5..ca3c868 100644 --- a/src/fastapi_poe/__init__.py +++ b/src/fastapi_poe/__init__.py @@ -18,6 +18,8 @@ "ErrorResponse", "MetaResponse", "ToolDefinition", + "AttachFileResponse", + "ImageResponse" ] from .base import PoeBot, make_app, run @@ -29,8 +31,10 @@ stream_request, ) from .types import ( + AttachFileResponse, Attachment, ErrorResponse, + ImageResponse, MetaResponse, PartialResponse, ProtocolMessage, diff --git a/src/fastapi_poe/base.py b/src/fastapi_poe/base.py index d764ad8..b10b3ae 100644 --- a/src/fastapi_poe/base.py +++ b/src/fastapi_poe/base.py @@ -4,6 +4,7 @@ import json import logging import os +import re import sys import warnings from typing import Any, AsyncIterable, BinaryIO, Dict, Optional, Union @@ -17,10 +18,13 @@ from starlette.middleware.base import BaseHTTPMiddleware from fastapi_poe.types import ( + AttachFileResponse, + AttachmentUploadError, AttachmentUploadResponse, ContentType, ErrorResponse, Identifier, + InvalidParameterError, MetaResponse, PartialResponse, QueryRequest, @@ -33,14 +37,6 @@ logger = logging.getLogger("uvicorn.default") -class InvalidParameterError(Exception): - pass - - -class AttachmentUploadError(Exception): - pass - - class LoggingMiddleware(BaseHTTPMiddleware): async def set_body(self, request: Request): receive_ = await request._receive() @@ -98,7 +94,7 @@ class PoeBot: # Override these for your bot async def get_response( self, request: QueryRequest - ) -> AsyncIterable[Union[PartialResponse, ServerSentEvent]]: + ) -> AsyncIterable[Union[PartialResponse, AttachFileResponse, ServerSentEvent]]: """Override this to return a response to user queries.""" yield self.text_event("hello") @@ -161,7 +157,6 @@ async def _make_file_attachment_request( content_type: Optional[str] = None, is_inline: bool = False, ) -> AttachmentUploadResponse: - url = "https://www.quora.com/poe_api/file_attachment_3RD_PARTY_POST" async with httpx.AsyncClient(timeout=120) as client: try: @@ -176,7 +171,12 @@ async def _make_file_attachment_request( "is_inline": is_inline, "download_url": download_url, } - request = httpx.Request("POST", url, data=data, headers=headers) + request = httpx.Request( + "POST", + self._attachment_upload_url, + data=data, + headers=headers + ) elif file_data and filename: data = {"message_id": message_id, "is_inline": is_inline} files = { @@ -187,7 +187,11 @@ async def _make_file_attachment_request( ) } request = httpx.Request( - "POST", url, files=files, data=data, headers=headers + "POST", + self._attachment_upload_url, + files=files, + data=data, + headers=headers ) else: raise InvalidParameterError( @@ -208,10 +212,11 @@ async def _make_file_attachment_request( logger.error("An HTTP error occurred when attempting to attach file") raise - async def _process_pending_attachment_requests(self, message_id): + async def _process_pending_attachment_requests(self, request: QueryRequest) -> None: try: await asyncio.gather( - *self._pending_file_attachment_tasks.pop(message_id, []) + *self._pending_file_attachment_tasks.pop(request.message_id, []), + *request._pending_tasks, ) except Exception: logger.error("Error processing pending attachment requests") @@ -269,8 +274,22 @@ def error_event( data["error_type"] = error_type return ServerSentEvent(data=json.dumps(data), event="error") + @staticmethod + def inline_attachment_event( + *, + inline_ref: str, + description: Optional[str] = None, + ): + if description: + text = f"![{_markdown_escape(description)}][{inline_ref}]" + else: + text = f"![{inline_ref}]" + return ServerSentEvent(data=json.dumps({"text": text}), event="text") + # Internal handlers + _attachment_upload_url = "https://www.quora.com/poe_api/file_attachment_3RD_PARTY_POST" + async def handle_report_feedback( self, feedback_request: ReportFeedbackRequest ) -> JSONResponse: @@ -307,6 +326,25 @@ async def handle_query( linkify=event.linkify, suggested_replies=event.suggested_replies, ) + elif isinstance(event, AttachFileResponse): + upload_task = asyncio.create_task( + request.post_message_attachment( + file_data=event.file_data, + filename=event.filename, + content_type=event.content_type, + is_inline=event.is_inline, + ) + ) + if event.is_inline: + upload_response = await upload_task + if not upload_response.inline_ref: + raise AttachmentUploadError( + "Attachment upload failed, no inline_ref returned." + ) + yield self.inline_attachment_event( + inline_ref=upload_response.inline_ref, + description=event.description or event.filename, + ) elif event.is_suggested_reply: yield self.suggested_reply_event(event.text) elif event.is_replace_response: @@ -317,13 +355,21 @@ async def handle_query( logger.exception("Error responding to query") yield self.error_event(repr(e), allow_retry=False) try: - await self._process_pending_attachment_requests(request.message_id) + await self._process_pending_attachment_requests(request) except Exception as e: logger.exception("Error processing pending attachment requests") yield self.error_event(repr(e), allow_retry=False) yield self.done_event() +ASCII_PUNCTUATION_CAPTURE_REGEX = re.compile(r"""([!"#$%&'()*+,\-.\/:;<=>?@\[\\\]^_`{|}~])""") + +def _markdown_escape(text: str) -> str: + return ASCII_PUNCTUATION_CAPTURE_REGEX.sub( + r"\\\1", text + ) + + def _find_access_key(*, access_key: str, api_key: str) -> Optional[str]: """Figures out the access key. diff --git a/src/fastapi_poe/types.py b/src/fastapi_poe/types.py index 4a89959..4e328bd 100644 --- a/src/fastapi_poe/types.py +++ b/src/fastapi_poe/types.py @@ -1,8 +1,14 @@ -from typing import Any, Dict, List, Optional +import asyncio +import httpx +import logging +from dataclasses import dataclass +from typing import Any, BinaryIO, Dict, List, Optional, Set, Union from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Literal, TypeAlias +logger = logging.getLogger("uvicorn.default") + Identifier: TypeAlias = str FeedbackType: TypeAlias = Literal["like", "dislike"] ContentType: TypeAlias = Literal["text/markdown", "text/plain"] @@ -22,6 +28,10 @@ class Attachment(BaseModel): name: str +class AttachmentUploadResponse(BaseModel): + inline_ref: Optional[str] + + class ProtocolMessage(BaseModel): """A message as used in the Poe protocol.""" @@ -34,6 +44,14 @@ class ProtocolMessage(BaseModel): attachments: List[Attachment] = Field(default_factory=list) +class InvalidParameterError(Exception): + pass + + +class AttachmentUploadError(Exception): + pass + + class BaseRequest(BaseModel): """Common data for all requests.""" @@ -56,6 +74,96 @@ class QueryRequest(BaseRequest): logit_bias: Dict[str, float] = {} stop_sequences: List[str] = [] + _pending_tasks: Set[asyncio.Task] = set() + _attachment_upload_url = "https://www.quora.com/poe_api/file_attachment_3RD_PARTY_POST" + + async def post_message_attachment( + self, + *, + download_url: Optional[str] = None, + file_data: Optional[Union[bytes, BinaryIO]] = None, + filename: Optional[str] = None, + content_type: Optional[str] = None, + is_inline: bool = False, + ) -> AttachmentUploadResponse: + task = asyncio.create_task( + self._make_file_attachment_request( + download_url=download_url, + file_data=file_data, + filename=filename, + content_type=content_type, + is_inline=is_inline, + ) + ) + self._pending_tasks.add(task) + try: + return await task + finally: + self._pending_tasks.remove(task) + + async def _make_file_attachment_request( + self, + *, + download_url: Optional[str] = None, + file_data: Optional[Union[bytes, BinaryIO]] = None, + filename: Optional[str] = None, + content_type: Optional[str] = None, + is_inline: bool = False, + ) -> AttachmentUploadResponse: + async with httpx.AsyncClient(timeout=120) as client: + try: + headers = {"Authorization": f"{self.access_key}"} + if download_url: + if file_data or filename: + raise InvalidParameterError( + "Cannot provide filename or file_data if download_url is provided." + ) + data = { + "message_id": self.message_id, + "is_inline": is_inline, + "download_url": download_url, + } + request = httpx.Request( + "POST", + self._attachment_upload_url, + data=data, + headers=headers + ) + elif file_data and filename: + data = {"message_id": self.message_id, "is_inline": is_inline} + files = { + "file": ( + (filename, file_data) + if content_type is None + else (filename, file_data, content_type) + ) + } + request = httpx.Request( + "POST", + self._attachment_upload_url, + files=files, + data=data, + headers=headers + ) + else: + raise InvalidParameterError( + "Must provide either download_url or file_data and filename." + ) + response = await client.send(request) + + if response.status_code != 200: + raise AttachmentUploadError( + f"{response.status_code}: {response.reason_phrase}" + ) + + return AttachmentUploadResponse( + inline_ref=response.json().get("inline_ref") + ) + + except httpx.HTTPError: + logger.error("An HTTP error occurred when attempting to attach file") + raise + class SettingsRequest(BaseRequest): """Request parameters for a settings request.""" @@ -87,10 +195,6 @@ class SettingsResponse(BaseModel): introduction_message: str = "" -class AttachmentUploadResponse(BaseModel): - inline_ref: Optional[str] - - class PartialResponse(BaseModel): """Representation of a (possibly partial) response from a bot.""" @@ -142,6 +246,22 @@ class MetaResponse(PartialResponse): refetch_settings: bool = False +@dataclass +class AttachFileResponse: + """Communicate attachment files from server bots.""" + + file_data: Union[bytes, BinaryIO] + filename: str + content_type: Optional[str] = None + is_inline: bool = False + description: Optional[str] = None + + +@dataclass +class ImageResponse(AttachFileResponse): + is_inline: bool = True + + class ToolDefinition(BaseModel): class FunctionDefinition(BaseModel): class ParametersDefinition(BaseModel):