diff --git a/src/fastapi_poe/base.py b/src/fastapi_poe/base.py index 01a8a53..ced6269 100644 --- a/src/fastapi_poe/base.py +++ b/src/fastapi_poe/base.py @@ -31,6 +31,10 @@ logger = logging.getLogger("uvicorn.default") +class InvalidParameterError(Exception): + pass + + class LoggingMiddleware(BaseHTTPMiddleware): async def set_body(self, request: Request): receive_ = await request._receive() @@ -126,32 +130,28 @@ async def post_message_attachment( async def _make_file_attachment_request( self, access_key, message_id, download_url=None, file_data=None, filename=None ): - assert download_url or ( - file_data and filename - ), "Must provide either download_url or file_data and filename." - - if file_data: - assert filename, "Must provide filename if providing file_data." - - if download_url: - assert ( - not file_data and not filename - ), "Cannot provide file_data or filename if providing download_url." - url = "https://www.quora.com/poe_api/file_attachment_POST" async with httpx.AsyncClient(timeout=120) as client: try: headers = {"Authorization": f"{access_key}"} if download_url: + if file_data or filename: + raise InvalidParameterError( + "Cannot provide filename or file_data if download_url is provided." + ) data = {"message_id": message_id, "download_url": download_url} request = httpx.Request("POST", url, data=data, headers=headers) - else: + elif file_data and filename: data = {"message_id": message_id} files = {"file": (filename, file_data)} request = httpx.Request( "POST", url, files=files, data=data, headers=headers ) + else: + raise InvalidParameterError( + "Must provide either download_url or file_data and filename." + ) response = await client.send(request) if response.status_code != 200: @@ -162,15 +162,18 @@ async def _make_file_attachment_request( return response except httpx.HTTPError: logger.error("An HTTP error occurred when attempting to attach file") + raise async def _process_pending_attachment_requests(self, message_id): try: await asyncio.gather(*self.pending_file_attachments.get(message_id, [])) except Exception: logger.error("Error processing pending attachment requests") - # clear the pending attachments for the message - async with self.file_attachment_lock: - self.pending_file_attachments.pop(message_id, None) + raise + finally: + # clear the pending attachments for the message + async with self.file_attachment_lock: + self.pending_file_attachments.pop(message_id, None) @staticmethod def text_event(text: str) -> ServerSentEvent: @@ -271,7 +274,11 @@ async def handle_query( except Exception as e: logger.exception("Error responding to query") yield self.error_event(repr(e), allow_retry=False) - await self._process_pending_attachment_requests(request.message_id) + try: + await self._process_pending_attachment_requests(request.message_id) + except Exception as e: + logger.exception("Error processing pending attachment requests") + yield self.error_event(repr(e), allow_retry=False) yield self.done_event()