Skip to content

Commit

Permalink
Litellm bug mitigation for tool calls
Browse files Browse the repository at this point in the history
  • Loading branch information
winternewt committed Jan 8, 2025
1 parent e431395 commit 54782fd
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
19 changes: 17 additions & 2 deletions core/just_agents/protocols/litellm_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,23 @@ def tool_calls_from_message(self, message: MessageDict) -> List[LiteLLMFunctionC
LiteLLMFunctionCall(**tool_call)
for tool_call in tool_calls
]
@staticmethod
def reenumerate_tool_call_chunks(chunks : List[Any]):
tool_calls = []
message = None
for chunk in chunks:
if (
len(chunk["choices"]) > 0
and "tool_calls" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["tool_calls"]
):
message = stream_chunk_builder(chunks=[chunk,chunks[-1]])
tool_calls.append(message.choices[0].message.tool_calls[0])
message.choices[0].message.tool_calls = tool_calls
return message

def response_from_deltas(self, chunks: List[Any]) -> ModelResponse:
return stream_chunk_builder(chunks)
#complete_response = litellm.stream_chunk_builder(chunks=chunks, messages=messages)
if "llama" in chunks[-1]["model"] and chunks[-1].choices[0].finish_reason=="tool_calls":
return self.reenumerate_tool_call_chunks(chunks) # bug fix
return stream_chunk_builder(chunks=chunks)

18 changes: 16 additions & 2 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,25 @@
import pytest
from typing import Callable, Any
from just_agents.base_agent import BaseAgent
from just_agents.llm_options import LLMOptions, LLAMA3_3, OPENAI_GPT4oMINI
from just_agents.llm_options import LLMOptions, LLAMA3_3, LLAMA3_2_VISION, OPENAI_GPT4oMINI
from just_agents.just_tool import JustToolsBus
@pytest.fixture(scope="module", autouse=True)
def load_env():
load_dotenv(override=True)

def get_current_weather(location: str):
"""Gets the current weather in a given location"""
"""
Gets the current weather in a given location
Args:
location (str): The name of the location for which to get the weather.
Returns:
str: A JSON-encoded string with the following keys:
- "location" (str): The location name.
- "temperature" (str): The temperature value, or "unknown" if not recognized.
- "unit" (str, optional): The unit of measurement for temperature (e.g., "celsius", "fahrenheit").
"""
if "tokyo" in location.lower():
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"})
elif "san francisco" in location.lower():
Expand Down Expand Up @@ -90,4 +101,7 @@ def test_query_tool():
def test_stream_tool():
validate_tool_call(agent_call, OPENAI_GPT4oMINI, False)

def test_stream_tool_grok():
validate_tool_call(agent_call, LLAMA3_3, False)


0 comments on commit 54782fd

Please sign in to comment.