Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: preserve extra fields in text event responses #127

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/fastapi_poe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import (
Any,
AsyncIterable,
Awaitable,
BinaryIO,
Expand Down Expand Up @@ -297,7 +298,8 @@
filename: Optional[str] = None,
content_type: Optional[str] = None,
is_inline: bool = False,
) -> AttachmentUploadResponse: ...
) -> AttachmentUploadResponse:
...

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol lint


# This overload requires all parameters to be passed as keywords
@overload
Expand All @@ -310,7 +312,8 @@
filename: Optional[str] = None,
content_type: Optional[str] = None,
is_inline: bool = False,
) -> AttachmentUploadResponse: ...
) -> AttachmentUploadResponse:
...

async def post_message_attachment(
self,
Expand Down Expand Up @@ -612,8 +615,14 @@
return new_messages

@staticmethod
def text_event(text: str) -> ServerSentEvent:
return ServerSentEvent(data=json.dumps({"text": text}), event="text")
def text_event(
text: str, data: Optional[Dict[str, Any]] = dict()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make this data: Optional[MappingProxyType] = None since we don't want default args and if we're passing this dict around, we prob want it to be read-only? (have ppl copy it like {**data} if they need to create a mutable copy)

) -> ServerSentEvent:
return ServerSentEvent(
text=text,

Check failure on line 622 in src/fastapi_poe/base.py

View workflow job for this annotation

GitHub Actions / pyright

No parameter named "text" (reportCallIssue)
data=data,
event="text",
)

@staticmethod
def replace_response_event(text: str) -> ServerSentEvent:
Expand Down Expand Up @@ -724,7 +733,7 @@
elif event.is_replace_response:
yield self.replace_response_event(event.text)
else:
yield self.text_event(event.text)
yield self.text_event(event.text, event.data)
except Exception as e:
logger.exception("Error responding to query")
yield self.error_event(
Expand Down
19 changes: 18 additions & 1 deletion src/fastapi_poe/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,25 @@
)
return
elif event.event == "text":
text = await self._get_single_json_field(
data_dict = await self._load_json_dict(
event.data, "text", message_id
)
text = data_dict.get("text")
if not isinstance(text, str):
await self.report_error(
"Expected string in 'text' field for 'text' event",
{"data": data_dict, "message_id": message_id},
)
raise BotErrorNoRetry("Expected string in 'text' event")
chunks.append(text)
yield BotMessage(
text=text,
raw_response={"type": event.event, "text": event.data},
full_prompt=repr(request),
data=data_dict,
is_replace_response=(event.event == "replace_response"),
)
continue
elif event.event == "replace_response":
text = await self._get_single_json_field(
event.data, "replace_response", message_id
Expand Down Expand Up @@ -249,6 +265,7 @@
text=text,
raw_response={"type": event.event, "text": event.data},
full_prompt=repr(request),
data=event.data,

Check failure on line 268 in src/fastapi_poe/client.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "str" cannot be assigned to parameter "data" of type "Dict[str, Any] | None" in function "__init__"   Type "str" is not assignable to type "Dict[str, Any] | None"     "str" is not assignable to "Dict[str, Any]"     "str" is not assignable to "None" (reportArgumentType)
is_replace_response=(event.event == "replace_response"),
)
await self.report_error(
Expand Down
223 changes: 223 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from typing import Any, AsyncGenerator, Optional, Union
import json
import pytest
from unittest.mock import AsyncMock, patch
from fastapi_poe.client import (
_BotContext,
BotMessage,
MetaMessage,
BotError,
BotErrorNoRetry,
)
from fastapi_poe.types import QueryRequest


class MockSSEEvent:
def __init__(self, event_type: str, data: Union[dict[str, Any], str]) -> None:

Check failure on line 16 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / pyright

Subscript for class "dict" will generate runtime exception; enclose type expression in quotes (reportIndexIssue)
self.event: str = event_type
self.data: str = json.dumps(data) if isinstance(data, dict) else data


class MockEventSource:
def __init__(self, events: list[MockSSEEvent]) -> None:

Check failure on line 22 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / pyright

Subscript for class "list" will generate runtime exception; enclose type expression in quotes (reportIndexIssue)
self.events = events

async def aiter_sse(self) -> AsyncGenerator[MockSSEEvent, None]:
for event in self.events:
yield event

async def __aenter__(self) -> "MockEventSource":
return self

async def __aexit__(self, *args: Any) -> None:
pass


class TestPerformQueryRequest:
@pytest.fixture
def mock_session(self) -> AsyncMock:
return AsyncMock()

@pytest.fixture
def bot_context(self, mock_session: AsyncMock) -> _BotContext:
return _BotContext(
endpoint="https://test.com/bot", session=mock_session, api_key="test_key"
)

@pytest.fixture
def base_request(self) -> QueryRequest:
return QueryRequest(
query=[],
user_id="test_user",
conversation_id="test_conv",
message_id="test_message",
version="1.0",
type="query",
)

async def _run_query_request(
self, events: list[MockSSEEvent], context: _BotContext, request: QueryRequest

Check failure on line 59 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / pyright

Subscript for class "list" will generate runtime exception; enclose type expression in quotes (reportIndexIssue)
) -> list[BotMessage]:

Check failure on line 60 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / pyright

Subscript for class "list" will generate runtime exception; enclose type expression in quotes (reportIndexIssue)
"""Helper method to run query request and collect messages"""
messages: list[BotMessage] = []
with patch("httpx_sse.aconnect_sse", return_value=MockEventSource(events)):
async for msg in context.perform_query_request(
request=request, tools=None, tool_calls=None, tool_results=None
):
messages.append(msg)
return messages

@pytest.mark.asyncio
async def test_text_event_with_extra_fields(
self, bot_context: _BotContext, base_request: QueryRequest
) -> None:
events = [
MockSSEEvent("text", {"text": "Hello", "extra_field": "extra_value"}),
MockSSEEvent("done", {}),
]

messages = await self._run_query_request(events, bot_context, base_request)

assert len(messages) == 1
assert messages[0].text == "Hello"
assert isinstance(messages[0].raw_response, dict)
assert "text" in messages[0].raw_response
assert (
json.loads(messages[0].raw_response["text"]).get("extra_field")
== "extra_value"
)

@pytest.mark.asyncio
async def test_replace_response_event(
self, bot_context: _BotContext, base_request: QueryRequest
) -> None:
events = [
MockSSEEvent("text", {"text": "First"}),
MockSSEEvent("replace_response", {"text": "Replaced"}),
MockSSEEvent("done", {}),
]

messages = await self._run_query_request(events, bot_context, base_request)

assert len(messages) == 2
assert messages[1].is_replace_response == True
assert messages[1].text == "Replaced"

@pytest.mark.asyncio
async def test_meta_event(
self, bot_context: _BotContext, base_request: QueryRequest
) -> None:
events = [
MockSSEEvent(
"meta",
{
"linkify": True,
"suggested_replies": True,
"content_type": "text/markdown",
},
),
MockSSEEvent("done", {}),
]

messages = await self._run_query_request(events, bot_context, base_request)

assert len(messages) == 1
assert isinstance(messages[0], MetaMessage)
assert messages[0].linkify == True
assert messages[0].suggested_replies == True
assert messages[0].content_type == "text/markdown"

@pytest.mark.asyncio
async def test_error_event(
self, bot_context: _BotContext, base_request: QueryRequest
) -> None:
events = [
MockSSEEvent("error", {"message": "Test error", "allow_retry": False}),
]
with pytest.raises(BotErrorNoRetry):
await self._run_query_request(events, bot_context, base_request)

@pytest.mark.asyncio
async def test_invalid_text_event(
self, bot_context: _BotContext, base_request: QueryRequest
) -> None:
events = [
MockSSEEvent("text", {"text": None}), # Invalid text field
]
with pytest.raises(BotErrorNoRetry):
await self._run_query_request(events, bot_context, base_request)

@pytest.mark.asyncio
async def test_suggested_reply_event(
self, bot_context: _BotContext, base_request: QueryRequest
) -> None:
events = [
MockSSEEvent(
"text",
{"text": "Suggestion", "data": {"extra_field": "extra_value"}},
),
MockSSEEvent("done", {}),
]

messages = await self._run_query_request(events, bot_context, base_request)

assert len(messages) == 1
assert messages[0].is_suggested_reply == True
assert messages[0].text == "Suggestion"

async def test_multiple_text_events(
self, bot_context: _BotContext, base_request: QueryRequest
) -> None:
events = [
MockSSEEvent(
"text",
{
"text": "First",
"extra_field": "extra_value",
"waiting_time_us": 1000,
},
),
MockSSEEvent(
"text",
{
"text": "Second",
"extra_field": "extra_value",
"waiting_time_us": 2000,
},
),
MockSSEEvent("done", {}),
]

messages = await self._run_query_request(events, bot_context, base_request)

# Check number of messages
assert len(messages) == 2

# Check first message
assert messages[0].text == "First"
assert isinstance(messages[0].raw_response, dict)
assert "text" in messages[0].raw_response
parsed_data = json.loads(messages[0].raw_response["text"])
assert parsed_data.get("extra_field") == "extra_value"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's have different values for these between the messages

assert parsed_data.get("waiting_time_us") == 1000
assert messages[0].is_replace_response == False
assert messages[0].is_suggested_reply == False
# Check second message
assert messages[1].text == "Second"
assert isinstance(messages[1].raw_response, dict)
assert "text" in messages[1].raw_response
parsed_data = json.loads(messages[1].raw_response["text"])
assert parsed_data.get("extra_field") == "extra_value"
assert parsed_data.get("waiting_time_us") == 2000
assert messages[1].is_replace_response == False
assert messages[1].is_suggested_reply == False

@pytest.mark.asyncio
async def test_unknown_event_type(
self, bot_context: _BotContext, base_request: QueryRequest
) -> None:
events = [MockSSEEvent("unknown", {"data": "test"}), MockSSEEvent("done", {})]

messages = await self._run_query_request(events, bot_context, base_request)

assert len(messages) == 0 # Unknown event should be ignored
Loading