From b724571748f2b453c7c2f259e2cae5e5e82c6790 Mon Sep 17 00:00:00 2001 From: Mohamed Boussaid Date: Fri, 24 Jan 2025 14:24:43 +0100 Subject: [PATCH] feat: Add audio subscription control to HumanInput Add a new `should_subscribe_to_audio` parameter to control when the agent subscribes to participant audio tracks. This can be either a boolean or a callback function that takes a RemoteTrackPublication and returns a boolean, allowing for fine-grained control over audio track subscriptions. --- livekit-agents/livekit/agents/pipeline/human_input.py | 11 ++++++++++- .../livekit/agents/pipeline/pipeline_agent.py | 5 ++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/human_input.py b/livekit-agents/livekit/agents/pipeline/human_input.py index b54ba6f28..aec721975 100644 --- a/livekit-agents/livekit/agents/pipeline/human_input.py +++ b/livekit-agents/livekit/agents/pipeline/human_input.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import Literal +from typing import Literal, Callable from livekit import rtc @@ -28,6 +28,7 @@ def __init__( stt: speech_to_text.STT, participant: rtc.RemoteParticipant, transcription: bool, + should_subscribe_to_audio: Callable[[rtc.RemoteTrackPublication], bool] | bool, ) -> None: super().__init__() self._room, self._vad, self._stt, self._participant, self._transcription = ( @@ -47,6 +48,7 @@ def __init__( self._room.on("track_published", self._subscribe_to_microphone) self._room.on("track_subscribed", self._subscribe_to_microphone) self._subscribe_to_microphone() + self._should_subscribe_to_audio = should_subscribe_to_audio async def aclose(self) -> None: if self._closed: @@ -77,6 +79,13 @@ def _subscribe_to_microphone(self, *args, **kwargs) -> None: if publication.source != rtc.TrackSource.SOURCE_MICROPHONE: continue + if self._should_subscribe_to_audio is not None: + if callable(self._should_subscribe_to_audio): + if not self._should_subscribe_to_audio(publication): + continue + elif not self._should_subscribe_to_audio: + continue + if not publication.subscribed: publication.set_subscribed(True) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index a9742ff92..3bd5c630f 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -137,7 +137,7 @@ class _ImplOptions: before_tts_cb: BeforeTTSCallback plotting: bool transcription: AgentTranscriptionOptions - + should_subscribe_to_audio: Callable[[rtc.RemoteTrackPublication], bool] | bool @dataclass(frozen=True) class AgentTranscriptionOptions: @@ -201,6 +201,7 @@ def __init__( loop: asyncio.AbstractEventLoop | None = None, # backward compatibility will_synthesize_assistant_reply: WillSynthesizeAssistantReply | None = None, + should_subscribe_to_audio: Callable[[rtc.RemoteTrackPublication], bool] | bool ) -> None: """ Create a new VoicePipelineAgent. @@ -255,6 +256,7 @@ def __init__( transcription=transcription, before_llm_cb=before_llm_cb, before_tts_cb=before_tts_cb, + should_subscribe_to_audio=should_subscribe_to_audio, ) self._plotter = AssistantPlotter(self._loop) @@ -557,6 +559,7 @@ def _link_participant(self, identity: str) -> None: stt=self._stt, participant=participant, transcription=self._opts.transcription.user_transcription, + should_subscribe_to_audio=self._opts.should_subscribe_to_audio, ) def _on_start_of_speech(ev: vad.VADEvent) -> None: