diff --git a/bolna/agent_manager/task_manager.py b/bolna/agent_manager/task_manager.py index d53b5a4..07e37c2 100644 --- a/bolna/agent_manager/task_manager.py +++ b/bolna/agent_manager/task_manager.py @@ -510,10 +510,12 @@ def __setup_output_handlers(self, turn_based_conversation, output_queue): def __setup_input_handlers(self, turn_based_conversation, input_queue, should_record): if self.task_config["tools_config"]["input"]["provider"] in SUPPORTED_INPUT_HANDLERS.keys(): logger.info(f"Connected through dashboard {turn_based_conversation}") - input_kwargs = {"queues": self.queues, - "websocket": self.websocket, - "input_types": get_required_input_types(self.task_config), - "mark_set": self.mark_set} + input_kwargs = { + "queues": self.queues, + "websocket": self.websocket, + "input_types": get_required_input_types(self.task_config), + "mark_set": self.mark_set + } if self.task_config["tools_config"]["input"]["provider"] == "daily": input_kwargs['room_url'] = self.room_url @@ -581,7 +583,8 @@ def __setup_synthesizer(self, llm_config=None): self.task_config["tools_config"]["synthesizer"]["stream"] = True if self.enforce_streaming else False #Hardcode stream to be False as we don't want to get blocked by a __listen_synthesizer co-routine self.tools["synthesizer"] = synthesizer_class(**self.task_config["tools_config"]["synthesizer"], **provider_config, **self.kwargs, caching=caching) - self.synthesizer_monitor_task = asyncio.create_task(self.tools['synthesizer'].monitor_connection()) + if not self.turn_based_conversation: + self.synthesizer_monitor_task = asyncio.create_task(self.tools['synthesizer'].monitor_connection()) if self.task_config["tools_config"]["llm_agent"] is not None and llm_config is not None: llm_config["buffer_size"] = self.task_config["tools_config"]["synthesizer"].get('buffer_size') @@ -928,7 +931,6 @@ def __update_preprocessed_tree_node(self): # LLM task ############################################################## async def _handle_llm_output(self, next_step, text_chunk, should_bypass_synth, meta_info, is_filler = False): - logger.info("received text from LLM for output processing: {} which belongs to sequence id {}".format(text_chunk, meta_info['sequence_id'])) if "request_id" not in meta_info: meta_info["request_id"] = str(uuid.uuid4()) @@ -1192,7 +1194,7 @@ async def _process_conversation_task(self, message, sequence, meta_info): logger.info("agent flow is not preprocessed") start_time = time.time() - should_bypass_synth = 'bypass_synth' in meta_info and meta_info['bypass_synth'] == True + should_bypass_synth = 'bypass_synth' in meta_info and meta_info['bypass_synth'] is True next_step = self._get_next_step(sequence, "llm") meta_info['llm_start_time'] = time.time() route = None @@ -1926,7 +1928,15 @@ async def __first_message(self, timeout=10.0): text = self.kwargs.get('agent_welcome_message', None) logger.info(f"Generating {text}") meta_info = {'io': self.tools["output"].get_provider(), 'message_category': 'agent_welcome_message', 'stream_sid': stream_sid, "request_id": str(uuid.uuid4()), "cached": True, "sequence_id": -1, 'format': self.task_config["tools_config"]["output"]["format"], 'text': text} - await self._synthesize(create_ws_data_packet(text, meta_info=meta_info)) + if self.turn_based_conversation: + meta_info['type'] = 'text' + bos_packet = create_ws_data_packet("", meta_info) + await self.tools["output"].handle(bos_packet) + await self.tools["output"].handle(create_ws_data_packet(text, meta_info)) + eos_packet = create_ws_data_packet("", meta_info) + await self.tools["output"].handle(eos_packet) + else: + await self._synthesize(create_ws_data_packet(text, meta_info=meta_info)) break else: logger.info(f"Stream id is still None, so not passing it") @@ -1990,12 +2000,9 @@ async def run(self): logger.info(f"Starting the first message task {self.enforce_streaming}") self.output_task = asyncio.create_task(self.__process_output_loop()) + self.first_message_task = asyncio.create_task(self.__first_message()) if not self.turn_based_conversation or self.enforce_streaming: logger.info(f"Setting up other servers") - self.first_message_task = asyncio.create_task(self.__first_message()) - #if not self.use_llm_to_determine_hangup : - # By default we will hang up after x amount of silence - # We still need to self.hangup_task = asyncio.create_task(self.__check_for_completion()) if self.should_backchannel: diff --git a/bolna/input_handlers/default.py b/bolna/input_handlers/default.py index aca6982..ef982ba 100644 --- a/bolna/input_handlers/default.py +++ b/bolna/input_handlers/default.py @@ -1,6 +1,7 @@ import asyncio import base64 import time +import uuid from dotenv import load_dotenv from bolna.helpers.logger_config import configure_logger from bolna.helpers.utils import create_ws_data_packet @@ -28,6 +29,9 @@ async def stop_handler(self): except Exception as e: logger.error(f"Error closing WebSocket: {e}") + def get_stream_sid(self): + return str(uuid.uuid4()) + def __process_audio(self, audio): data = base64.b64decode(audio) ws_data_packet = create_ws_data_packet(