From 9f9f2ad9cd5ad471c3612100f3ba392846f5f63c Mon Sep 17 00:00:00 2001 From: nitin Date: Fri, 22 Nov 2024 13:51:36 +0530 Subject: [PATCH] fix: tools support issue --- realtime_agent/agent.py | 24 ++++++++++++- realtime_agent/main.py | 3 ++ realtime_agent/realtime/tools_example.py | 44 ++++++++++++++++++++++++ realtime_agent/tools.py | 16 ++++----- 4 files changed, 76 insertions(+), 11 deletions(-) create mode 100644 realtime_agent/realtime/tools_example.py diff --git a/realtime_agent/agent.py b/realtime_agent/agent.py index 992c991..88eb895 100644 --- a/realtime_agent/agent.py +++ b/realtime_agent/agent.py @@ -11,7 +11,7 @@ from agora_realtime_ai_api.rtc import Channel, ChatMessage, RtcEngine, RtcOptions from .logger import setup_logger -from .realtime.struct import InputAudioBufferCommitted, InputAudioBufferSpeechStarted, InputAudioBufferSpeechStopped, InputAudioTranscription, ItemCreated, ItemInputAudioTranscriptionCompleted, RateLimitsUpdated, ResponseAudioDelta, ResponseAudioDone, ResponseAudioTranscriptDelta, ResponseAudioTranscriptDone, ResponseContentPartAdded, ResponseContentPartDone, ResponseCreated, ResponseDone, ResponseOutputItemAdded, ResponseOutputItemDone, ServerVADUpdateParams, SessionUpdate, SessionUpdateParams, SessionUpdated, Voices, to_json +from .realtime.struct import FunctionCallOutputItemParam, InputAudioBufferCommitted, InputAudioBufferSpeechStarted, InputAudioBufferSpeechStopped, InputAudioTranscription, ItemCreate, ItemCreated, ItemInputAudioTranscriptionCompleted, RateLimitsUpdated, ResponseAudioDelta, ResponseAudioDone, ResponseAudioTranscriptDelta, ResponseAudioTranscriptDone, ResponseContentPartAdded, ResponseContentPartDone, ResponseCreate, ResponseCreated, ResponseDone, ResponseFunctionCallArgumentsDelta, ResponseFunctionCallArgumentsDone, ResponseOutputItemAdded, ResponseOutputItemDone, ServerVADUpdateParams, SessionUpdate, SessionUpdateParams, SessionUpdated, Voices, to_json from .realtime.connection import RealtimeApiConnection from .tools import ClientToolCallResponse, ToolContext from .utils import PCMWriter @@ -240,6 +240,21 @@ async def model_to_rtc(self) -> None: await pcm_writer.flush() raise # Re-raise the cancelled exception to properly exit the task + async def handle_funtion_call(self, message: ResponseFunctionCallArgumentsDone) -> None: + function_call_response = await self.tools.execute_tool(message.name, message.arguments) + logger.info(f"Function call response: {function_call_response}") + await self.connection.send_request( + ItemCreate( + item = FunctionCallOutputItemParam( + call_id=message.call_id, + output=function_call_response.json_encoded_output + ) + ) + ) + await self.connection.send_request( + ResponseCreate() + ) + async def _process_model_messages(self) -> None: async for message in self.connection.listen(): # logger.info(f"Received message {message=}") @@ -312,5 +327,12 @@ async def _process_model_messages(self) -> None: pass case RateLimitsUpdated(): pass + case ResponseFunctionCallArgumentsDone(): + asyncio.create_task( + self.handle_funtion_call(message) + ) + case ResponseFunctionCallArgumentsDelta(): + pass + case _: logger.warning(f"Unhandled message {message=}") diff --git a/realtime_agent/main.py b/realtime_agent/main.py index 484b1de..4d2ce00 100644 --- a/realtime_agent/main.py +++ b/realtime_agent/main.py @@ -9,6 +9,8 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field, ValidationError +from realtime_agent.realtime.tools_example import AgentTools + from .realtime.struct import PCM_CHANNELS, PCM_SAMPLE_RATE, ServerVADUpdateParams, Voices from .agent import InferenceConfig, RealtimeKitAgent @@ -82,6 +84,7 @@ def run_agent_in_process( ), inference_config=inference_config, tools=None, + # tools=AgentTools() # tools example, replace with this line ) ) diff --git a/realtime_agent/realtime/tools_example.py b/realtime_agent/realtime/tools_example.py new file mode 100644 index 0000000..9c9192a --- /dev/null +++ b/realtime_agent/realtime/tools_example.py @@ -0,0 +1,44 @@ + +from typing import Any +from realtime_agent.tools import ToolContext + +# Function calling Example +# This is an example of how to add a new function to the agent tools. + +class AgentTools(ToolContext): + def __init__(self) -> None: + super().__init__() + + # create multiple functions here as per requirement + self.register_function( + name="get_avg_temp", + description="Returns average temperature of a country", + parameters={ + "type": "object", + "properties": { + "country": { + "type": "string", + "description": "Name of country", + }, + }, + "required": ["country"], + }, + fn=self._get_avg_temperature_by_country_name, + ) + + async def _get_avg_temperature_by_country_name( + self, + country: str, + ) -> dict[str, Any]: + try: + result = "24 degree C" # Dummy data (Get the Required value here, like a DB call or API call) + return { + "status": "success", + "message": f"Average temperature of {country} is {result}", + "result": result, + } + except Exception as e: + return { + "status": "error", + "message": f"Failed to get : {str(e)}", + } \ No newline at end of file diff --git a/realtime_agent/tools.py b/realtime_agent/tools.py index b83f230..4f56b5d 100644 --- a/realtime_agent/tools.py +++ b/realtime_agent/tools.py @@ -24,11 +24,9 @@ class LocalFunctionToolDeclaration: def model_description(self) -> dict[str, Any]: return { "type": "function", - "function": { - "name": self.name, - "description": self.description, - "parameters": self.parameters, - }, + "name": self.name, + "description": self.description, + "parameters": self.parameters, } @@ -43,11 +41,9 @@ class PassThroughFunctionToolDeclaration: def model_description(self) -> dict[str, Any]: return { "type": "function", - "function": { - "name": self.name, - "description": self.description, - "parameters": self.parameters, - }, + "name": self.name, + "description": self.description, + "parameters": self.parameters, }