From ec7a0013b4db8cf711740735e332978a4b2b4cbc Mon Sep 17 00:00:00 2001 From: ticoAg <1627635056@qq.com> Date: Sat, 16 Nov 2024 13:04:53 +0800 Subject: [PATCH] bugfix for empty cache with diff device && ChatTTS process --- .gitignore | 3 ++- STT/paraformer_handler.py | 5 ++++- TTS/chatTTS_handler.py | 12 ++++++++---- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 33b7875..3e653c0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__ tmp -cache \ No newline at end of file +cache +asset diff --git a/STT/paraformer_handler.py b/STT/paraformer_handler.py index 99fd6ac..e576158 100644 --- a/STT/paraformer_handler.py +++ b/STT/paraformer_handler.py @@ -53,7 +53,10 @@ def process(self, spoken_prompt): pred_text = ( self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "") ) - torch.mps.empty_cache() + if self.device == "cuda": + torch.cuda.empty_cache() + elif self.device == "mps": + torch.mps.empty_cache() logger.debug("finished paraformer inference") console.print(f"[yellow]USER: {pred_text}") diff --git a/TTS/chatTTS_handler.py b/TTS/chatTTS_handler.py index 12bcf66..879b540 100644 --- a/TTS/chatTTS_handler.py +++ b/TTS/chatTTS_handler.py @@ -1,10 +1,12 @@ -import ChatTTS import logging -from baseHandler import BaseHandler + +import ChatTTS import librosa import numpy as np -from rich.console import Console import torch +from rich.console import Console + +from baseHandler import BaseHandler logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", @@ -40,6 +42,8 @@ def warmup(self): _ = self.model.infer("text") def process(self, llm_sentence): + if isinstance(llm_sentence, tuple): + llm_sentence, language_code = llm_sentence console.print(f"[green]ASSISTANT: {llm_sentence}") if self.device == "mps": import time @@ -62,7 +66,7 @@ def process(self, llm_sentence): self.should_listen.set() return audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000) - audio_chunk = (audio_chunk * 32768).astype(np.int16)[0] + audio_chunk = (audio_chunk * 32768).astype(np.int16) while len(audio_chunk) > self.chunk_size: yield audio_chunk[: self.chunk_size] # Return the first chunk_size samples of the audio data audio_chunk = audio_chunk[self.chunk_size :] # Remove the samples that have already been returned