-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: preserve extra fields in text event responses #127
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from typing import ( | ||
Any, | ||
AsyncIterable, | ||
Awaitable, | ||
BinaryIO, | ||
|
@@ -297,7 +298,8 @@ | |
filename: Optional[str] = None, | ||
content_type: Optional[str] = None, | ||
is_inline: bool = False, | ||
) -> AttachmentUploadResponse: ... | ||
) -> AttachmentUploadResponse: | ||
... | ||
|
||
# This overload requires all parameters to be passed as keywords | ||
@overload | ||
|
@@ -310,7 +312,8 @@ | |
filename: Optional[str] = None, | ||
content_type: Optional[str] = None, | ||
is_inline: bool = False, | ||
) -> AttachmentUploadResponse: ... | ||
) -> AttachmentUploadResponse: | ||
... | ||
|
||
async def post_message_attachment( | ||
self, | ||
|
@@ -612,8 +615,14 @@ | |
return new_messages | ||
|
||
@staticmethod | ||
def text_event(text: str) -> ServerSentEvent: | ||
return ServerSentEvent(data=json.dumps({"text": text}), event="text") | ||
def text_event( | ||
text: str, data: Optional[Dict[str, Any]] = dict() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's make this |
||
) -> ServerSentEvent: | ||
return ServerSentEvent( | ||
text=text, | ||
data=data, | ||
event="text", | ||
) | ||
|
||
@staticmethod | ||
def replace_response_event(text: str) -> ServerSentEvent: | ||
|
@@ -724,7 +733,7 @@ | |
elif event.is_replace_response: | ||
yield self.replace_response_event(event.text) | ||
else: | ||
yield self.text_event(event.text) | ||
yield self.text_event(event.text, event.data) | ||
except Exception as e: | ||
logger.exception("Error responding to query") | ||
yield self.error_event( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
from typing import Any, AsyncGenerator, Optional, Union | ||
import json | ||
import pytest | ||
from unittest.mock import AsyncMock, patch | ||
from fastapi_poe.client import ( | ||
_BotContext, | ||
BotMessage, | ||
MetaMessage, | ||
BotError, | ||
BotErrorNoRetry, | ||
) | ||
from fastapi_poe.types import QueryRequest | ||
|
||
|
||
class MockSSEEvent: | ||
def __init__(self, event_type: str, data: Union[dict[str, Any], str]) -> None: | ||
self.event: str = event_type | ||
self.data: str = json.dumps(data) if isinstance(data, dict) else data | ||
|
||
|
||
class MockEventSource: | ||
def __init__(self, events: list[MockSSEEvent]) -> None: | ||
self.events = events | ||
|
||
async def aiter_sse(self) -> AsyncGenerator[MockSSEEvent, None]: | ||
for event in self.events: | ||
yield event | ||
|
||
async def __aenter__(self) -> "MockEventSource": | ||
return self | ||
|
||
async def __aexit__(self, *args: Any) -> None: | ||
pass | ||
|
||
|
||
class TestPerformQueryRequest: | ||
@pytest.fixture | ||
def mock_session(self) -> AsyncMock: | ||
return AsyncMock() | ||
|
||
@pytest.fixture | ||
def bot_context(self, mock_session: AsyncMock) -> _BotContext: | ||
return _BotContext( | ||
endpoint="https://test.com/bot", session=mock_session, api_key="test_key" | ||
) | ||
|
||
@pytest.fixture | ||
def base_request(self) -> QueryRequest: | ||
return QueryRequest( | ||
query=[], | ||
user_id="test_user", | ||
conversation_id="test_conv", | ||
message_id="test_message", | ||
version="1.0", | ||
type="query", | ||
) | ||
|
||
async def _run_query_request( | ||
self, events: list[MockSSEEvent], context: _BotContext, request: QueryRequest | ||
) -> list[BotMessage]: | ||
"""Helper method to run query request and collect messages""" | ||
messages: list[BotMessage] = [] | ||
with patch("httpx_sse.aconnect_sse", return_value=MockEventSource(events)): | ||
async for msg in context.perform_query_request( | ||
request=request, tools=None, tool_calls=None, tool_results=None | ||
): | ||
messages.append(msg) | ||
return messages | ||
|
||
@pytest.mark.asyncio | ||
async def test_text_event_with_extra_fields( | ||
self, bot_context: _BotContext, base_request: QueryRequest | ||
) -> None: | ||
events = [ | ||
MockSSEEvent("text", {"text": "Hello", "extra_field": "extra_value"}), | ||
MockSSEEvent("done", {}), | ||
] | ||
|
||
messages = await self._run_query_request(events, bot_context, base_request) | ||
|
||
assert len(messages) == 1 | ||
assert messages[0].text == "Hello" | ||
assert isinstance(messages[0].raw_response, dict) | ||
assert "text" in messages[0].raw_response | ||
assert ( | ||
json.loads(messages[0].raw_response["text"]).get("extra_field") | ||
== "extra_value" | ||
) | ||
|
||
@pytest.mark.asyncio | ||
async def test_replace_response_event( | ||
self, bot_context: _BotContext, base_request: QueryRequest | ||
) -> None: | ||
events = [ | ||
MockSSEEvent("text", {"text": "First"}), | ||
MockSSEEvent("replace_response", {"text": "Replaced"}), | ||
MockSSEEvent("done", {}), | ||
] | ||
|
||
messages = await self._run_query_request(events, bot_context, base_request) | ||
|
||
assert len(messages) == 2 | ||
assert messages[1].is_replace_response == True | ||
assert messages[1].text == "Replaced" | ||
|
||
@pytest.mark.asyncio | ||
async def test_meta_event( | ||
self, bot_context: _BotContext, base_request: QueryRequest | ||
) -> None: | ||
events = [ | ||
MockSSEEvent( | ||
"meta", | ||
{ | ||
"linkify": True, | ||
"suggested_replies": True, | ||
"content_type": "text/markdown", | ||
}, | ||
), | ||
MockSSEEvent("done", {}), | ||
] | ||
|
||
messages = await self._run_query_request(events, bot_context, base_request) | ||
|
||
assert len(messages) == 1 | ||
assert isinstance(messages[0], MetaMessage) | ||
assert messages[0].linkify == True | ||
assert messages[0].suggested_replies == True | ||
assert messages[0].content_type == "text/markdown" | ||
|
||
@pytest.mark.asyncio | ||
async def test_error_event( | ||
self, bot_context: _BotContext, base_request: QueryRequest | ||
) -> None: | ||
events = [ | ||
MockSSEEvent("error", {"message": "Test error", "allow_retry": False}), | ||
] | ||
with pytest.raises(BotErrorNoRetry): | ||
await self._run_query_request(events, bot_context, base_request) | ||
|
||
@pytest.mark.asyncio | ||
async def test_invalid_text_event( | ||
self, bot_context: _BotContext, base_request: QueryRequest | ||
) -> None: | ||
events = [ | ||
MockSSEEvent("text", {"text": None}), # Invalid text field | ||
] | ||
with pytest.raises(BotErrorNoRetry): | ||
await self._run_query_request(events, bot_context, base_request) | ||
|
||
@pytest.mark.asyncio | ||
async def test_suggested_reply_event( | ||
self, bot_context: _BotContext, base_request: QueryRequest | ||
) -> None: | ||
events = [ | ||
MockSSEEvent( | ||
"text", | ||
{"text": "Suggestion", "data": {"extra_field": "extra_value"}}, | ||
), | ||
MockSSEEvent("done", {}), | ||
] | ||
|
||
messages = await self._run_query_request(events, bot_context, base_request) | ||
|
||
assert len(messages) == 1 | ||
assert messages[0].is_suggested_reply == True | ||
assert messages[0].text == "Suggestion" | ||
|
||
async def test_multiple_text_events( | ||
self, bot_context: _BotContext, base_request: QueryRequest | ||
) -> None: | ||
events = [ | ||
MockSSEEvent( | ||
"text", | ||
{ | ||
"text": "First", | ||
"extra_field": "extra_value", | ||
"waiting_time_us": 1000, | ||
}, | ||
), | ||
MockSSEEvent( | ||
"text", | ||
{ | ||
"text": "Second", | ||
"extra_field": "extra_value", | ||
"waiting_time_us": 2000, | ||
}, | ||
), | ||
MockSSEEvent("done", {}), | ||
] | ||
|
||
messages = await self._run_query_request(events, bot_context, base_request) | ||
|
||
# Check number of messages | ||
assert len(messages) == 2 | ||
|
||
# Check first message | ||
assert messages[0].text == "First" | ||
assert isinstance(messages[0].raw_response, dict) | ||
assert "text" in messages[0].raw_response | ||
parsed_data = json.loads(messages[0].raw_response["text"]) | ||
assert parsed_data.get("extra_field") == "extra_value" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's have different values for these between the messages |
||
assert parsed_data.get("waiting_time_us") == 1000 | ||
assert messages[0].is_replace_response == False | ||
assert messages[0].is_suggested_reply == False | ||
# Check second message | ||
assert messages[1].text == "Second" | ||
assert isinstance(messages[1].raw_response, dict) | ||
assert "text" in messages[1].raw_response | ||
parsed_data = json.loads(messages[1].raw_response["text"]) | ||
assert parsed_data.get("extra_field") == "extra_value" | ||
assert parsed_data.get("waiting_time_us") == 2000 | ||
assert messages[1].is_replace_response == False | ||
assert messages[1].is_suggested_reply == False | ||
|
||
@pytest.mark.asyncio | ||
async def test_unknown_event_type( | ||
self, bot_context: _BotContext, base_request: QueryRequest | ||
) -> None: | ||
events = [MockSSEEvent("unknown", {"data": "test"}), MockSSEEvent("done", {})] | ||
|
||
messages = await self._run_query_request(events, bot_context, base_request) | ||
|
||
assert len(messages) == 0 # Unknown event should be ignored |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol lint