Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: tools support issue #11

Merged
merged 1 commit into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion realtime_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from agora_realtime_ai_api.rtc import Channel, ChatMessage, RtcEngine, RtcOptions

from .logger import setup_logger
from .realtime.struct import InputAudioBufferCommitted, InputAudioBufferSpeechStarted, InputAudioBufferSpeechStopped, InputAudioTranscription, ItemCreated, ItemInputAudioTranscriptionCompleted, RateLimitsUpdated, ResponseAudioDelta, ResponseAudioDone, ResponseAudioTranscriptDelta, ResponseAudioTranscriptDone, ResponseContentPartAdded, ResponseContentPartDone, ResponseCreated, ResponseDone, ResponseOutputItemAdded, ResponseOutputItemDone, ServerVADUpdateParams, SessionUpdate, SessionUpdateParams, SessionUpdated, Voices, to_json
from .realtime.struct import FunctionCallOutputItemParam, InputAudioBufferCommitted, InputAudioBufferSpeechStarted, InputAudioBufferSpeechStopped, InputAudioTranscription, ItemCreate, ItemCreated, ItemInputAudioTranscriptionCompleted, RateLimitsUpdated, ResponseAudioDelta, ResponseAudioDone, ResponseAudioTranscriptDelta, ResponseAudioTranscriptDone, ResponseContentPartAdded, ResponseContentPartDone, ResponseCreate, ResponseCreated, ResponseDone, ResponseFunctionCallArgumentsDelta, ResponseFunctionCallArgumentsDone, ResponseOutputItemAdded, ResponseOutputItemDone, ServerVADUpdateParams, SessionUpdate, SessionUpdateParams, SessionUpdated, Voices, to_json
from .realtime.connection import RealtimeApiConnection
from .tools import ClientToolCallResponse, ToolContext
from .utils import PCMWriter
Expand Down Expand Up @@ -240,6 +240,21 @@ async def model_to_rtc(self) -> None:
await pcm_writer.flush()
raise # Re-raise the cancelled exception to properly exit the task

async def handle_funtion_call(self, message: ResponseFunctionCallArgumentsDone) -> None:
function_call_response = await self.tools.execute_tool(message.name, message.arguments)
logger.info(f"Function call response: {function_call_response}")
await self.connection.send_request(
ItemCreate(
item = FunctionCallOutputItemParam(
call_id=message.call_id,
output=function_call_response.json_encoded_output
)
)
)
await self.connection.send_request(
ResponseCreate()
)

async def _process_model_messages(self) -> None:
async for message in self.connection.listen():
# logger.info(f"Received message {message=}")
Expand Down Expand Up @@ -312,5 +327,12 @@ async def _process_model_messages(self) -> None:
pass
case RateLimitsUpdated():
pass
case ResponseFunctionCallArgumentsDone():
asyncio.create_task(
self.handle_funtion_call(message)
)
case ResponseFunctionCallArgumentsDelta():
pass

case _:
logger.warning(f"Unhandled message {message=}")
3 changes: 3 additions & 0 deletions realtime_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError

from realtime_agent.realtime.tools_example import AgentTools

from .realtime.struct import PCM_CHANNELS, PCM_SAMPLE_RATE, ServerVADUpdateParams, Voices

from .agent import InferenceConfig, RealtimeKitAgent
Expand Down Expand Up @@ -82,6 +84,7 @@ def run_agent_in_process(
),
inference_config=inference_config,
tools=None,
# tools=AgentTools() # tools example, replace with this line
)
)

Expand Down
44 changes: 44 additions & 0 deletions realtime_agent/realtime/tools_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

from typing import Any
from realtime_agent.tools import ToolContext

# Function calling Example
# This is an example of how to add a new function to the agent tools.

class AgentTools(ToolContext):
def __init__(self) -> None:
super().__init__()

# create multiple functions here as per requirement
self.register_function(
name="get_avg_temp",
description="Returns average temperature of a country",
parameters={
"type": "object",
"properties": {
"country": {
"type": "string",
"description": "Name of country",
},
},
"required": ["country"],
},
fn=self._get_avg_temperature_by_country_name,
)

async def _get_avg_temperature_by_country_name(
self,
country: str,
) -> dict[str, Any]:
try:
result = "24 degree C" # Dummy data (Get the Required value here, like a DB call or API call)
return {
"status": "success",
"message": f"Average temperature of {country} is {result}",
"result": result,
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to get : {str(e)}",
}
16 changes: 6 additions & 10 deletions realtime_agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ class LocalFunctionToolDeclaration:
def model_description(self) -> dict[str, Any]:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
},
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}


Expand All @@ -43,11 +41,9 @@ class PassThroughFunctionToolDeclaration:
def model_description(self) -> dict[str, Any]:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
},
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}


Expand Down