diff --git a/src/fastapi_poe/base.py b/src/fastapi_poe/base.py index 19cf7a8..89d1c3c 100644 --- a/src/fastapi_poe/base.py +++ b/src/fastapi_poe/base.py @@ -6,7 +6,18 @@ import os import sys import warnings -from typing import AsyncIterable, Awaitable, BinaryIO, Callable, Dict, Optional, Union +from collections import defaultdict +from dataclasses import dataclass +from typing import ( + AsyncIterable, + Awaitable, + BinaryIO, + Callable, + Dict, + Optional, + Sequence, + Union, +) import httpx from fastapi import Depends, FastAPI, HTTPException, Request, Response @@ -86,20 +97,11 @@ async def http_exception_handler(request: Request, ex: Exception) -> Response: http_bearer = HTTPBearer() -def auth_user( - authorization: HTTPAuthorizationCredentials = Depends(http_bearer), -) -> None: - if auth_key is None: - return - if authorization.scheme != "Bearer" or authorization.credentials != auth_key: - raise HTTPException( - status_code=401, - detail="Invalid access key", - headers={"WWW-Authenticate": "Bearer"}, - ) - - +@dataclass class PoeBot: + path: str = "/" # Path where this bot will be exposed + access_key: Optional[str] = None # Access key for this bot + # Override these for your bot async def get_response_with_context( @@ -143,7 +145,7 @@ async def on_error(self, error_request: ReportErrorRequest) -> None: logger.error(f"Error from Poe server: {error_request}") # Helpers for generating responses - def __init__(self) -> None: + def __post_init__(self) -> None: self._pending_file_attachment_tasks = {} async def post_message_attachment( @@ -415,23 +417,7 @@ def _verify_access_key( return _access_key -def make_app( - bot: PoeBot, - access_key: str = "", - *, - api_key: str = "", - allow_without_key: bool = False, -) -> FastAPI: - """Create an app object. Arguments are as for run().""" - app = FastAPI() - app.add_exception_handler(RequestValidationError, http_exception_handler) - - global auth_key - auth_key = _verify_access_key( - access_key=access_key, api_key=api_key, allow_without_key=allow_without_key - ) - - @app.get("/") +def _add_routes_for_bot(app: FastAPI, bot: PoeBot) -> None: async def index() -> Response: url = "https://poe.com/create_bot?server=1" return HTMLResponse( @@ -440,7 +426,21 @@ async def index() -> Response: f' href="{url}">{url}.

' ) - @app.post("/") + def auth_user( + authorization: HTTPAuthorizationCredentials = Depends(http_bearer), + ) -> None: + if bot.access_key is None: + return + if ( + authorization.scheme != "Bearer" + or authorization.credentials != bot.access_key + ): + raise HTTPException( + status_code=401, + detail="Invalid access key", + headers={"WWW-Authenticate": "Bearer"}, + ) + async def poe_post(request: Request, dict: object = Depends(auth_user)) -> Response: request_body = await request.json() request_body["http_request"] = request @@ -450,8 +450,8 @@ async def poe_post(request: Request, dict: object = Depends(auth_user)) -> Respo QueryRequest.parse_obj( { **request_body, - "access_key": auth_key or "", - "api_key": auth_key or "", + "access_key": bot.access_key or "", + "api_key": bot.access_key or "", } ), RequestContext(http_request=request), @@ -475,13 +475,67 @@ async def poe_post(request: Request, dict: object = Depends(auth_user)) -> Respo else: raise HTTPException(status_code=501, detail="Unsupported request type") + app.get(bot.path)(index) + app.post(bot.path)(poe_post) + + +def make_app( + bot: Union[PoeBot, Sequence[PoeBot]], + access_key: str = "", + *, + api_key: str = "", + allow_without_key: bool = False, +) -> FastAPI: + """Create an app object. Arguments are as for run().""" + app = FastAPI() + app.add_exception_handler(RequestValidationError, http_exception_handler) + + if isinstance(bot, PoeBot): + if bot.access_key is None: + bot.access_key = _verify_access_key( + access_key=access_key, + api_key=api_key, + allow_without_key=allow_without_key, + ) + elif access_key: + raise ValueError( + "Cannot provide access_key if the bot object already has an access key" + ) + elif api_key: + raise ValueError( + "Cannot provide api_key if the bot object already has an access key" + ) + bots = [bot] + else: + if access_key or api_key: + raise ValueError( + "When serving multiple bots, the access_key must be set on each bot" + ) + bots = bot + + # Ensure paths are unique + path_to_bots = defaultdict(list) + for bot in bots: + path_to_bots[bot.path].append(bot) + for path, bots_of_path in path_to_bots.items(): + if len(bots_of_path) > 1: + raise ValueError( + f"Multiple bots are trying to use the same path: {path}: {bots_of_path}. " + "Please use a different path for each bot." + ) + + for bot_obj in bots: + if bot_obj.access_key is None and not allow_without_key: + raise ValueError(f"Missing access key on {bot_obj}") + _add_routes_for_bot(app, bot_obj) + # Uncomment this line to print out request and response # app.add_middleware(LoggingMiddleware) return app def run( - bot: PoeBot, + bot: Union[PoeBot, Sequence[PoeBot]], access_key: str = "", *, api_key: str = "", @@ -490,10 +544,11 @@ def run( """ Run a Poe bot server using FastAPI. - :param bot: The bot object. + :param bot: The bot object or a list of bot objects. :param access_key: The access key to use. If not provided, the server tries to read the POE_ACCESS_KEY environment variable. If that is not set, the server will - refuse to start, unless *allow_without_key* is True. + refuse to start, unless *allow_without_key* is True. If multiple bots are provided, + the access key must be provided as part of the bot object. :param api_key: The previous name of access_key. This param is deprecated and will be removed in a future version :param allow_without_key: If True, the server will start even if no access key