Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

forcing initial message in non-stream ws #105

Merged
merged 3 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions bolna/agent_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,12 @@ def __setup_output_handlers(self, turn_based_conversation, output_queue):
def __setup_input_handlers(self, turn_based_conversation, input_queue, should_record):
if self.task_config["tools_config"]["input"]["provider"] in SUPPORTED_INPUT_HANDLERS.keys():
logger.info(f"Connected through dashboard {turn_based_conversation}")
input_kwargs = {"queues": self.queues,
"websocket": self.websocket,
"input_types": get_required_input_types(self.task_config),
"mark_set": self.mark_set}
input_kwargs = {
"queues": self.queues,
"websocket": self.websocket,
"input_types": get_required_input_types(self.task_config),
"mark_set": self.mark_set
}

if self.task_config["tools_config"]["input"]["provider"] == "daily":
input_kwargs['room_url'] = self.room_url
Expand Down Expand Up @@ -581,7 +583,8 @@ def __setup_synthesizer(self, llm_config=None):
self.task_config["tools_config"]["synthesizer"]["stream"] = True if self.enforce_streaming else False #Hardcode stream to be False as we don't want to get blocked by a __listen_synthesizer co-routine

self.tools["synthesizer"] = synthesizer_class(**self.task_config["tools_config"]["synthesizer"], **provider_config, **self.kwargs, caching=caching)
self.synthesizer_monitor_task = asyncio.create_task(self.tools['synthesizer'].monitor_connection())
if not self.turn_based_conversation:
self.synthesizer_monitor_task = asyncio.create_task(self.tools['synthesizer'].monitor_connection())
if self.task_config["tools_config"]["llm_agent"] is not None and llm_config is not None:
llm_config["buffer_size"] = self.task_config["tools_config"]["synthesizer"].get('buffer_size')

Expand Down Expand Up @@ -928,7 +931,6 @@ def __update_preprocessed_tree_node(self):
# LLM task
##############################################################
async def _handle_llm_output(self, next_step, text_chunk, should_bypass_synth, meta_info, is_filler = False):

logger.info("received text from LLM for output processing: {} which belongs to sequence id {}".format(text_chunk, meta_info['sequence_id']))
if "request_id" not in meta_info:
meta_info["request_id"] = str(uuid.uuid4())
Expand Down Expand Up @@ -1192,7 +1194,7 @@ async def _process_conversation_task(self, message, sequence, meta_info):
logger.info("agent flow is not preprocessed")

start_time = time.time()
should_bypass_synth = 'bypass_synth' in meta_info and meta_info['bypass_synth'] == True
should_bypass_synth = 'bypass_synth' in meta_info and meta_info['bypass_synth'] is True
next_step = self._get_next_step(sequence, "llm")
meta_info['llm_start_time'] = time.time()
route = None
Expand Down Expand Up @@ -1926,7 +1928,15 @@ async def __first_message(self, timeout=10.0):
text = self.kwargs.get('agent_welcome_message', None)
logger.info(f"Generating {text}")
meta_info = {'io': self.tools["output"].get_provider(), 'message_category': 'agent_welcome_message', 'stream_sid': stream_sid, "request_id": str(uuid.uuid4()), "cached": True, "sequence_id": -1, 'format': self.task_config["tools_config"]["output"]["format"], 'text': text}
await self._synthesize(create_ws_data_packet(text, meta_info=meta_info))
if self.turn_based_conversation:
meta_info['type'] = 'text'
bos_packet = create_ws_data_packet("<beginning_of_stream>", meta_info)
await self.tools["output"].handle(bos_packet)
await self.tools["output"].handle(create_ws_data_packet(text, meta_info))
eos_packet = create_ws_data_packet("<end_of_stream>", meta_info)
await self.tools["output"].handle(eos_packet)
else:
await self._synthesize(create_ws_data_packet(text, meta_info=meta_info))
break
else:
logger.info(f"Stream id is still None, so not passing it")
Expand Down Expand Up @@ -1990,12 +2000,9 @@ async def run(self):

logger.info(f"Starting the first message task {self.enforce_streaming}")
self.output_task = asyncio.create_task(self.__process_output_loop())
self.first_message_task = asyncio.create_task(self.__first_message())
if not self.turn_based_conversation or self.enforce_streaming:
logger.info(f"Setting up other servers")
self.first_message_task = asyncio.create_task(self.__first_message())
#if not self.use_llm_to_determine_hangup :
# By default we will hang up after x amount of silence
# We still need to
self.hangup_task = asyncio.create_task(self.__check_for_completion())

if self.should_backchannel:
Expand Down
4 changes: 4 additions & 0 deletions bolna/input_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import base64
import time
import uuid
from dotenv import load_dotenv
from bolna.helpers.logger_config import configure_logger
from bolna.helpers.utils import create_ws_data_packet
Expand Down Expand Up @@ -28,6 +29,9 @@ async def stop_handler(self):
except Exception as e:
logger.error(f"Error closing WebSocket: {e}")

def get_stream_sid(self):
return str(uuid.uuid4())

def __process_audio(self, audio):
data = base64.b64decode(audio)
ws_data_packet = create_ws_data_packet(
Expand Down