diff --git a/core/just_agents/protocols/litellm_protocol.py b/core/just_agents/protocols/litellm_protocol.py index e384932..e5ee1e7 100644 --- a/core/just_agents/protocols/litellm_protocol.py +++ b/core/just_agents/protocols/litellm_protocol.py @@ -126,8 +126,23 @@ def tool_calls_from_message(self, message: MessageDict) -> List[LiteLLMFunctionC LiteLLMFunctionCall(**tool_call) for tool_call in tool_calls ] + @staticmethod + def reenumerate_tool_call_chunks(chunks : List[Any]): + tool_calls = [] + message = None + for chunk in chunks: + if ( + len(chunk["choices"]) > 0 + and "tool_calls" in chunk["choices"][0]["delta"] + and chunk["choices"][0]["delta"]["tool_calls"] + ): + message = stream_chunk_builder(chunks=[chunk,chunks[-1]]) + tool_calls.append(message.choices[0].message.tool_calls[0]) + message.choices[0].message.tool_calls = tool_calls + return message def response_from_deltas(self, chunks: List[Any]) -> ModelResponse: - return stream_chunk_builder(chunks) - #complete_response = litellm.stream_chunk_builder(chunks=chunks, messages=messages) + if "llama" in chunks[-1]["model"] and chunks[-1].choices[0].finish_reason=="tool_calls": + return self.reenumerate_tool_call_chunks(chunks) # bug fix + return stream_chunk_builder(chunks=chunks) diff --git a/tests/test_stream.py b/tests/test_stream.py index f18579e..5880bda 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -3,14 +3,25 @@ import pytest from typing import Callable, Any from just_agents.base_agent import BaseAgent -from just_agents.llm_options import LLMOptions, LLAMA3_3, OPENAI_GPT4oMINI +from just_agents.llm_options import LLMOptions, LLAMA3_3, LLAMA3_2_VISION, OPENAI_GPT4oMINI from just_agents.just_tool import JustToolsBus @pytest.fixture(scope="module", autouse=True) def load_env(): load_dotenv(override=True) def get_current_weather(location: str): - """Gets the current weather in a given location""" + """ + Gets the current weather in a given location + + Args: + location (str): The name of the location for which to get the weather. + + Returns: + str: A JSON-encoded string with the following keys: + - "location" (str): The location name. + - "temperature" (str): The temperature value, or "unknown" if not recognized. + - "unit" (str, optional): The unit of measurement for temperature (e.g., "celsius", "fahrenheit"). + """ if "tokyo" in location.lower(): return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"}) elif "san francisco" in location.lower(): @@ -90,4 +101,7 @@ def test_query_tool(): def test_stream_tool(): validate_tool_call(agent_call, OPENAI_GPT4oMINI, False) +def test_stream_tool_grok(): + validate_tool_call(agent_call, LLAMA3_3, False) +