Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

different approach for passing http context information #59

Merged
merged 4 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "fastapi_poe"
version = "0.0.30"
version = "0.0.31"
authors = [
{ name="Lida Li", email="[email protected]" },
{ name="Jelle Zijlstra", email="[email protected]" },
Expand Down
2 changes: 2 additions & 0 deletions src/fastapi_poe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"ErrorResponse",
"MetaResponse",
"ToolDefinition",
"RequestContext",
]

from .base import PoeBot, make_app, run
Expand All @@ -37,6 +38,7 @@
QueryRequest,
ReportErrorRequest,
ReportFeedbackRequest,
RequestContext,
SettingsRequest,
SettingsResponse,
ToolDefinition,
Expand Down
56 changes: 44 additions & 12 deletions src/fastapi_poe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
QueryRequest,
ReportErrorRequest,
ReportFeedbackRequest,
RequestContext,
SettingsRequest,
SettingsResponse,
)
Expand Down Expand Up @@ -96,20 +97,43 @@ def auth_user(

class PoeBot:
# Override these for your bot

async def get_response_with_context(
self, request: QueryRequest, context: RequestContext
) -> AsyncIterable[Union[PartialResponse, ServerSentEvent]]:
async for event in self.get_response(request):
yield event

async def get_response(
self, request: QueryRequest
) -> AsyncIterable[Union[PartialResponse, ServerSentEvent]]:
"""Override this to return a response to user queries."""
yield self.text_event("hello")

async def get_settings_with_context(
self, setting: SettingsRequest, context: RequestContext
) -> SettingsResponse:
settings = await self.get_settings(setting)
return settings

async def get_settings(self, setting: SettingsRequest) -> SettingsResponse:
"""Override this to return non-standard settings."""
return SettingsResponse()

async def on_feedback_with_context(
self, feedback_request: ReportFeedbackRequest, context: RequestContext
) -> None:
await self.on_feedback(feedback_request)

async def on_feedback(self, feedback_request: ReportFeedbackRequest) -> None:
"""Override this to record feedback from the user."""
pass

async def on_error_with_context(
self, error_request: ReportErrorRequest, context: RequestContext
) -> None:
await self.on_error(error_request)

async def on_error(self, error_request: ReportErrorRequest) -> None:
"""Override this to record errors from the Poe server."""
logger.error(f"Error from Poe server: {error_request}")
Expand Down Expand Up @@ -272,26 +296,28 @@ def error_event(
# Internal handlers

async def handle_report_feedback(
self, feedback_request: ReportFeedbackRequest
self, feedback_request: ReportFeedbackRequest, context: RequestContext
) -> JSONResponse:
await self.on_feedback(feedback_request)
await self.on_feedback_with_context(feedback_request, context)
return JSONResponse({})

async def handle_report_error(
self, error_request: ReportErrorRequest
self, error_request: ReportErrorRequest, context: RequestContext
) -> JSONResponse:
await self.on_error(error_request)
await self.on_error_with_context(error_request, context)
return JSONResponse({})

async def handle_settings(self, settings_request: SettingsRequest) -> JSONResponse:
settings = await self.get_settings(settings_request)
async def handle_settings(
self, settings_request: SettingsRequest, context: RequestContext
) -> JSONResponse:
settings = await self.get_settings_with_context(settings_request, context)
return JSONResponse(settings.dict())

async def handle_query(
self, request: QueryRequest
self, request: QueryRequest, context: RequestContext
) -> AsyncIterable[ServerSentEvent]:
try:
async for event in self.get_response(request):
async for event in self.get_response_with_context(request, context):
if isinstance(event, ServerSentEvent):
yield event
elif isinstance(event, ErrorResponse):
Expand Down Expand Up @@ -420,18 +446,24 @@ async def poe_post(request: Request, dict=Depends(auth_user)) -> Response:
"access_key": auth_key or "<missing>",
"api_key": auth_key or "<missing>",
}
)
),
RequestContext(http_request=request),
)
)
elif request_body["type"] == "settings":
return await bot.handle_settings(SettingsRequest.parse_obj(request_body))
return await bot.handle_settings(
SettingsRequest.parse_obj(request_body),
RequestContext(http_request=request),
)
elif request_body["type"] == "report_feedback":
return await bot.handle_report_feedback(
ReportFeedbackRequest.parse_obj(request_body)
ReportFeedbackRequest.parse_obj(request_body),
RequestContext(http_request=request),
)
elif request_body["type"] == "report_error":
return await bot.handle_report_error(
ReportErrorRequest.parse_obj(request_body)
ReportErrorRequest.parse_obj(request_body),
RequestContext(http_request=request),
)
else:
raise HTTPException(status_code=501, detail="Unsupported request type")
Expand Down
4 changes: 1 addition & 3 deletions src/fastapi_poe/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ async def perform_query_request(
message_id = request.message_id
event_count = 0
error_reported = False
# http_request is not JSON-serializable and isn't relevant for sending
# onwards
payload = request.model_dump(exclude={"http_request"})
payload = request.model_dump()
payload["tools"] = [tool.model_dump() for tool in tools]
payload["tool_calls"] = [tool_call.model_dump() for tool_call in tool_calls]
payload["tool_results"] = [
Expand Down
11 changes: 7 additions & 4 deletions src/fastapi_poe/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,18 @@ class ProtocolMessage(BaseModel):
attachments: List[Attachment] = Field(default_factory=list)


class BaseRequest(BaseModel):
"""Common data for all requests."""

class RequestContext(BaseModel):
class Config:
arbitrary_types_allowed = True

http_request: Request


class BaseRequest(BaseModel):
"""Common data for all requests."""

version: str
type: Literal["query", "settings", "report_feedback", "report_error"]
http_request: Optional[Request] = None


class QueryRequest(BaseRequest):
Expand Down
Loading