Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Yield FunctionResultContent in streaming chat completion path. Update tests. #9974

Merged
merged 13 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

from semantic_kernel.connectors.ai.function_call_behavior import FunctionCallBehavior
from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration
from semantic_kernel.connectors.ai.function_calling_utils import merge_function_results
from semantic_kernel.connectors.ai.function_calling_utils import (
merge_function_results,
merge_streaming_function_results,
)
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior, FunctionChoiceType
from semantic_kernel.const import AUTO_FUNCTION_INVOCATION_SPAN_NAME
from semantic_kernel.contents.annotation_content import AnnotationContent
Expand Down Expand Up @@ -303,8 +306,18 @@ async def get_streaming_chat_message_contents(
],
)

# Merge and yield the function results, regardless of the termination status
# Include the ai_model_id so we can later add two streaming messages together
# Some settings may not have an ai_model_id, so we need to check for it
ai_model_id = self._get_ai_model_id(settings)
TaoChenOSU marked this conversation as resolved.
Show resolved Hide resolved
function_result_messages = merge_streaming_function_results(
messages=chat_history.messages[-len(results) :],
ai_model_id=ai_model_id, # type: ignore
)
if self._yield_function_result_messages(function_result_messages):
yield function_result_messages
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved

if any(result.terminate for result in results if result is not None):
yield merge_function_results(chat_history.messages[-len(results) :]) # type: ignore
break

