diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000..b4ed6e5 Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/fastapi_poe/base.py b/src/fastapi_poe/base.py index 52a1ac4..d1ab536 100644 --- a/src/fastapi_poe/base.py +++ b/src/fastapi_poe/base.py @@ -31,6 +31,10 @@ logger = logging.getLogger("uvicorn.default") +class InvalidParameterError(Exception): + pass + + class LoggingMiddleware(BaseHTTPMiddleware): async def set_body(self, request: Request): receive_ = await request._receive() @@ -64,7 +68,7 @@ async def dispatch(self, request: Request, call_next): return response -def exception_handler(request: Request, ex: HTTPException): +async def http_exception_handler(request, ex): logger.error(ex) @@ -86,7 +90,6 @@ def auth_user( class PoeBot: # Override these for your bot - async def get_response( self, request: QueryRequest ) -> AsyncIterable[Union[PartialResponse, ServerSentEvent]]: @@ -111,11 +114,11 @@ def __init__(self): self.file_attachment_lock = asyncio.Lock() async def post_message_attachment( - self, message_id, file_data, filename, access_key + self, access_key, message_id, download_url=None, file_data=None, filename=None ): task = asyncio.create_task( self._make_file_attachment_request( - message_id, file_data, filename, access_key + access_key, message_id, download_url, file_data, filename ) ) async with self.file_attachment_lock: @@ -124,19 +127,30 @@ async def post_message_attachment( self.pending_file_attachments[message_id] = files_for_message async def _make_file_attachment_request( - self, message_id, file_data, filename, access_key + self, access_key, message_id, download_url=None, file_data=None, filename=None ): url = "https://www.quora.com/poe_api/file_attachment_POST" async with httpx.AsyncClient(timeout=120) as client: try: - files = {"file": (filename, file_data)} - data = {"message_id": message_id} headers = {"Authorization": f"{access_key}"} - - request = httpx.Request( - "POST", url, files=files, data=data, headers=headers - ) + if download_url: + if file_data or filename: + raise InvalidParameterError( + "Cannot provide filename or file_data if download_url is provided." + ) + data = {"message_id": message_id, "download_url": download_url} + request = httpx.Request("POST", url, data=data, headers=headers) + elif file_data and filename: + data = {"message_id": message_id} + files = {"file": (filename, file_data)} + request = httpx.Request( + "POST", url, files=files, data=data, headers=headers + ) + else: + raise InvalidParameterError( + "Must provide either download_url or file_data and filename." + ) response = await client.send(request) if response.status_code != 200: @@ -147,15 +161,18 @@ async def _make_file_attachment_request( return response except httpx.HTTPError: logger.error("An HTTP error occurred when attempting to attach file") + raise async def _process_pending_attachment_requests(self, message_id): try: await asyncio.gather(*self.pending_file_attachments.get(message_id, [])) except Exception: logger.error("Error processing pending attachment requests") - # clear the pending attachments for the message - async with self.file_attachment_lock: - self.pending_file_attachments.pop(message_id, None) + raise + finally: + # clear the pending attachments for the message + async with self.file_attachment_lock: + self.pending_file_attachments.pop(message_id, None) @staticmethod def text_event(text: str) -> ServerSentEvent: @@ -256,7 +273,11 @@ async def handle_query( except Exception as e: logger.exception("Error responding to query") yield self.error_event(repr(e), allow_retry=False) - await self._process_pending_attachment_requests(request.message_id) + try: + await self._process_pending_attachment_requests(request.message_id) + except Exception as e: + logger.exception("Error processing pending attachment requests") + yield self.error_event(repr(e), allow_retry=False) yield self.done_event() @@ -327,7 +348,7 @@ def make_app( ) -> FastAPI: """Create an app object. Arguments are as for run().""" app = FastAPI() - app.add_exception_handler(RequestValidationError, exception_handler) + app.add_exception_handler(RequestValidationError, http_exception_handler) global auth_key auth_key = _verify_access_key(