Skip to content

Commit

Permalink
feat: Add audio subscription control to HumanInput
Browse files Browse the repository at this point in the history
Add a new `should_subscribe_to_audio` parameter to control when the agent subscribes to participant audio tracks. This can be either a boolean or a callback function that takes a RemoteTrackPublication and returns a boolean, allowing for fine-grained control over audio track subscriptions.
  • Loading branch information
Mohamed Boussaid committed Jan 24, 2025
1 parent 695f7b5 commit b724571
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
11 changes: 10 additions & 1 deletion livekit-agents/livekit/agents/pipeline/human_input.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from typing import Literal
from typing import Literal, Callable

from livekit import rtc

Expand All @@ -28,6 +28,7 @@ def __init__(
stt: speech_to_text.STT,
participant: rtc.RemoteParticipant,
transcription: bool,
should_subscribe_to_audio: Callable[[rtc.RemoteTrackPublication], bool] | bool,
) -> None:
super().__init__()
self._room, self._vad, self._stt, self._participant, self._transcription = (
Expand All @@ -47,6 +48,7 @@ def __init__(
self._room.on("track_published", self._subscribe_to_microphone)
self._room.on("track_subscribed", self._subscribe_to_microphone)
self._subscribe_to_microphone()
self._should_subscribe_to_audio = should_subscribe_to_audio

async def aclose(self) -> None:
if self._closed:
Expand Down Expand Up @@ -77,6 +79,13 @@ def _subscribe_to_microphone(self, *args, **kwargs) -> None:
if publication.source != rtc.TrackSource.SOURCE_MICROPHONE:
continue

if self._should_subscribe_to_audio is not None:
if callable(self._should_subscribe_to_audio):
if not self._should_subscribe_to_audio(publication):
continue
elif not self._should_subscribe_to_audio:
continue

if not publication.subscribed:
publication.set_subscribed(True)

Expand Down
5 changes: 4 additions & 1 deletion livekit-agents/livekit/agents/pipeline/pipeline_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class _ImplOptions:
before_tts_cb: BeforeTTSCallback
plotting: bool
transcription: AgentTranscriptionOptions

should_subscribe_to_audio: Callable[[rtc.RemoteTrackPublication], bool] | bool

@dataclass(frozen=True)
class AgentTranscriptionOptions:
Expand Down Expand Up @@ -201,6 +201,7 @@ def __init__(
loop: asyncio.AbstractEventLoop | None = None,
# backward compatibility
will_synthesize_assistant_reply: WillSynthesizeAssistantReply | None = None,
should_subscribe_to_audio: Callable[[rtc.RemoteTrackPublication], bool] | bool
) -> None:
"""
Create a new VoicePipelineAgent.
Expand Down Expand Up @@ -255,6 +256,7 @@ def __init__(
transcription=transcription,
before_llm_cb=before_llm_cb,
before_tts_cb=before_tts_cb,
should_subscribe_to_audio=should_subscribe_to_audio,
)
self._plotter = AssistantPlotter(self._loop)

Expand Down Expand Up @@ -557,6 +559,7 @@ def _link_participant(self, identity: str) -> None:
stt=self._stt,
participant=participant,
transcription=self._opts.transcription.user_transcription,
should_subscribe_to_audio=self._opts.should_subscribe_to_audio,
)

def _on_start_of_speech(ev: vad.VADEvent) -> None:
Expand Down

0 comments on commit b724571

Please sign in to comment.