async def get_streaming_chat_message_content(
Expand Down Expand Up @@ -415,4 +428,12 @@ def _start_auto_function_invocation_activity(self, kernel: "Kernel", settings: "

return span

def _get_ai_model_id(self, settings: "PromptExecutionSettings") -> str:
"""Retrieve the AI model ID from settings if available."""
return getattr(settings, "ai_model_id", "")

def _yield_function_result_messages(self, function_result_messages: list) -> bool:
"""Determine if the function result messages should be yielded."""
return len(function_result_messages) > 0 and len(function_result_messages[0].items) > 0

# endregion
28 changes: 28 additions & 0 deletions python/semantic_kernel/connectors/ai/function_calling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError

Expand Down Expand Up @@ -95,3 +96,30 @@ def merge_function_results(
items=items,
)
]


def merge_streaming_function_results(
messages: list[ChatMessageContent | StreamingChatMessageContent],
ai_model_id: str,
) -> list[StreamingChatMessageContent]:
"""Combine multiple streaming function result content types to one streaming chat message content type.

This method combines the FunctionResultContent items from separate StreamingChatMessageContent messages,
and is used in the event that the `context.terminate = True` condition is met.

Args:
messages: The list of streaming chat message content types.
ai_model_id: The AI model ID.

Returns:
The combined streaming chat message content type.
"""
items: list[Any] = []
for message in messages:
items.extend([item for item in message.items if isinstance(item, FunctionResultContent)])

# If we want to be able to support adding the streaming message chunks together, then the author role needs to be
# `Assistant```, as the `Tool` role will cause the add method to break.
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved
return [
StreamingChatMessageContent(role=AuthorRole.ASSISTANT, items=items, choice_index=0, ai_model_id=ai_model_id)
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved
]
3 changes: 2 additions & 1 deletion python/semantic_kernel/contents/function_call_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def parse_arguments(self) -> Mapping[str, Any] | None:
if isinstance(self.arguments, Mapping):
return self.arguments
try:
return json.loads(self.arguments)
sanitized_arguments = self.arguments.replace("'", '"')
return json.loads(sanitized_arguments)
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved
except json.JSONDecodeError as exc:
raise FunctionCallInvalidArgumentsException("Function Call arguments are not valid JSON.") from exc

Expand Down
7 changes: 7 additions & 0 deletions python/semantic_kernel/contents/function_result_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
if TYPE_CHECKING:
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.functions.function_result import FunctionResult

TAG_CONTENT_MAP = {
Expand Down Expand Up @@ -157,6 +158,12 @@ def to_chat_message_content(self) -> "ChatMessageContent":

return ChatMessageContent(role=AuthorRole.TOOL, items=[self])

def to_streaming_chat_message_content(self) -> "StreamingChatMessageContent":
"""Convert the instance to a StreamingChatMessageContent."""
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent

return StreamingChatMessageContent(role=AuthorRole.TOOL, choice_index=0, items=[self])

def to_dict(self) -> dict[str, str]:
"""Convert the instance to a dictionary."""
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from html import unescape
from typing import TYPE_CHECKING, Any

import yaml
import yaml # type: ignore
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved
from pydantic import Field, ValidationError, model_validator

from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
Expand Down
8 changes: 7 additions & 1 deletion python/semantic_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.contents.streaming_content_mixin import StreamingContentMixin
from semantic_kernel.exceptions import (
FunctionCallInvalidArgumentsException,
Expand Down Expand Up @@ -398,7 +399,12 @@ async def invoke_function_call(
frc = FunctionResultContent.from_function_call_content_and_result(
function_call_content=function_call, result=invocation_context.function_result
)
chat_history.add_message(message=frc.to_chat_message_content())

is_streaming = any(isinstance(message, StreamingChatMessageContent) for message in chat_history.messages)
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved

message = frc.to_streaming_chat_message_content() if is_streaming else frc.to_chat_message_content()

chat_history.add_message(message=message)

return invocation_context if invocation_context.terminate else None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ def mock_azure_ai_inference_streaming_chat_completion_response_with_tool_call(mo
ChatCompletionsToolCall(
id="test_id",
function=FunctionCall(
name="test_function",
arguments={"test_arg": "test_value"},
name="getLightStatus",
arguments={"arg1": "test_value"},
),
),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ServiceInvalidExecutionSettingsError,
)
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.kernel import Kernel
from semantic_kernel.utils.telemetry.user_agent import SEMANTIC_KERNEL_USER_AGENT


Expand Down Expand Up @@ -492,11 +493,12 @@ async def test_azure_ai_inference_streaming_chat_completion_with_function_choice
async def test_azure_ai_inference_streaming_chat_completion_with_function_choice_behavior(
mock_complete,
azure_ai_inference_service,
kernel,
kernel: Kernel,
chat_history: ChatHistory,
mock_azure_ai_inference_streaming_chat_completion_response_with_tool_call,
decorated_native_function,
) -> None:
"""Test streaming completion of AzureAIInferenceChatCompletion with function choice behavior"""
"""Test streaming completion of AzureAIInferenceChatCompletion with function choice behavior."""
user_message_content: str = "Hello"
chat_history.add_user_message(user_message_content)

Expand All @@ -507,20 +509,31 @@ async def test_azure_ai_inference_streaming_chat_completion_with_function_choice

mock_complete.return_value = mock_azure_ai_inference_streaming_chat_completion_response_with_tool_call

kernel.add_function(plugin_name="TestPlugin", function=decorated_native_function)

all_messages = []
async for messages in azure_ai_inference_service.get_streaming_chat_message_contents(
chat_history,
settings,
kernel=kernel,
arguments=KernelArguments(),
):
assert len(messages) == 1
assert messages[0].role == "assistant"
assert messages[0].content == ""
assert messages[0].finish_reason == FinishReason.TOOL_CALLS
all_messages.extend(messages)

# Assert the number of total messages
assert len(all_messages) == 2, f"Expected 2 messages, got {len(all_messages)}"

# Validate the first message
assert all_messages[0].role == "assistant", f"Unexpected role for first message: {all_messages[0].role}"
assert all_messages[0].content == "", f"Unexpected content for first message: {all_messages[0].content}"
assert all_messages[0].finish_reason == FinishReason.TOOL_CALLS, (
TaoChenOSU marked this conversation as resolved.
Show resolved Hide resolved
f"Unexpected finish reason for first message: {all_messages[0].finish_reason}"
)

# Streaming completion with tool call does not invoke the model
# after maximum_auto_invoke_attempts is reached
assert mock_complete.call_count == 1
# Validate the second message
assert all_messages[1].role == "assistant", f"Unexpected role for second message: {all_messages[1].role}"
assert all_messages[1].content == "", f"Unexpected content for second message: {all_messages[1].content}"
assert all_messages[1].finish_reason is None


@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ async def mock_google_ai_streaming_chat_completion_response_with_tool_call() ->
parts=[
protos.Part(
function_call=protos.FunctionCall(
name="test_function",
args={"test_arg": "test_value"},
name="getLightStatus",
args={"arg1": "test_value"},
)
)
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ServiceInitializationError,
ServiceInvalidExecutionSettingsError,
)
from semantic_kernel.kernel import Kernel


# region init
Expand Down Expand Up @@ -259,9 +260,10 @@ async def test_google_ai_streaming_chat_completion_with_function_choice_behavior
async def test_google_ai_streaming_chat_completion_with_function_choice_behavior(
mock_google_ai_model_generate_content_async,
google_ai_unit_test_env,
kernel,
kernel: Kernel,
chat_history: ChatHistory,
mock_google_ai_streaming_chat_completion_response_with_tool_call,
decorated_native_function,
) -> None:
"""Test streaming chat completion of GoogleAIChatCompletion with function choice behavior"""
mock_google_ai_model_generate_content_async.return_value = (
Expand All @@ -275,20 +277,29 @@ async def test_google_ai_streaming_chat_completion_with_function_choice_behavior

google_ai_chat_completion = GoogleAIChatCompletion()

kernel.add_function(plugin_name="TestPlugin", function=decorated_native_function)

all_messages = []
async for messages in google_ai_chat_completion.get_streaming_chat_message_contents(
chat_history,
settings,
kernel=kernel,
):
assert len(messages) == 1
assert messages[0].role == "assistant"
assert messages[0].content == ""
# Google doesn't return STOP as the finish reason for tool calls
assert messages[0].finish_reason == FinishReason.STOP
all_messages.extend(messages)

assert len(all_messages) == 2, f"Expected 2 messages, got {len(all_messages)}"
TaoChenOSU marked this conversation as resolved.
Show resolved Hide resolved

# Validate the first message
assert all_messages[0].role == "assistant", f"Unexpected role for first message: {all_messages[0].role}"
assert all_messages[0].content == "", f"Unexpected content for first message: {all_messages[0].content}"
assert all_messages[0].finish_reason == FinishReason.STOP, (
f"Unexpected finish reason for first message: {all_messages[0].finish_reason}"
)

# Streaming completion with tool call does not invoke the model
# after maximum_auto_invoke_attempts is reached
assert mock_google_ai_model_generate_content_async.call_count == 1
# Validate the second message
assert all_messages[1].role == "assistant", f"Unexpected role for second message: {all_messages[1].role}"
assert all_messages[1].content == "", f"Unexpected content for second message: {all_messages[1].content}"
assert all_messages[1].finish_reason is None


@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def mock_vertex_ai_streaming_chat_completion_response_with_tool_call() -> AsyncI
parts=[
Part(
function_call=FunctionCall(
name="test_function",
args={"test_arg": "test_value"},
name="getLightStatus",
args={"arg1": "test_value"},
)
)
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ServiceInitializationError,
ServiceInvalidExecutionSettingsError,
)
from semantic_kernel.kernel import Kernel


# region init
Expand Down Expand Up @@ -259,9 +260,10 @@ async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior
async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior(
mock_vertex_ai_model_generate_content_async,
vertex_ai_unit_test_env,
kernel,
kernel: Kernel,
chat_history: ChatHistory,
mock_vertex_ai_streaming_chat_completion_response_with_tool_call,
decorated_native_function,
) -> None:
"""Test streaming chat completion of VertexAIChatCompletion with function choice behavior"""
mock_vertex_ai_model_generate_content_async.return_value = (
Expand All @@ -275,20 +277,29 @@ async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior

vertex_ai_chat_completion = VertexAIChatCompletion()

kernel.add_function(plugin_name="TestPlugin", function=decorated_native_function)

all_messages = []
async for messages in vertex_ai_chat_completion.get_streaming_chat_message_contents(
chat_history,
settings,
kernel=kernel,
):
assert len(messages) == 1
assert messages[0].role == "assistant"
assert messages[0].content == ""
# Google doesn't return STOP as the finish reason for tool calls
assert messages[0].finish_reason == FinishReason.STOP
all_messages.extend(messages)

assert len(all_messages) == 2, f"Expected 2 messages, got {len(all_messages)}"
moonbox3 marked this conversation as resolved.
Show resolved Hide resolved

# Validate the first message
assert all_messages[0].role == "assistant", f"Unexpected role for first message: {all_messages[0].role}"
assert all_messages[0].content == "", f"Unexpected content for first message: {all_messages[0].content}"
assert all_messages[0].finish_reason == FinishReason.STOP, (
f"Unexpected finish reason for first message: {all_messages[0].finish_reason}"
)

# Streaming completion with tool call does not invoke the model
# after maximum_auto_invoke_attempts is reached
assert mock_vertex_ai_model_generate_content_async.call_count == 1
# Validate the second message
assert all_messages[1].role == "assistant", f"Unexpected role for second message: {all_messages[1].role}"
assert all_messages[1].content == "", f"Unexpected content for second message: {all_messages[1].content}"
assert all_messages[1].finish_reason is None


@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock)
Expand Down
2 changes: 2 additions & 0 deletions python/tests/unit/kernel/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from semantic_kernel.contents import ChatMessageContent
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.exceptions import KernelFunctionAlreadyExistsError, KernelServiceNotFoundError
from semantic_kernel.exceptions.content_exceptions import FunctionCallInvalidArgumentsException
from semantic_kernel.exceptions.kernel_exceptions import (
Expand Down Expand Up @@ -299,6 +300,7 @@ async def test_invoke_function_call_throws_during_invoke(kernel: Kernel, get_too
result_mock = MagicMock(spec=ChatMessageContent)
result_mock.items = [tool_call_mock]
chat_history_mock = MagicMock(spec=ChatHistory)
chat_history_mock.messages = [MagicMock(spec=StreamingChatMessageContent)]

func_mock = AsyncMock(spec=KernelFunction)
func_meta = KernelFunctionMetadata(name="function", is_prompt=False)
Expand Down
Loading