Skip to content

Commit

Permalink
feat: Assistant doesn't interrupt itself when speaking
Browse files Browse the repository at this point in the history
Integrating echo cancellation.
  • Loading branch information
clemlesne committed Dec 11, 2024
1 parent 9e24660 commit 416afd5
Show file tree
Hide file tree
Showing 8 changed files with 436 additions and 137 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ Conversation options are represented as features. They can be configured from Ap
| `slow_llm_for_chat` | Whether to use the slow LLM for chat. | `bool` | false |
| `vad_cutoff_timeout_ms` | The cutoff timeout for voice activity detection in secs. | `int` | 600 |
| `vad_silence_timeout_ms` | The timeout for phone silence in secs. | `int` | 400 |
| `vad_threshold` | The threshold for voice activity detection. | `float` | 0.5 |

### Use an OpenAI compatible model for the LLM

Expand Down
4 changes: 0 additions & 4 deletions app/helpers/call_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ async def on_call_disconnected(

@tracer.start_as_current_span("on_audio_connected")
async def on_audio_connected( # noqa: PLR0913
audio_bits_per_sample: int,
audio_channels: int,
audio_in: asyncio.Queue[bytes],
audio_out: asyncio.Queue[bytes | bool],
audio_sample_rate: int,
Expand All @@ -185,8 +183,6 @@ async def on_audio_connected( # noqa: PLR0913
Starts the real-time conversation with the LLM.
"""
await load_llm_chat(
audio_bits_per_sample=audio_bits_per_sample,
audio_channels=audio_channels,
audio_in=audio_in,
audio_out=audio_out,
audio_sample_rate=audio_sample_rate,
Expand Down
234 changes: 129 additions & 105 deletions app/helpers/call_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,9 @@
from azure.cognitiveservices.speech.audio import PushAudioInputStream
from azure.communication.callautomation.aio import CallAutomationClient
from openai import APIError
from pydub import AudioSegment
from pydub.effects import (
high_pass_filter,
low_pass_filter,
)
from webrtcvad import Vad

from app.helpers.call_utils import (
EchoCancellationStream,
handle_media,
handle_realtime_tts,
tts_sentence_split,
Expand Down Expand Up @@ -55,9 +50,7 @@

# TODO: Refacto, this function is too long
@tracer.start_as_current_span("call_load_llm_chat")
async def load_llm_chat( # noqa: PLR0913
audio_bits_per_sample: int,
audio_channels: int,
async def load_llm_chat( # noqa: PLR0913, PLR0915
audio_in: asyncio.Queue[bytes],
audio_out: asyncio.Queue[bytes | bool],
audio_sample_rate: int,
Expand All @@ -70,32 +63,71 @@ async def load_llm_chat( # noqa: PLR0913
# Init language recognition
stt_buffer: list[str] = [] # Temporary buffer for recognition
stt_complete_gate = asyncio.Event() # Gate to wait for the recognition
aec = EchoCancellationStream(
sample_rate=audio_sample_rate,
)
audio_reference: asyncio.Queue[bytes] = asyncio.Queue()

def _stt_callback(text: str) -> None:
async def _send_in_to_aec() -> None:
"""
Store the recognition in the buffer.
Send input audio to the echo cancellation.
"""
# Skip if no text
if not text:
return
while True:
in_chunck = await audio_in.get()
audio_in.task_done()
await aec.push_input(in_chunck)

async def _send_out_to_aec() -> None:
"""
Forward the TTS to the echo cancellation and output.
"""
while True:
out_chunck = await audio_reference.get()
audio_reference.task_done()
await asyncio.gather(
# First, send the audio to the output
audio_out.put(out_chunck),
# Then, send the audio to the echo cancellation
aec.push_reference(out_chunck),
)

stt_buffer.append(text)
def _partial_stt_callback(text: str) -> None:
"""
Store the partial recognition in the buffer.
"""
# Init buffer if empty
if not stt_buffer:
stt_buffer.append("")
# Append the recognition
stt_buffer[-1] += text
logger.debug("Partial recognition: %s", stt_buffer)

def _complete_stt_callback(text: str) -> None:
"""
Store the recognition in the buffer.
"""
# Init buffer if empty
if not stt_buffer:
stt_buffer.append("")
# Store the recognition
stt_buffer[-1] = text
# Add a new buffer for the next partial recognition
stt_buffer.append("")
logger.debug("Complete recognition: %s", stt_buffer)

# Open the recognition gate
stt_complete_gate.set()

async with (
use_stt_client(
audio_bits_per_sample=audio_bits_per_sample,
audio_channels=audio_channels,
audio_sample_rate=audio_sample_rate,
call=call,
callback=_stt_callback,
complete_callback=_complete_stt_callback,
partial_callback=_partial_stt_callback,
) as stt_stream,
use_tts_client(
audio=audio_out,
call=call,
out=audio_reference,
) as tts_client,
):
# Build scheduler
Expand All @@ -120,32 +152,32 @@ async def _timeout_callback() -> None:
)
)

async def _clear_audio_callback() -> None:
async def _stop_callback() -> None:
"""
Triggered when the audio buffer needs to be cleared.
"""
# Close previous response now
if last_response:
await last_response.close(timeout=0)

# Stop TTS, clear the buffer and send a stop signal
tts_client.stop_speaking_async()
while not audio_out.empty():
audio_out.get_nowait()
audio_out.task_done()
await audio_out.put(False)

# Close the recognition gate
# Reset the recognition
stt_buffer.clear()
stt_complete_gate.clear()

# Close previous response if any
if last_response:
await scheduler.spawn(last_response.close(timeout=0))

# Clear the recognition buffer
stt_buffer.clear()
# Send a stop signal
await audio_out.put(False)

async def _commit_answer(tool_blacklist: set[str] | None = None) -> None:
"""
Process the response.
"""
# Store recognition task
# Stop any previous response
await _stop_callback()

# Start chat task
nonlocal last_response
last_response = await scheduler.spawn(
_continue_chat(
Expand All @@ -159,21 +191,31 @@ async def _commit_answer(tool_blacklist: set[str] | None = None) -> None:
)
)

# Wait for the response to be processed
# Wait for its response
await last_response.wait()

async def _response_callback() -> None:
async def _response_callback(_retry: bool = False) -> None:
"""
Triggered when the audio buffer needs to be processed.
If the recognition is empty, retry the recognition once. Otherwise, process the response.
"""
# Wait for the complete recognition
await stt_complete_gate.wait()
# Wait the complete recognition for 50ms maximum
try:
await asyncio.wait_for(stt_complete_gate.wait(), timeout=0.05)
except TimeoutError:
pass

stt_text = " ".join(stt_buffer).strip()

# Skip if no partial recognition
# Ignore empty recognition
if not stt_text:
return
# Skip if already retries
if _retry:
return
# Retry recognition, maybe the user was too fast or the recognition is temporarly slow
await asyncio.sleep(0.2)
return await _response_callback(_retry=True)

# Add it to the call history and update last interaction
logger.info("Voice stored: %s", stt_buffer)
Expand Down Expand Up @@ -211,16 +253,21 @@ async def _response_callback() -> None:
{"end_call"},
)

await _process_chat_audio(
bits_per_sample=audio_bits_per_sample,
call=call,
channels=audio_channels,
clear_audio_callback=_clear_audio_callback,
in_stream=audio_in,
out_stream=stt_stream,
response_callback=_response_callback,
sample_rate=audio_sample_rate,
timeout_callback=_timeout_callback,
await asyncio.gather(
# Start the echo cancellation
aec.process_stream(),
# Apply the echo cancellation
_send_in_to_aec(),
_send_out_to_aec(),
# Detect VAD
_process_audio_for_vad(
call=call,
stop_callback=_stop_callback,
echo_cancellation=aec,
out_stream=stt_stream,
response_callback=_response_callback,
timeout_callback=_timeout_callback,
),
)


Expand Down Expand Up @@ -598,38 +645,41 @@ async def _content_callback(buffer: str) -> None:


# TODO: Refacto and simplify
async def _process_chat_audio( # noqa: PLR0913
bits_per_sample: int,
async def _process_audio_for_vad( # noqa: PLR0913
call: CallStateModel,
channels: int,
clear_audio_callback: Callable[[], Awaitable[None]],
in_stream: asyncio.Queue[bytes],
echo_cancellation: EchoCancellationStream,
out_stream: PushAudioInputStream,
response_callback: Callable[[], Awaitable[None]],
sample_rate: int,
stop_callback: Callable[[], Awaitable[None]],
timeout_callback: Callable[[], Awaitable[None]],
) -> None:
clear_tts_task: asyncio.Task | None = None
"""
Process voice activity and silence detection.
Follows the following steps:
- Detect voice activity and clear the TTS to let the user speak
- Wait for silence and trigger the chat
- Wait for longer silence and trigger the timeout
"""
stop_task: asyncio.Task | None = None
silence_task: asyncio.Task | None = None
vad = Vad(
# Aggressiveness mode (0, 1, 2, or 3)
# Sets the VAD operating mode. A more aggressive (higher mode) VAD is more restrictive in reporting speech. Put in other words the probability of being speech when the VAD returns 1 is increased with increasing mode. As a consequence also the missed detection rate goes up.
mode=3,
)

async def _silence_callback() -> None:
async def _wait_for_silence() -> None:
"""
Flush the audio buffer if no audio is detected for a while and trigger the timeout if required.
Run the chat after a silence.
If the silence is too long, run the timeout.
"""
# Wait before flushing
nonlocal clear_tts_task
nonlocal stop_task
timeout_ms = await vad_silence_timeout_ms()
await asyncio.sleep(timeout_ms / 1000)

# Cancel the clear TTS task if any
if clear_tts_task:
clear_tts_task.cancel()
clear_tts_task = None
# Cancel the clear TTS task
if stop_task:
stop_task.cancel()
stop_task = None

# Flush the audio buffer
logger.debug("Flushing audio buffer after %i ms", timeout_ms)
Expand Down Expand Up @@ -661,11 +711,9 @@ async def _silence_callback() -> None:
logger.info("Silence triggered after %i sec", timeout_sec)
await timeout_callback()

async def _clear_tts_callback() -> None:
async def _wait_for_stop() -> None:
"""
Clear the TTS queue.
Start is the index of the buffer where the TTS was triggered.
Stop the TTS if user speaks for too long.
"""
timeout_ms = await vad_cutoff_timeout_ms()

Expand All @@ -675,55 +723,31 @@ async def _clear_tts_callback() -> None:
logger.debug("Canceling TTS after %i ms", timeout_ms)

# Clear the queue
await clear_audio_callback()
await stop_callback()

# Consumes audio stream
while True:
# Wait for the next audio packet
in_chunck = await in_stream.get()

# Load audio
in_audio: AudioSegment = AudioSegment(
channels=channels,
data=in_chunck,
frame_rate=sample_rate,
sample_width=bits_per_sample // 8,
)

# Apply high-pass and low-pass filters in a simple attempt to reduce noise
in_audio = high_pass_filter(seg=in_audio, cutoff=85)
in_audio = low_pass_filter(seg=in_audio, cutoff=3000)

# Always add the audio to the buffer
assert isinstance(in_audio.raw_data, bytes)
out_stream.write(in_audio.raw_data)
out_chunck, is_speech = await echo_cancellation.pull_audio()

# Confirm ASAP that the event is processed
in_stream.task_done()
# Add audio to the buffer
out_stream.write(out_chunck)

# Use WebRTC VAD algorithm to detect voice
in_empty = False
if not vad.is_speech(
buf=in_audio.raw_data,
sample_rate=in_audio.frame_rate,
):
in_empty = True
# If no speech, init the silence task
if not is_speech:
# Start timeout if not already started
if not silence_task:
silence_task = asyncio.create_task(_silence_callback())

if in_empty:
silence_task = asyncio.create_task(_wait_for_silence())
# Continue to the next audio packet
continue

# Voice detected, cancel the timeout if any
# Voice detected, cancel the timeout task
if silence_task:
silence_task.cancel()
silence_task = None

# Start the TTS clear task
if not clear_tts_task:
clear_tts_task = asyncio.create_task(_clear_tts_callback())
if not stop_task:
stop_task = asyncio.create_task(_wait_for_stop())


def _tts_callback(
Expand Down
Loading

0 comments on commit 416afd5

Please sign in to comment.