From 3bcfaf9f9f6dd078a58842b37664c0133ec8375d Mon Sep 17 00:00:00 2001 From: JarbasAi Date: Fri, 8 Dec 2023 18:02:51 +0000 Subject: [PATCH 01/15] refactor/tts_cache fix kwargs handling in synth method move methods around for readability and group them based on functionality add more deprecation warnings lang from session move cache to TTSContext --- .github/workflows/unit_tests.yml | 2 +- ovos_plugin_manager/templates/tts.py | 800 +++++++++++++-------------- test/unittests/test_tts.py | 86 ++- 3 files changed, 472 insertions(+), 416 deletions(-) diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 89dc6e84..69a78d3a 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -33,7 +33,7 @@ jobs: strategy: max-parallel: 2 matrix: - python-version: [ 3.7, 3.8, 3.9, "3.10" ] + python-version: [ 3.8, 3.9, "3.10" ] runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index 6f91085b..b7e2b078 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -1,30 +1,6 @@ -""" -this module is meant to enable usage of mycroft plugins inside and outside -mycroft, importing from here will make things work as planned in mycroft, -but if outside mycroft things will still work - -The main use case is for plugins to be used across different projects - -## Differences from upstream - -TTS: -- added automatic guessing of phonemes/visime calculation, enabling mouth -movements for all TTS engines (only mimic implements this in upstream) -- playback start call has been omitted and moved to init method -- init is called by mycroft, but non mycroft usage wont call it -- outside mycroft the enclosure is not set, bus is dummy and playback thread is not used - - playback queue is not wanted when some module is calling get_tts - - if playback was started on init then python scripts would never stop - from mycroft.tts import TTSFactory - engine = TTSFactory.create() - engine.get_tts("hello world", "hello_world." + engine.audio_ext) - # would hang here - engine.playback.stop() -""" import abc import asyncio import inspect -import random import re import subprocess from os.path import isfile, join @@ -37,13 +13,15 @@ import requests from ovos_bus_client.apis.enclosure import EnclosureAPI from ovos_bus_client.message import Message, dig_for_message +from ovos_bus_client.session import SessionManager from ovos_config import Configuration +from ovos_config.locations import get_xdg_cache_save_path from ovos_utils import classproperty from ovos_utils.fakebus import FakeBus from ovos_utils.file_utils import get_cache_directory from ovos_utils.file_utils import resolve_resource_file from ovos_utils.lang.visimes import VISIMES -from ovos_utils.log import LOG +from ovos_utils.log import LOG, deprecated from ovos_utils.metrics import Stopwatch from ovos_utils.process_utils import RuntimeRequirements @@ -56,83 +34,42 @@ SSML_TAGS = re.compile(r'<[^>]*>') -class PlaybackThread(Thread): - """ PlaybackThread moved to ovos_audio.playback - standalone plugin usage should rely on self.get_tts - ovos-audio relies on self.execute and needs this class - - this class was only in ovos-plugin-manager in order to - patch usage of our plugins in mycroft-core""" - - def __new__(self, *args, **kwargs): - LOG.warning("PlaybackThread moved to ovos_audio.playback") - try: - from ovos_audio.playback import PlaybackThread - return PlaybackThread(*args, **kwargs) - except ImportError: - raise ImportError("please install ovos-audio for playback handling") - - class TTSContext: - """ parses kwargs for valid signatures and extracts voice/lang optional parameters - it will look for a requested voice in kwargs and inside the source Message data if available. - voice can also be defined by a combination of language and gender, - in that case the helper method get_voice will be used to resolve the final voice_id - """ - - def __init__(self, engine): - self.engine = engine - - def get_message(self, kwargs): - msg = kwargs.get("message") or dig_for_message() - if msg and isinstance(msg, Message): - return msg - - def get_lang(self, kwargs): - # parse requested language for this TTS request - # NOTE: this is ovos only functionality, not in mycroft-core! - lang = kwargs.get("lang") - message = self.get_message(kwargs) - if not lang and message: - # get lang from message object if possible - lang = message.data.get("lang") or \ - message.context.get("lang") - return lang or self.engine.lang - - def get_gender(self, kwargs): - gender = kwargs.get("gender") - message = self.get_message(kwargs) - if not gender and message: - # get gender from message object if possible - gender = message.data.get("gender") or \ - message.context.get("gender") - return gender + _caches = {} - def get_voice(self, kwargs): - # parse requested voice for this TTS request - # NOTE: this is ovos only functionality, not in mycroft-core! - voice = kwargs.get("voice") - message = self.get_message(kwargs) - if not voice and message: - # get voice from message object if possible - voice = message.data.get("voice") or \ - message.context.get("voice") + def __init__(self, plugin_id: str, lang: str, voice: str): + self.plugin_id = plugin_id + self.lang = lang + self.voice = voice - if not voice: - gender = self.get_gender(kwargs) - if gender: - lang = self.get_lang(kwargs) - voice = self.engine.get_voice(gender, lang) - - return voice or self.engine.voice - - def get(self, kwargs=None): - kwargs = kwargs or {} - return self.get_lang(kwargs), self.get_voice(kwargs) + @property + def tts_id(self): + return join(self.plugin_id, self.voice, self.lang) + + def get_cache(self, audio_ext="wav", cache_config=None): + cache_config = cache_config or { + "min_free_percent": 75, + "persist_cache": False, + "persist_thresh": 1, + "preloaded_cache": f"{get_xdg_cache_save_path()}/{self.tts_id}" + } + if self.tts_id not in TTSContext._caches: + TTSContext._caches[self.tts_id] = TextToSpeechCache( + cache_config, self.tts_id, audio_ext + ) + return self._caches[self.tts_id] - def get_cache(self, kwargs=None): - lang, voice = self.get(kwargs) - return self.engine.get_cache(voice, lang) + def get_from_cache(self, sentence, audio_ext="wav", cache_config=None): + sentence_hash = hash_sentence(sentence) + phonemes = None + cache = self.get_cache(audio_ext, cache_config) + if sentence_hash not in cache: + raise FileNotFoundError(f"sentence is not cached, {sentence_hash}.{audio_ext}") + audio_file, pho_file = cache.cached_sentences[sentence_hash] + LOG.info(f"Found {audio_file.name} in TTS cache") + if pho_file: + phonemes = pho_file.load() + return audio_file, phonemes class TTS: @@ -151,7 +88,7 @@ class TTS: queue = None playback = None - def __init__(self, lang="en-us", config=None, validator=None, + def __init__(self, lang=None, config=None, validator=None, audio_ext='wav', phonetic_spelling=True, ssml_tags=None): self.log_timestamps = False @@ -159,199 +96,44 @@ def __init__(self, lang="en-us", config=None, validator=None, self.stopwatch = Stopwatch() self.tts_name = self.__class__.__name__ - self.bus = FakeBus() # initialized in "init" step - self.lang = lang or self.config.get("lang") or 'en-us' + self.validator = validator or TTSValidator(self) self.phonetic_spelling = phonetic_spelling self.audio_ext = audio_ext self.ssml_tags = ssml_tags or [] self.log_timestamps = self.config.get("log_timestamps", False) - self.enable_cache = self.config.get("enable_cache", True) - - self.voice = self.config.get("voice") or "default" - # TODO can self.filename be deprecated ? is it used anywhere at all? - cache_dir = get_cache_directory(self.tts_name) - self.filename = join(cache_dir, 'tts.' + self.audio_ext) - - random.seed() + self.enable_cache = self.config.get("enable_cache", False) if TTS.queue is None: TTS.queue = Queue() - self.context = TTSContext(self) - - # NOTE: self.playback.start() was moved to init method - # playback queue is not wanted if we only care about get_tts - # init is called by mycroft, but non mycroft usage wont call it, - # outside mycroft the enclosure is not set, bus is dummy and - # playback thread is not used self.spellings = self.load_spellings() - - self.caches = { - self.tts_id: TextToSpeechCache( - self.config, self.tts_id, self.audio_ext - )} - - cfg = Configuration() - g2pm = self.config.get("g2p_module") - if g2pm: - if g2pm in find_g2p_plugins(): - cfg.setdefault("g2p", {}) - globl = cfg["g2p"].get("module") or g2pm - if globl != g2pm: - LOG.info(f"TTS requested {g2pm} explicitly, ignoring global module {globl} ") - cfg["g2p"]["module"] = g2pm - else: - LOG.warning(f"TTS selected {g2pm}, but it is not available!") - - try: - self.g2p = OVOSG2PFactory.create(cfg) - except: - LOG.exception("G2P plugin not loaded, there will be no mouth movements") - self.g2p = None - - self.cache.curate() + self._init_g2p() self.add_metric({"metric_type": "tts.init"}) + # unused by plugins, assigned in init method by ovos-audio, + # only present for backwards compat reasons + self.bus = None + + # methods for individual plugins to override @classproperty def runtime_requirements(self): - """ skill developers should override this if they do not require connectivity - some examples: - IOT plugin that controls devices via LAN could return: - scans_on_init = True - RuntimeRequirements(internet_before_load=False, - network_before_load=scans_on_init, - requires_internet=False, - requires_network=True, - no_internet_fallback=True, - no_network_fallback=False) - online search plugin with a local cache: - has_cache = False - RuntimeRequirements(internet_before_load=not has_cache, - network_before_load=not has_cache, - requires_internet=True, - requires_network=True, - no_internet_fallback=True, - no_network_fallback=True) - a fully offline plugin: - RuntimeRequirements(internet_before_load=False, - network_before_load=False, - requires_internet=False, - requires_network=False, - no_internet_fallback=True, - no_network_fallback=True) - """ + """ WIP - currently unused, + placeholder to allow plugins to request internet/gui before load + refer to skills to see how it is used""" return RuntimeRequirements() @property - def tts_id(self): - lang, voice = self.context.get() - return join(self.tts_name, voice, lang) - - @property - def cache(self): - return self.caches.get(self.tts_id) or \ - self.get_cache() - - @cache.setter - def cache(self, val): - self.caches[self.tts_id] = val - - def get_cache(self, voice=None, lang=None): - lang = lang or self.lang - voice = voice or self.voice or "default" - tts_id = join(self.tts_name, voice, lang) - if tts_id not in self.caches: - self.caches[tts_id] = TextToSpeechCache( - self.config, tts_id, self.audio_ext - ) - return self.caches[tts_id] - - def handle_metric(self, metadata=None): - """ receive timing metrics for diagnostics - does nothing by default but plugins might use it, eg, NeonCore""" - - def add_metric(self, metadata=None): - """ wraps handle_metric to catch exceptions and log timestamps """ - try: - self.handle_metric(metadata) - if self.log_timestamps: - LOG.debug(f"time delta: {self.stopwatch.delta} metric: {metadata}") - except Exception as e: - LOG.exception(e) - - def load_spellings(self, config=None): - """Load phonetic spellings of words as dictionary.""" - path = join('text', self.lang.lower(), 'phonetic_spellings.txt') - try: - spellings_file = resolve_resource_file(path, config=config or Configuration()) - except: - LOG.debug('Failed to locate phonetic spellings resouce file.') - return {} - if not spellings_file: - return {} - try: - with open(spellings_file) as f: - lines = filter(bool, f.read().split('\n')) - lines = [i.split(':') for i in lines] - return {key.strip(): value.strip() for key, value in lines} - except ValueError: - LOG.exception('Failed to load phonetic spellings.') - return {} - - def begin_audio(self): - """Helper function for child classes to call in execute()""" - self.stopwatch.start() - self.add_metric({"metric_type": "tts.start"}) - - def end_audio(self, listen=False): - """Helper cleanup function for child classes to call in execute(). - - Arguments: - listen (bool): DEPRECATED: indication if listening trigger should be sent. - """ - self.add_metric({"metric_type": "tts.end"}) - self.stopwatch.stop() - - def init(self, bus=None, playback=None): - """ Performs intial setup of TTS object. - - Arguments: - bus: OpenVoiceOS messagebus connection + def available_languages(self) -> set: + """Return languages supported by this TTS implementation in this state + This property should be overridden by the derived class to advertise + what languages that engine supports. + Returns: + set: supported languages """ - self.bus = bus or FakeBus() - if playback is None: - LOG.warning("PlaybackThread should be inited by ovos-audio, initing via plugin has been deprecated, " - "please pass playback=PlaybackThread() to TTS.init") - if TTS.playback: - playback.shutdown() - playback = PlaybackThread(TTS.queue, self.bus) # compat - playback.start() - self._init_playback(playback) - self.add_metric({"metric_type": "tts.setup"}) - - def _init_playback(self, playback): - TTS.playback = playback - TTS.playback.set_bus(self.bus) - TTS.playback.attach_tts(self) - if not TTS.playback.enclosure: - TTS.playback.enclosure = EnclosureAPI(self.bus) - - if not TTS.playback.is_alive(): - TTS.playback.start() - - @property - def enclosure(self): - if not TTS.playback.enclosure: - bus = TTS.playback.bus or self.bus - TTS.playback.enclosure = EnclosureAPI(bus) - return TTS.playback.enclosure - - @enclosure.setter - def enclosure(self, val): - TTS.playback.enclosure = val + return set() @abc.abstractmethod def get_tts(self, sentence, wav_file, lang=None): @@ -369,6 +151,23 @@ def get_tts(self, sentence, wav_file, lang=None): """ return "", None + def preprocess_sentence(self, sentence): + """Default preprocessing is a sentence_tokenizer, + ie. splits the utterance into sub-sentences using quebra_frases + + This method can be overridden to create chunks suitable to the + TTS engine in question. + + Arguments: + sentence (str): sentence to preprocess + + Returns: + list: list of sentence parts + """ + if self.config.get("sentence_tokenize"): # TODO default to True on next major release + return quebra_frases.sentence_tokenize(sentence) + return [sentence] + def modify_tag(self, tag): """Override to modify each supported ssml tag. @@ -377,6 +176,36 @@ def modify_tag(self, tag): """ return tag + def handle_metric(self, metadata=None): + """ receive timing metrics for diagnostics + does nothing by default but plugins might use it, eg, NeonCore""" + + # properties that reflect bus message session + @property + def voice(self): + message = dig_for_message() + if message: + # TODO - get from tts_prefs in session + pass + return self.config.get("voice") or "default" + + @voice.setter + def voice(self, val): + self.config["voice"] = val + + @property + def lang(self): + message = dig_for_message() + if message: + sess = SessionManager.get() + return sess.lang + return self.config.get("lang") or 'en-us' + + @lang.setter + def lang(self, val): + LOG.warning("self.lang can not be set! it comes from the bus message") + + # SSML helpers @staticmethod def remove_ssml(text): """Removes SSML tags from a string. @@ -463,22 +292,99 @@ def validate_ssml(self, utterance): # return text with supported ssml tags only return utterance.replace(" ", " ") - def _preprocess_sentence(self, sentence): - """Default preprocessing is a sentence_tokenizer, - ie. splits the utterance into sub-sentences using quebra_frases + # init helpers + def _init_g2p(self): + cfg = Configuration() + g2pm = self.config.get("g2p_module") + if g2pm: + if g2pm in find_g2p_plugins(): + cfg.setdefault("g2p", {}) + globl = cfg["g2p"].get("module") or g2pm + if globl != g2pm: + LOG.info(f"TTS requested {g2pm} explicitly, ignoring global module {globl} ") + cfg["g2p"]["module"] = g2pm + else: + LOG.warning(f"TTS selected {g2pm}, but it is not available!") - This method can be overridden to create chunks suitable to the - TTS engine in question. + try: + self.g2p = OVOSG2PFactory.create(cfg) + except: + LOG.debug("G2P plugin not loaded, there will be no mouth movements") + self.g2p = None + + def init(self, bus=None, playback=None): + """ Connects TTS object to PlaybackQueue in ovos-audio. + + This method needs to be called in order for self.execute to do anything + + not needed if using get_tts / synth methods directly as intended in standalone usage Arguments: - sentence (str): sentence to preprocess + bus: OpenVoiceOS messagebus connection + """ + self.bus = bus or FakeBus() + if playback is None: + LOG.warning("PlaybackThread should be inited by ovos-audio, initing via plugin has been deprecated, " + "please pass playback=PlaybackThread() to TTS.init") + if TTS.playback: + playback.shutdown() + playback = PlaybackThread(TTS.queue, self.bus) # compat + playback.start() + self._init_playback(playback) + self.add_metric({"metric_type": "tts.setup"}) - Returns: - list: list of sentence parts + def _init_playback(self, playback): + TTS.playback = playback + TTS.playback.set_bus(self.bus) + TTS.playback.attach_tts(self) + if not TTS.playback.enclosure: + TTS.playback.enclosure = EnclosureAPI(self.bus) + + if not TTS.playback.is_alive(): + TTS.playback.start() + + def load_spellings(self, config=None): + """Load phonetic spellings of words as dictionary.""" + path = join('text', self.lang.lower(), 'phonetic_spellings.txt') + try: + spellings_file = resolve_resource_file(path, config=config or Configuration()) + except: + LOG.debug('Failed to locate phonetic spellings resource file.') + return {} + if not spellings_file: + return {} + try: + with open(spellings_file) as f: + lines = filter(bool, f.read().split('\n')) + lines = [i.split(':') for i in lines] + return {key.strip(): value.strip() for key, value in lines} + except ValueError: + LOG.exception('Failed to load phonetic spellings.') + return {} + + ## execution events + def add_metric(self, metadata=None): + """ wraps handle_metric to catch exceptions and log timestamps """ + try: + self.handle_metric(metadata) + if self.log_timestamps: + LOG.debug(f"time delta: {self.stopwatch.delta} metric: {metadata}") + except Exception as e: + LOG.exception(e) + + def begin_audio(self): + """Helper function for child classes to call in execute()""" + self.stopwatch.start() + self.add_metric({"metric_type": "tts.start"}) + + def end_audio(self, listen=False): + """Helper cleanup function for child classes to call in execute(). + + Arguments: + listen (bool): DEPRECATED: indication if listening trigger should be sent. """ - if self.config.get("sentence_tokenize"): # TODO default to True on next major release - return quebra_frases.sentence_tokenize(sentence) - return [sentence] + self.add_metric({"metric_type": "tts.end"}) + self.stopwatch.stop() def execute(self, sentence, ident=None, listen=False, **kwargs): """Convert sentence to speech, preprocessing out unsupported ssml @@ -498,6 +404,7 @@ def execute(self, sentence, ident=None, listen=False, **kwargs): self._execute(sentence, ident, listen, **kwargs) self.end_audio() + ## synth def _replace_phonetic_spellings(self, sentence): if self.phonetic_spelling: for word in re.findall(r"[\w']+", sentence): @@ -506,52 +413,77 @@ def _replace_phonetic_spellings(self, sentence): sentence = sentence.replace(word, spelled) return sentence + def _get_visemes(self, phonemes, sentence, ctxt): + # get visemes/mouth movements + viseme = [] + if phonemes: + viseme = self.viseme(phonemes) + elif self.g2p is not None: + try: + viseme = self.g2p.utterance2visemes(sentence, ctxt.lang) + except OutOfVocabulary: + pass + except: + # this one is unplanned, let devs know all the info so they can fix it + LOG.exception(f"Unexpected failure in G2P plugin: {self.g2p}") + + if not viseme: + # Debug level because this is expected in default installs + LOG.debug(f"no mouth movements available! unknown visemes for {sentence}") + return viseme + + def _get_ctxt(self, kwargs=None): + kwargs = kwargs or {} + # get request specific synth params + message = kwargs.get("message") or dig_for_message() + lang = kwargs.get("lang") + voice = kwargs.get("voice") + if message and not lang: + sess = SessionManager.get(message) + lang = lang or sess.lang + return TTSContext(plugin_id=self.tts_name, # TODO this should be the OPM name at some point + lang=lang or self.lang, + voice=voice or self.voice) + def _execute(self, sentence, ident, listen, preprocess=True, **kwargs): if preprocess: + # pre-process sentence = self._replace_phonetic_spellings(sentence) - chunks = self._preprocess_sentence(sentence) + chunks = self.preprocess_sentence(sentence) # Apply the listen flag to the last chunk, set the rest to False chunks = [(chunks[i], listen if i == len(chunks) - 1 else False) for i in range(len(chunks))] + + # metrics timing callback self.add_metric({"metric_type": "tts.preprocessed", "n_chunks": len(chunks)}) else: chunks = [(sentence, listen)] - lang, voice = self.context.get(kwargs) - tts_id = join(self.tts_name, voice, lang) + # get request specific synth params + ctxt = self._get_ctxt(kwargs) + + message = kwargs.get("message") or \ + dig_for_message() or \ + Message("speak", context={"session": {"session_id": ident}}) # synth -> queue for playback for sentence, l in chunks: # load from cache or synth + cache - audio_file, phonemes = self.synth(sentence, **kwargs) + audio_file, phonemes = self.synth(sentence, ctxt, **kwargs) # get visemes/mouth movements - viseme = [] - if phonemes: - viseme = self.viseme(phonemes) - elif self.g2p is not None: - try: - viseme = self.g2p.utterance2visemes(sentence, lang) - except OutOfVocabulary: - pass - except: - # this one is unplanned, let devs know all the info so they can fix it - LOG.exception(f"Unexpected failure in G2P plugin: {self.g2p}") - - if not viseme: - # Debug level because this is expected in default installs - LOG.debug(f"no mouth movements available! unknown visemes for {sentence}") - - message = kwargs.get("message") or \ - dig_for_message() or \ - Message("speak", context={"session": {"session_id": ident}}) + viseme = self._get_visemes(phonemes, sentence, ctxt) + + # queue audio for playback TTS.queue.put( - (str(audio_file), viseme, l, tts_id, message) + (str(audio_file), viseme, l, ctxt.tts_id, message) ) + + # metrics timing callback self.add_metric({"metric_type": "tts.queued"}) - def synth(self, sentence, **kwargs): + def synth(self, sentence, ctxt: TTSContext = None, **kwargs): """ This method wraps get_tts several optional keyword arguments are supported sentence will be read/saved to cache""" @@ -559,24 +491,19 @@ def synth(self, sentence, **kwargs): sentence_hash = hash_sentence(sentence) # parse requested language for this TTS request - # NOTE: this is ovos/neon only functionality, not in mycroft-core! - lang, voice = self.context.get(kwargs) - kwargs["lang"] = lang - kwargs["voice"] = voice - - cache = self.get_cache(voice, lang) # cache per tts_id (lang/voice combo) + ctxt = ctxt or self._get_ctxt(kwargs) + cache = ctxt.get_cache(self.audio_ext, self.config) # load from cache if self.enable_cache and sentence_hash in cache: - audio, phonemes = self.get_from_cache(sentence, **kwargs) + audio, phonemes = ctxt.get_from_cache(sentence, cache) self.add_metric({"metric_type": "tts.synth.finished", "cache": True}) return audio, phonemes # synth + cache audio = cache.define_audio_file(sentence_hash) - # filter kwargs per plugin, different plugins expose different options - # mycroft-core -> no kwargs + # filter kwargs per plugin, different plugins expose different kwargs # ovos -> lang + voice optional kwargs # neon-core -> message kwargs = {k: v for k, v in kwargs.items() @@ -589,51 +516,10 @@ def synth(self, sentence, **kwargs): # cache sentence + phonemes if self.enable_cache: - self._cache_sentence(sentence, audio, phonemes, sentence_hash, - voice=voice, lang=lang) + self._cache_sentence(sentence, audio, cache, + phonemes, sentence_hash) return audio, phonemes - def _cache_phonemes(self, sentence, phonemes=None, sentence_hash=None): - sentence_hash = sentence_hash or hash_sentence(sentence) - if not phonemes and self.g2p is not None: - try: - phonemes = self.g2p.utterance2arpa(sentence, self.lang) - self.add_metric({"metric_type": "tts.phonemes.g2p"}) - except Exception as e: - self.add_metric({"metric_type": "tts.phonemes.g2p.error", "error": str(e)}) - if phonemes: - return self.save_phonemes(sentence_hash, phonemes) - return None - - def _cache_sentence(self, sentence, audio_file, phonemes=None, sentence_hash=None, - voice=None, lang=None): - sentence_hash = sentence_hash or hash_sentence(sentence) - # RANT: why do you hate strings ChrisV? - if isinstance(audio_file.path, str): - audio_file.path = Path(audio_file.path) - pho_file = self._cache_phonemes(sentence, phonemes, sentence_hash) - cache = self.get_cache(voice=voice, lang=lang) - cache.cached_sentences[sentence_hash] = (audio_file, pho_file) - self.add_metric({"metric_type": "tts.synth.cached"}) - - def get_from_cache(self, sentence, **kwargs): - sentence_hash = hash_sentence(sentence) - phonemes = None - cache = self.context.get_cache(kwargs) - audio_file, pho_file = cache.cached_sentences[sentence_hash] - LOG.info(f"Found {audio_file.name} in TTS cache") - if not pho_file: - # guess phonemes from sentence + cache them - pho_file = self._cache_phonemes(sentence, sentence_hash) - if pho_file: - phonemes = pho_file.load() - return audio_file, phonemes - - def get_voice(self, gender, lang=None): - """ map a language and gender to a valid voice for this TTS engine """ - lang = lang or self.lang - return gender - def viseme(self, phonemes): """Create visemes from phonemes. @@ -660,10 +546,117 @@ def viseme(self, phonemes): float(0.2))) return visimes or None + ## cache + def _cache_phonemes(self, sentence, cache: TextToSpeechCache = None, phonemes=None, sentence_hash=None): + sentence_hash = sentence_hash or hash_sentence(sentence) + if not phonemes and self.g2p is not None: + try: + phonemes = self.g2p.utterance2arpa(sentence, self.lang) + self.add_metric({"metric_type": "tts.phonemes.g2p"}) + except Exception as e: + self.add_metric({"metric_type": "tts.phonemes.g2p.error", "error": str(e)}) + if phonemes: + phoneme_file = cache.define_phoneme_file(sentence_hash) + phoneme_file.save(phonemes) + return phoneme_file + return None + + def _cache_sentence(self, sentence, audio_file, cache, phonemes=None, sentence_hash=None): + sentence_hash = sentence_hash or hash_sentence(sentence) + # RANT: why do you hate strings ChrisV? + if isinstance(audio_file.path, str): + audio_file.path = Path(audio_file.path) + pho_file = self._cache_phonemes(sentence, cache, phonemes, sentence_hash) + cache.cached_sentences[sentence_hash] = (audio_file, pho_file) + self.add_metric({"metric_type": "tts.synth.cached"}) + + ## shutdown + def stop(self): + if TTS.playback: + try: + TTS.playback.stop() + except Exception as e: + pass + self.add_metric({"metric_type": "tts.stop"}) + + def shutdown(self): + self.stop() + if TTS.playback: + TTS.playback.detach_tts(self) + + def __del__(self): + self.shutdown() + + # below code is all deprecated and marked for removal in next stable release + # TODO - update version number in warnings + @property + @deprecated("self.enclosure has been deprecated, use EnclosureAPI directly decoupled from the plugin code", + "0.1.0") + def enclosure(self): + if not TTS.playback.enclosure: + bus = TTS.playback.bus or self.bus + TTS.playback.enclosure = EnclosureAPI(bus) + return TTS.playback.enclosure + + @enclosure.setter + @deprecated("self.enclosure has been deprecated, use EnclosureAPI directly decoupled from the plugin code", + "0.1.0") + def enclosure(self, val): + TTS.playback.enclosure = val + + @property + @deprecated("self.filename has been deprecated, unused for a long time now", + "0.1.0") + def filename(self): + cache_dir = get_cache_directory(self.tts_name) + return join(cache_dir, 'tts.' + self.audio_ext) + + @filename.setter + @deprecated("self.filename has been deprecated, unused for a long time now", + "0.1.0") + def filename(self, val): + pass + + @property + @deprecated("self.tts_id has been deprecated, use TTSContext().tts_id", + "0.1.0") + def tts_id(self): + return self._get_ctxt().tts_id + + @property + @deprecated("self.cache has been deprecated, use TTSContext().get_cache", + "0.1.0") + def cache(self): + return TTSContext._caches.get(self.tts_id) or \ + self.get_cache() + + @cache.setter + @deprecated("self.cache has been deprecated, use TTSContext().get_cache", + "0.1.0") + def cache(self, val): + TTSContext._caches[self.tts_id] = val + + @deprecated("get_voice was never formally adopted and is unused, it will be removed", + "0.1.0") + def get_voice(self, gender, lang=None): + """ map a language and gender to a valid voice for this TTS engine """ + lang = lang or self.lang + return gender + + @deprecated("get_cache has been deprecated, use TTSContext().get_cache directly", + "0.1.0") + def get_cache(self, voice=None, lang=None): + return self._get_ctxt().get_cache(self.audio_ext, self.config) + + @deprecated("clear_cache has been deprecated, use TTSContext().get_cache directly", + "0.1.0") def clear_cache(self): """ Remove all cached files. """ - self.cache.clear() + cache = self._get_ctxt().get_cache(self.audio_ext, self.config) + cache.clear() + @deprecated("save_phonemes has been deprecated, use TTSContext().get_cache directly", + "0.1.0") def save_phonemes(self, key, phonemes): """Cache phonemes @@ -671,44 +664,27 @@ def save_phonemes(self, key, phonemes): key (str): Hash key for the sentence phonemes (str): phoneme string to save """ - phoneme_file = self.cache.define_phoneme_file(key) + cache = self._get_ctxt().get_cache(self.audio_ext, self.config) + phoneme_file = cache.define_phoneme_file(key) phoneme_file.save(phonemes) return phoneme_file + @deprecated("load_phonemes has been deprecated, use TTSContext().get_cache directly", + "0.1.0") def load_phonemes(self, key): """Load phonemes from cache file. Arguments: key (str): Key identifying phoneme cache """ - phoneme_file = self.cache.define_phoneme_file(key) + cache = self._get_ctxt().get_cache(self.audio_ext, self.config) + phoneme_file = cache.define_phoneme_file(key) return phoneme_file.load() - def stop(self): - if TTS.playback: - try: - TTS.playback.stop() - except Exception as e: - pass - self.add_metric({"metric_type": "tts.stop"}) - - def shutdown(self): - self.stop() - if TTS.playback: - TTS.playback.detach_tts(self) - - def __del__(self): - self.shutdown() - - @property - def available_languages(self) -> set: - """Return languages supported by this TTS implementation in this state - This property should be overridden by the derived class to advertise - what languages that engine supports. - Returns: - set: supported languages - """ - return set() + @deprecated("get_from_cache has been deprecated, use TTSContext().get_from_cache directly", + "0.1.0") + def get_from_cache(self, sentence): + return self._get_ctxt().get_from_cache(sentence, self.audio_ext, self.config) class TTSValidator: @@ -815,6 +791,7 @@ class RemoteTTSTimeoutException(RemoteTTSException): class StreamingTTSCallbacks: """handle the playback of streaming TTS, can be overrided in StreamingTTS""" + def __init__(self, bus, play_args=None, tts_config=None): self.bus = bus self.config = tts_config or {} @@ -831,12 +808,12 @@ def stream_start(self, message=None): message = message or \ dig_for_message() or \ Message("speak") - + # we don't use the regular PlaybackThread here, we need to handle recognizer_loop:audio_output_start if not self.config.get("pulse_duck", False): self.bus.emit(message.forward("ovos.common_play.duck")) self.bus.emit(message.forward("recognizer_loop:audio_output_start")) - + if self._process: self.stream_stop() LOG.debug(f"stream playback command: {self.play_args}") @@ -860,7 +837,7 @@ def stream_stop(self, listen=False, message=None): message = message or \ dig_for_message() or \ Message("speak") - + if self._process: self._process.stdin.close() self._process.wait() @@ -869,7 +846,7 @@ def stream_stop(self, listen=False, message=None): # we don't use the regular PlaybackThread here, we need to handle recognizer_loop:audio_output_end and listen flag if not self.config.get("pulse_duck", False): self.bus.emit(message.forward("ovos.common_play.unduck")) - self.bus.emit(message.forward("recognizer_loop:audio_output_end")) + self.bus.emit(message.forward("recognizer_loop:audio_output_end")) if listen: self.bus.emit(message.forward('mycroft.mic.listen')) @@ -878,14 +855,14 @@ class StreamingTTS(TTS): """ Abstract class for a Streaming TTS engine implementation. Audio is streamed in chunks as it becomes available instead of waiting the full sentence to be synthesized - + this plugin can be used in a synchronous way like any other plugin via self.get_tts(sentence, wav_file) - + to play audio as it becomes available use self.generate_audio(sentence, wav_file) NOTE: StreamingTTS does not support phonemes """ - + def init(self, bus=None, playback=None, callbacks=None): """ Performs intial setup of TTS object. @@ -902,7 +879,7 @@ def init(self, bus=None, playback=None, callbacks=None): async def stream_tts(self, sentence) -> AsyncIterable[bytes]: """yield chunks of TTS audio as they become available""" raise NotImplementedError - + async def generate_audio(self, sentence, wav_file, play_streaming=True, listen=False, message=None, plugin_kwargs=None): """save streamed TTS to wav file, if configured also play TTS as it becomes available""" @@ -927,7 +904,7 @@ def _execute(self, sentence, ident, listen, **kwargs): sentence_hash = hash_sentence(sentence) # parse requested language for this TTS request - lang, voice = self.context.get(kwargs) + lang, voice = self.context.get(kwargs) kwargs["lang"] = lang kwargs["voice"] = voice @@ -947,8 +924,8 @@ def _execute(self, sentence, ident, listen, **kwargs): # filter kwargs per plugin, different plugins expose different options plugin_kwargs = {k: v for k, v in kwargs.items() - if k in inspect.signature(self.stream_tts).parameters - and k not in ["sentence", "wav_file", "play_streaming"]} + if k in inspect.signature(self.stream_tts).parameters + and k not in ["sentence", "wav_file", "play_streaming"]} # handle streaming TTS loop = asyncio.new_event_loop() @@ -956,7 +933,7 @@ def _execute(self, sentence, ident, listen, **kwargs): try: self.add_metric({"metric_type": "tts.stream.start"}) loop.run_until_complete( - self.generate_audio(sentence, wav_file, + self.generate_audio(sentence, wav_file, play_streaming=True, listen=listen, message=message, @@ -965,7 +942,7 @@ def _execute(self, sentence, ident, listen, **kwargs): finally: loop.close() self.add_metric({"metric_type": "tts.stream.end"}) - + def get_tts(self, sentence, wav_file, **kwargs): """wrap streaming TTS into sync usage""" loop = asyncio.new_event_loop() @@ -981,6 +958,8 @@ def get_tts(self, sentence, wav_file, **kwargs): return wav_file, None # No phonemes +# below classes are deprecated and will be removed in 0.1.0 + class RemoteTTS(TTS): """ Abstract class for a Remote TTS engine implementation. @@ -988,6 +967,8 @@ class RemoteTTS(TTS): Usage is discouraged """ + @deprecated("RemoteTTS has been deprecated, please use the regular TTS class", + "0.1.0") def __init__(self, lang, config, url, api_path, validator): super(RemoteTTS, self).__init__(lang, config, validator) self.api_path = api_path @@ -1006,3 +987,20 @@ def get_tts(self, sentence, wav_file, lang=None): with open(wav_file, 'wb') as f: f.write(r.content) return wav_file, None + + +class PlaybackThread(Thread): + """ PlaybackThread moved to ovos_audio.playback + standalone plugin usage should rely on self.get_tts + ovos-audio relies on self.execute and needs this class + + this class was only in ovos-plugin-manager in order to + patch usage of our plugins in mycroft-core""" + + def __new__(self, *args, **kwargs): + LOG.warning("PlaybackThread moved to ovos_audio.playback") + try: + from ovos_audio.playback import PlaybackThread + return PlaybackThread(*args, **kwargs) + except ImportError: + raise ImportError("please install ovos-audio for playback handling") diff --git a/test/unittests/test_tts.py b/test/unittests/test_tts.py index 27fe6e3d..efa01bad 100644 --- a/test/unittests/test_tts.py +++ b/test/unittests/test_tts.py @@ -1,7 +1,9 @@ import unittest -from unittest.mock import patch, Mock +from unittest.mock import MagicMock, patch +from unittest.mock import Mock + +from ovos_plugin_manager.templates.tts import TTS, TTSContext from ovos_plugin_manager.utils import PluginTypes, PluginConfigTypes -from ovos_plugin_manager.templates.tts import TTS class TestTTSTemplate(unittest.TestCase): @@ -114,23 +116,23 @@ def test_format_speak_tags_with_speech_no_tags(self): self.assertEqual(tagged_with_exclusion, valid_output) def test_playback_thread(self): - from ovos_plugin_manager.templates.tts import PlaybackThread + pass # TODO - + def test_tts_context(self): - from ovos_plugin_manager.templates.tts import TTSContext + pass # TODO - + def test_tts_validator(self): - from ovos_plugin_manager.templates.tts import TTSValidator + pass # TODO - + def test_concat_tts(self): - from ovos_plugin_manager.templates.tts import ConcatTTS + pass # TODO - + def test_remote_tt(self): - from ovos_plugin_manager.templates.tts import RemoteTTS + pass # TODO @@ -187,15 +189,15 @@ def test_get_tts_config(self, get_config): self.CONFIG_SECTION, None) def test_get_voice_id(self): - from ovos_plugin_manager.tts import get_voice_id + pass # TODO def test_scan_voices(self): - from ovos_plugin_manager.tts import scan_voices + pass # TODO def test_get_voices(self): - from ovos_plugin_manager.tts import get_voices + pass # TODO @@ -262,3 +264,59 @@ def test_create(self, get_class): get_class.assert_called_with(expected_config) plugin_class.assert_called_with(lang=None, config=expected_config) self.assertEqual(plugin, plugin_class()) + + +class TestTTSContext(unittest.TestCase): + def test_tts_context_init(self): + session_mock = MagicMock() + tts_context = TTSContext(session=session_mock) + self.assertEqual(tts_context.session, session_mock) + self.assertEqual(tts_context.lang, session_mock.lang) + + @patch("ovos_plugin_manager.templates.tts.TextToSpeechCache", autospec=True) + def test_tts_context_get_cache(self, cache_mock): + session_mock = MagicMock() + tts_context = TTSContext(session=session_mock) + + result = tts_context.get_cache() + + self.assertEqual(result, cache_mock.return_value) + self.assertEqual(result, tts_context._caches[tts_context.tts_id]) + + +class TestTTSCache(unittest.TestCase): + def setUp(self): + self.tts_mock = TTS(lang="en-us", config={"some_config_key": "some_config_value"}) + self.tts_mock.stopwatch = MagicMock() + self.tts_mock.queue = MagicMock() + self.tts_mock.playback = MagicMock() + + @patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash") + @patch("ovos_plugin_manager.templates.tts.TTSContext", autospec=True) + def test_tts_synth(self, tts_context_mock, hash_sentence_mock): + tts_context_mock.get_cache.return_value = MagicMock() + tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path" + + sentence = "Hello world!" + result = self.tts_mock.synth(sentence, tts_context_mock) + + tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config) + tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash") + self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None)) + + @patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash") + def test_tts_synth_cache_enabled(self, hash_sentence_mock): + tts_context_mock = MagicMock() + tts_context_mock.tts_id = "fake_tts_id" + tts_context_mock.get_cache.return_value = MagicMock() + tts_context_mock.get_cache.return_value.cached_sentences = {} + tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path" + tts_context_mock._caches = {tts_context_mock.tts_id: tts_context_mock.get_cache.return_value} + + sentence = "Hello world!" + result = self.tts_mock.synth(sentence, tts_context_mock) + + tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config) + tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash") + self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None)) + self.assertIn("fake_hash", tts_context_mock.get_cache.return_value.cached_sentences) From e8b6066c1edbb070710d35631ce58d40d1e9657f Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 20 Apr 2024 19:41:55 +0100 Subject: [PATCH 02/15] voice from Session --- ovos_plugin_manager/templates/tts.py | 91 +++++++++++++++++++--------- test/unittests/test_tts.py | 31 +++++++--- 2 files changed, 85 insertions(+), 37 deletions(-) diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index b7e2b078..abd8b920 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -21,7 +21,7 @@ from ovos_utils.file_utils import get_cache_directory from ovos_utils.file_utils import resolve_resource_file from ovos_utils.lang.visimes import VISIMES -from ovos_utils.log import LOG, deprecated +from ovos_utils.log import LOG, deprecated, log_deprecation from ovos_utils.metrics import Stopwatch from ovos_utils.process_utils import RuntimeRequirements @@ -37,10 +37,11 @@ class TTSContext: _caches = {} - def __init__(self, plugin_id: str, lang: str, voice: str): + def __init__(self, plugin_id: str, lang: str, voice: str, synth_kwargs: dict = None): self.plugin_id = plugin_id self.lang = lang self.voice = voice + self.synth_kwargs = synth_kwargs or {} @property def tts_id(self): @@ -90,6 +91,9 @@ class TTS: def __init__(self, lang=None, config=None, validator=None, audio_ext='wav', phonetic_spelling=True, ssml_tags=None): + if lang is not None: + log_deprecation("lang argument for TTS has been deprecated! it will be ignored, " + "pass lang to get_tts directly instead") self.log_timestamps = False self.config = config or get_plugin_config(config, "tts") @@ -117,6 +121,18 @@ def __init__(self, lang=None, config=None, validator=None, # only present for backwards compat reasons self.bus = None + self._plugin_id = "" # the plugin name + + @property + def plugin_id(self) -> str: + if not self._plugin_id: + from ovos_plugin_manager.tts import find_tts_plugins + for tts_id, clazz in find_tts_plugins().items(): + if isinstance(self, clazz): + self._plugin_id = tts_id + break + return self._plugin_id + # methods for individual plugins to override @classproperty def runtime_requirements(self): @@ -183,11 +199,18 @@ def handle_metric(self, metadata=None): # properties that reflect bus message session @property def voice(self): + voice = self.config.get("voice") or "default" message = dig_for_message() if message: - # TODO - get from tts_prefs in session - pass - return self.config.get("voice") or "default" + sess = SessionManager.get() + if sess.tts_preferences["plugin_id"] == self.plugin_id: + v = sess.tts_preferences["config"].get("voice") + if v: + voice = v + else: + # we got a request for a TTS plugin that isn't loaded! + LOG.error("ignoring TTS preferences in Session, plugin does not match!") + return voice @voice.setter def voice(self, val): @@ -432,18 +455,38 @@ def _get_visemes(self, phonemes, sentence, ctxt): LOG.debug(f"no mouth movements available! unknown visemes for {sentence}") return viseme - def _get_ctxt(self, kwargs=None): - kwargs = kwargs or {} + def _get_ctxt(self, kwargs=None) -> TTSContext: + """create a TTSContext from arbitrary kwargs passed to synth/execute methods + takes preferences from Session into account if a message is present + """ # get request specific synth params + kwargs = kwargs or {} message = kwargs.get("message") or dig_for_message() - lang = kwargs.get("lang") - voice = kwargs.get("voice") - if message and not lang: - sess = SessionManager.get(message) - lang = lang or sess.lang - return TTSContext(plugin_id=self.tts_name, # TODO this should be the OPM name at some point - lang=lang or self.lang, - voice=voice or self.voice) + + # update kwargs from session + if message: + sess = SessionManager.get() + if sess.tts_preferences["plugin_id"] == self.plugin_id: + for k, v in sess.tts_preferences["config"].items(): + if k not in kwargs: + kwargs[k] = v + else: + # we got a request for a TTS plugin that isn't loaded! + LOG.error("ignoring TTS preferences in Session, plugin does not match!") + + if "lang" not in kwargs: + kwargs["lang"] = sess.lang + + # filter kwargs accepted by this specific plugin + kwargs = {k: v for k, v in kwargs.items() + if k in inspect.signature(self.get_tts).parameters + and k not in ["sentence", "wav_file"]} + + LOG.debug(f"TTS kwargs: {kwargs}") + return TTSContext(plugin_id=self.plugin_id, + lang=kwargs.get("lang") or self.lang, + voice=kwargs.get("voice") or self.voice, + synth_kwargs=kwargs) def _execute(self, sentence, ident, listen, preprocess=True, **kwargs): if preprocess: @@ -470,7 +513,7 @@ def _execute(self, sentence, ident, listen, preprocess=True, **kwargs): # synth -> queue for playback for sentence, l in chunks: # load from cache or synth + cache - audio_file, phonemes = self.synth(sentence, ctxt, **kwargs) + audio_file, phonemes = self.synth(sentence, ctxt) # get visemes/mouth movements viseme = self._get_visemes(phonemes, sentence, ctxt) @@ -490,7 +533,7 @@ def synth(self, sentence, ctxt: TTSContext = None, **kwargs): self.add_metric({"metric_type": "tts.synth.start"}) sentence_hash = hash_sentence(sentence) - # parse requested language for this TTS request + # parse kwargs for this TTS request ctxt = ctxt or self._get_ctxt(kwargs) cache = ctxt.get_cache(self.audio_ext, self.config) @@ -502,16 +545,8 @@ def synth(self, sentence, ctxt: TTSContext = None, **kwargs): # synth + cache audio = cache.define_audio_file(sentence_hash) - - # filter kwargs per plugin, different plugins expose different kwargs - # ovos -> lang + voice optional kwargs - # neon-core -> message - kwargs = {k: v for k, v in kwargs.items() - if k in inspect.signature(self.get_tts).parameters - and k not in ["sentence", "wav_file"]} - - # finally do the TTS synth - audio.path, phonemes = self.get_tts(sentence, str(audio), **kwargs) + audio.path, phonemes = self.get_tts(sentence, str(audio), + **ctxt.synth_kwargs) self.add_metric({"metric_type": "tts.synth.finished"}) # cache sentence + phonemes @@ -588,7 +623,6 @@ def __del__(self): self.shutdown() # below code is all deprecated and marked for removal in next stable release - # TODO - update version number in warnings @property @deprecated("self.enclosure has been deprecated, use EnclosureAPI directly decoupled from the plugin code", "0.1.0") @@ -1004,3 +1038,4 @@ def __new__(self, *args, **kwargs): return PlaybackThread(*args, **kwargs) except ImportError: raise ImportError("please install ovos-audio for playback handling") + diff --git a/test/unittests/test_tts.py b/test/unittests/test_tts.py index efa01bad..1a2fa263 100644 --- a/test/unittests/test_tts.py +++ b/test/unittests/test_tts.py @@ -267,16 +267,10 @@ def test_create(self, get_class): class TestTTSContext(unittest.TestCase): - def test_tts_context_init(self): - session_mock = MagicMock() - tts_context = TTSContext(session=session_mock) - self.assertEqual(tts_context.session, session_mock) - self.assertEqual(tts_context.lang, session_mock.lang) @patch("ovos_plugin_manager.templates.tts.TextToSpeechCache", autospec=True) def test_tts_context_get_cache(self, cache_mock): - session_mock = MagicMock() - tts_context = TTSContext(session=session_mock) + tts_context = TTSContext("plug", "voice", "lang") result = tts_context.get_cache() @@ -286,13 +280,13 @@ def test_tts_context_get_cache(self, cache_mock): class TestTTSCache(unittest.TestCase): def setUp(self): - self.tts_mock = TTS(lang="en-us", config={"some_config_key": "some_config_value"}) + self.tts_mock = TTS(config={"some_config_key": "some_config_value"}) self.tts_mock.stopwatch = MagicMock() self.tts_mock.queue = MagicMock() self.tts_mock.playback = MagicMock() @patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash") - @patch("ovos_plugin_manager.templates.tts.TTSContext", autospec=True) + @patch("ovos_plugin_manager.templates.tts.TTSContext") def test_tts_synth(self, tts_context_mock, hash_sentence_mock): tts_context_mock.get_cache.return_value = MagicMock() tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path" @@ -314,9 +308,28 @@ def test_tts_synth_cache_enabled(self, hash_sentence_mock): tts_context_mock._caches = {tts_context_mock.tts_id: tts_context_mock.get_cache.return_value} sentence = "Hello world!" + self.tts_mock.enable_cache = True result = self.tts_mock.synth(sentence, tts_context_mock) tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config) tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash") self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None)) self.assertIn("fake_hash", tts_context_mock.get_cache.return_value.cached_sentences) + + @patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash") + def test_tts_synth_cache_disabled(self, hash_sentence_mock): + tts_context_mock = MagicMock() + tts_context_mock.tts_id = "fake_tts_id" + tts_context_mock.get_cache.return_value = MagicMock() + tts_context_mock.get_cache.return_value.cached_sentences = {} + tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path" + tts_context_mock._caches = {tts_context_mock.tts_id: tts_context_mock.get_cache.return_value} + + sentence = "Hello world!" + self.tts_mock.enable_cache = False + result = self.tts_mock.synth(sentence, tts_context_mock) + + tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config) + tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash") + self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None)) + self.assertNotIn("fake_hash", tts_context_mock.get_cache.return_value.cached_sentences) From f33daa4f342efe3620a828f0ec0fd3a2686a8e59 Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 20 Apr 2024 21:22:48 +0100 Subject: [PATCH 03/15] save a call to find_tts_plugins --- ovos_plugin_manager/tts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ovos_plugin_manager/tts.py b/ovos_plugin_manager/tts.py index ba1983a2..5bfa8289 100644 --- a/ovos_plugin_manager/tts.py +++ b/ovos_plugin_manager/tts.py @@ -205,6 +205,7 @@ def create(config=None): LOG.info(f'Found plugin {tts_module}') tts = clazz(lang=None, # explicitly read lang from config config=tts_config) + tts._plugin_id = tts_module tts.validator.validate() LOG.info(f'Loaded plugin {tts_module}') else: From 1ef4e9e7b9cabd0ae9a953cc7e04fa98fa9bc848 Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 20 Apr 2024 21:24:25 +0100 Subject: [PATCH 04/15] save a call to find_tts_plugins --- ovos_plugin_manager/tts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ovos_plugin_manager/tts.py b/ovos_plugin_manager/tts.py index 5bfa8289..fe453423 100644 --- a/ovos_plugin_manager/tts.py +++ b/ovos_plugin_manager/tts.py @@ -203,8 +203,7 @@ def create(config=None): clazz = OVOSTTSFactory.get_class(tts_config) if clazz: LOG.info(f'Found plugin {tts_module}') - tts = clazz(lang=None, # explicitly read lang from config - config=tts_config) + tts = clazz(config=tts_config) tts._plugin_id = tts_module tts.validator.validate() LOG.info(f'Loaded plugin {tts_module}') From 0d28777fb22c5804a18d6510b413e4874ecca3bb Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 20 Apr 2024 21:36:54 +0100 Subject: [PATCH 05/15] streaming TTS too --- ovos_plugin_manager/templates/tts.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index abd8b920..6ecbbae5 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -910,7 +910,7 @@ def init(self, bus=None, playback=None, callbacks=None): tts_config=self.config) @abc.abstractmethod - async def stream_tts(self, sentence) -> AsyncIterable[bytes]: + async def stream_tts(self, sentence, **kwargs) -> AsyncIterable[bytes]: """yield chunks of TTS audio as they become available""" raise NotImplementedError @@ -938,16 +938,13 @@ def _execute(self, sentence, ident, listen, **kwargs): sentence_hash = hash_sentence(sentence) # parse requested language for this TTS request - lang, voice = self.context.get(kwargs) - kwargs["lang"] = lang - kwargs["voice"] = voice - - # get path to cache final synthesized file - cache = self.get_cache(voice, lang) # cache per tts_id (lang/voice combo) + ctxt = self._get_ctxt(kwargs) + cache = ctxt.get_cache(self.audio_ext, self.config) # if cached, play existing file instead if self.enable_cache and sentence_hash in cache: - super()._execute(sentence, ident, listen, preprocess=False, **kwargs) + super()._execute(sentence, ident, listen, + preprocess=False, **ctxt.synth_kwargs) return wav_file = str(cache.define_audio_file(sentence_hash)) @@ -956,10 +953,10 @@ def _execute(self, sentence, ident, listen, **kwargs): dig_for_message() or \ Message("speak") - # filter kwargs per plugin, different plugins expose different options - plugin_kwargs = {k: v for k, v in kwargs.items() - if k in inspect.signature(self.stream_tts).parameters - and k not in ["sentence", "wav_file", "play_streaming"]} + # filter kwargs accepted by this specific plugin + ctxt.synth_kwargs = {k: v for k, v in kwargs.items() + if k in inspect.signature(self.stream_tts).parameters + and k not in ["sentence"]} # handle streaming TTS loop = asyncio.new_event_loop() @@ -971,7 +968,7 @@ def _execute(self, sentence, ident, listen, **kwargs): play_streaming=True, listen=listen, message=message, - plugin_kwargs=plugin_kwargs) + plugin_kwargs=ctxt.synth_kwargs) ) finally: loop.close() @@ -1038,4 +1035,3 @@ def __new__(self, *args, **kwargs): return PlaybackThread(*args, **kwargs) except ImportError: raise ImportError("please install ovos-audio for playback handling") - From 9b8173dca76f40ea9cb77256cdf934808c2ce99e Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 20 Apr 2024 21:38:54 +0100 Subject: [PATCH 06/15] tests --- test/unittests/test_tts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/unittests/test_tts.py b/test/unittests/test_tts.py index 1a2fa263..fe9ed18f 100644 --- a/test/unittests/test_tts.py +++ b/test/unittests/test_tts.py @@ -248,13 +248,13 @@ def test_create(self, get_class): "config": True, "lang": "en-ca"} get_class.assert_called_once_with(expected_config) - plugin_class.assert_called_once_with(lang=None, config=expected_config) + plugin_class.assert_called_once_with(config=expected_config) self.assertEqual(plugin, plugin_class()) # Test create with TTS config and no module config plugin = OVOSTTSFactory.create(tts_config) get_class.assert_called_with(tts_config) - plugin_class.assert_called_with(lang=None, config=tts_config) + plugin_class.assert_called_with(config=tts_config) self.assertEqual(plugin, plugin_class()) # Test create with TTS config with module-specific config @@ -262,7 +262,7 @@ def test_create(self, get_class): expected_config = {"module": "test-tts-plugin-test", "config": True, "lang": "es-mx"} get_class.assert_called_with(expected_config) - plugin_class.assert_called_with(lang=None, config=expected_config) + plugin_class.assert_called_with(config=expected_config) self.assertEqual(plugin, plugin_class()) From b121afc447edec83433ababae97c4512800bc06a Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 20 Apr 2024 21:59:55 +0100 Subject: [PATCH 07/15] docstrs --- ovos_plugin_manager/templates/tts.py | 249 ++++++++++++++++++++++++--- 1 file changed, 221 insertions(+), 28 deletions(-) diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index 6ecbbae5..122e3687 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -7,7 +7,7 @@ from pathlib import Path from queue import Queue from threading import Thread -from typing import AsyncIterable +from typing import AsyncIterable, List import quebra_frases import requests @@ -35,9 +35,31 @@ class TTSContext: + """ + A context manager for handling Text-To-Speech (TTS) operations and caching. + + Attributes: + plugin_id (str): Identifier for the TTS plugin being used. + lang (str): Language code for the TTS operation. + voice (str): Identifier for the voice type in use. + synth_kwargs (dict): Optional dictionary containing additional keyword arguments for the TTS synthesizer. + + Class Attributes: + _caches (dict): A class-level dictionary acting as a cache store for different TTS contexts. + """ + _caches = {} def __init__(self, plugin_id: str, lang: str, voice: str, synth_kwargs: dict = None): + """ + Initializes the TTSContext instance. + + Parameters: + plugin_id (str): The unique identifier for the TTS plugin. + lang (str): The language in which the text will be synthesized. + voice (str): The voice model to be used for text synthesis. + synth_kwargs (dict, optional): Additional keyword arguments for the synthesizer. + """ self.plugin_id = plugin_id self.lang = lang self.voice = voice @@ -45,9 +67,26 @@ def __init__(self, plugin_id: str, lang: str, voice: str, synth_kwargs: dict = N @property def tts_id(self): + """ + Constructs a unique identifier for the TTS context based on plugin, voice, and language. + + Returns: + str: A unique identifier that represents the TTS context. + """ return join(self.plugin_id, self.voice, self.lang) def get_cache(self, audio_ext="wav", cache_config=None): + """ + Retrieves or creates a cache instance for the current TTS context. + + Parameters: + audio_ext (str, optional): The file extension for the audio files (default is 'wav'). + cache_config (dict, optional): Configuration settings for the cache, including parameters like + minimum free percent, persistence settings, and cache directory path. + + Returns: + TextToSpeechCache: The cache instance associated with the current TTS context. + """ cache_config = cache_config or { "min_free_percent": 75, "persist_cache": False, @@ -61,6 +100,20 @@ def get_cache(self, audio_ext="wav", cache_config=None): return self._caches[self.tts_id] def get_from_cache(self, sentence, audio_ext="wav", cache_config=None): + """ + Retrieves an audio file and phoneme data from the cache, based on the input sentence. + + Parameters: + sentence (str): The sentence for which to retrieve audio data. + audio_ext (str, optional): The file extension of the audio file (default is 'wav'). + cache_config (dict, optional): Configuration settings for the cache. + + Returns: + tuple: A tuple containing the path to the cached audio file and optionally the phoneme data. + + Raises: + FileNotFoundError: If the sentence is not found in the cache. + """ sentence_hash = hash_sentence(sentence) phonemes = None cache = self.get_cache(audio_ext, cache_config) @@ -79,18 +132,34 @@ class TTS: It aggregates the minimum required parameters and exposes ``execute(sentence)`` and ``validate_ssml(sentence)`` functions. - Arguments: - lang (str): - config (dict): Configuration for this specific tts engine - validator (TTSValidator): Used to verify proper installation - phonetic_spelling (bool): Whether to spell certain words phonetically - ssml_tags (list): Supported ssml properties. Ex. ['speak', 'prosody'] + Attributes: + queue (Queue): A queue for managing TTS playback tasks. + playback (PlaybackThread): The playback thread used for TTS audio output. + + Args: + lang (str): The language code for the TTS engine. + config (dict): Configuration settings for the specific TTS engine. + validator (TTSValidator): Validator used to verify proper installation. + audio_ext (str): The default audio file extension (default is 'wav'). + phonetic_spelling (bool): Whether to spell certain words phonetically. + ssml_tags (list): Supported SSML properties (e.g., ['speak', 'prosody']). """ queue = None playback = None def __init__(self, lang=None, config=None, validator=None, audio_ext='wav', phonetic_spelling=True, ssml_tags=None): + """ + Initializes the TTS engine with specified parameters. + + Args: + lang (str): The language code (deprecated). + config (dict): Configuration settings for the TTS engine. + validator (TTSValidator): Validator for verifying installation. + audio_ext (str): Default audio file extension (default is 'wav'). + phonetic_spelling (bool): Whether to use phonetic spelling (default is True). + ssml_tags (list): Supported SSML tags (default is None). + """ if lang is not None: log_deprecation("lang argument for TTS has been deprecated! it will be ignored, " "pass lang to get_tts directly instead") @@ -125,6 +194,12 @@ def __init__(self, lang=None, config=None, validator=None, @property def plugin_id(self) -> str: + """ + Retrieves the plugin ID for the TTS engine. + + Returns: + str: The plugin ID associated with the TTS engine. + """ if not self._plugin_id: from ovos_plugin_manager.tts import find_tts_plugins for tts_id, clazz in find_tts_plugins().items(): @@ -147,7 +222,7 @@ def available_languages(self) -> set: This property should be overridden by the derived class to advertise what languages that engine supports. Returns: - set: supported languages + set: A set of supported language codes. """ return set() @@ -155,19 +230,17 @@ def available_languages(self) -> set: def get_tts(self, sentence, wav_file, lang=None): """Abstract method that a tts implementation needs to implement. - Should get data from tts. - - Arguments: - sentence(str): Sentence to synthesize - wav_file(str): output file - lang(str): requested language (optional), defaults to self.lang + Args: + sentence (str): The input sentence to synthesize. + wav_file (str): The output file path for the synthesized audio. + lang (str, optional): The requested language (defaults to self.lang). Returns: tuple: (wav_file, phoneme) """ return "", None - def preprocess_sentence(self, sentence): + def preprocess_sentence(self, sentence: str) -> List[str]: """Default preprocessing is a sentence_tokenizer, ie. splits the utterance into sub-sentences using quebra_frases @@ -317,6 +390,9 @@ def validate_ssml(self, utterance): # init helpers def _init_g2p(self): + """ + Initializes the grapheme-to-phoneme (G2P) conversion for the TTS engine. + """ cfg = Configuration() g2pm = self.config.get("g2p_module") if g2pm: @@ -357,6 +433,13 @@ def init(self, bus=None, playback=None): self.add_metric({"metric_type": "tts.setup"}) def _init_playback(self, playback): + """ + Initializes the playback functionality for the TTS engine. + + Args: + playback: PlaybackThread instance. + """ + TTS.playback = playback TTS.playback.set_bus(self.bus) TTS.playback.attach_tts(self) @@ -367,7 +450,15 @@ def _init_playback(self, playback): TTS.playback.start() def load_spellings(self, config=None): - """Load phonetic spellings of words as dictionary.""" + """ + Loads phonetic spellings of words as a dictionary. + + Args: + config (dict, optional): Configuration settings. + + Returns: + dict: A dictionary of phonetic spellings. + """ path = join('text', self.lang.lower(), 'phonetic_spellings.txt') try: spellings_file = resolve_resource_file(path, config=config or Configuration()) @@ -387,7 +478,12 @@ def load_spellings(self, config=None): ## execution events def add_metric(self, metadata=None): - """ wraps handle_metric to catch exceptions and log timestamps """ + """ + Wraps handle_metric to catch exceptions and log timestamps. + + Args: + metadata (dict, optional): Additional metadata for the metric. + """ try: self.handle_metric(metadata) if self.log_timestamps: @@ -527,9 +623,19 @@ def _execute(self, sentence, ident, listen, preprocess=True, **kwargs): self.add_metric({"metric_type": "tts.queued"}) def synth(self, sentence, ctxt: TTSContext = None, **kwargs): - """ This method wraps get_tts - several optional keyword arguments are supported - sentence will be read/saved to cache""" + """ + Synthesizes speech for the given sentence. wraps get_tts + + sentence will be read/saved to cache + + Args: + sentence (str): The sentence to synthesize. + ctxt (TTSContext): The TTS context. + **kwargs: Additional synth arguments for get_tts. + + Returns: + tuple: A tuple containing the path to the synthesized audio file and phoneme data. + """ self.add_metric({"metric_type": "tts.synth.start"}) sentence_hash = hash_sentence(sentence) @@ -583,6 +689,15 @@ def viseme(self, phonemes): ## cache def _cache_phonemes(self, sentence, cache: TextToSpeechCache = None, phonemes=None, sentence_hash=None): + """ + Caches phonemes for the given sentence. + + Args: + sentence (str): The sentence to cache phonemes for. + cache (TextToSpeechCache): The cache instance. + phonemes (str, optional): The phonemes for the sentence. + sentence_hash (str, optional): The hash of the sentence. + """ sentence_hash = sentence_hash or hash_sentence(sentence) if not phonemes and self.g2p is not None: try: @@ -597,6 +712,16 @@ def _cache_phonemes(self, sentence, cache: TextToSpeechCache = None, phonemes=No return None def _cache_sentence(self, sentence, audio_file, cache, phonemes=None, sentence_hash=None): + """ + Caches the sentence along with associated audio and phonemes. + + Args: + sentence (str): The sentence to cache. + audio_file (AudioFile): The audio file associated with the sentence. + cache (TextToSpeechCache): The cache instance. + phonemes (str, optional): The phonemes for the sentence. + sentence_hash (str, optional): The hash of the sentence. + """ sentence_hash = sentence_hash or hash_sentence(sentence) # RANT: why do you hate strings ChrisV? if isinstance(audio_file.path, str): @@ -607,6 +732,7 @@ def _cache_sentence(self, sentence, audio_file, cache, phonemes=None, sentence_h ## shutdown def stop(self): + """Stops the TTS playback.""" if TTS.playback: try: TTS.playback.stop() @@ -615,11 +741,13 @@ def stop(self): self.add_metric({"metric_type": "tts.stop"}) def shutdown(self): + """Shuts down the TTS engine.""" self.stop() if TTS.playback: TTS.playback.detach_tts(self) def __del__(self): + """Destructor for the TTS object.""" self.shutdown() # below code is all deprecated and marked for removal in next stable release @@ -627,6 +755,11 @@ def __del__(self): @deprecated("self.enclosure has been deprecated, use EnclosureAPI directly decoupled from the plugin code", "0.1.0") def enclosure(self): + """Deprecated. Accessor for the enclosure property. + + Returns: + EnclosureAPI: The EnclosureAPI instance associated with the TTS playback. + """ if not TTS.playback.enclosure: bus = TTS.playback.bus or self.bus TTS.playback.enclosure = EnclosureAPI(bus) @@ -636,12 +769,22 @@ def enclosure(self): @deprecated("self.enclosure has been deprecated, use EnclosureAPI directly decoupled from the plugin code", "0.1.0") def enclosure(self, val): + """Deprecated. Setter for the enclosure property. + + Arguments: + val (EnclosureAPI): The EnclosureAPI instance to set. + """ TTS.playback.enclosure = val @property @deprecated("self.filename has been deprecated, unused for a long time now", "0.1.0") def filename(self): + """Deprecated. Accessor for the filename property. + + Returns: + str: The filename for the TTS audio. + """ cache_dir = get_cache_directory(self.tts_name) return join(cache_dir, 'tts.' + self.audio_ext) @@ -649,18 +792,32 @@ def filename(self): @deprecated("self.filename has been deprecated, unused for a long time now", "0.1.0") def filename(self, val): - pass + """Deprecated. Setter for the filename property. + + Arguments: + val (str): The filename to set. + """ @property @deprecated("self.tts_id has been deprecated, use TTSContext().tts_id", "0.1.0") def tts_id(self): + """Deprecated. Accessor for the tts_id property. + + Returns: + str: The ID associated with the TTS context. + """ return self._get_ctxt().tts_id @property @deprecated("self.cache has been deprecated, use TTSContext().get_cache", "0.1.0") def cache(self): + """Deprecated. Accessor for the cache property. + + Returns: + TextToSpeechCache: The cache associated with the TTS context. + """ return TTSContext._caches.get(self.tts_id) or \ self.get_cache() @@ -668,35 +825,60 @@ def cache(self): @deprecated("self.cache has been deprecated, use TTSContext().get_cache", "0.1.0") def cache(self, val): + """Deprecated. Setter for the cache property. + + Arguments: + val (TextToSpeechCache): The cache to set. + """ TTSContext._caches[self.tts_id] = val @deprecated("get_voice was never formally adopted and is unused, it will be removed", "0.1.0") def get_voice(self, gender, lang=None): - """ map a language and gender to a valid voice for this TTS engine """ + """Deprecated. Get a valid voice for the TTS engine. + + Arguments: + gender (str): Gender of the voice. + lang (str, optional): Language for the voice. Defaults to None. + + Returns: + str: The selected voice. + """ lang = lang or self.lang return gender @deprecated("get_cache has been deprecated, use TTSContext().get_cache directly", "0.1.0") def get_cache(self, voice=None, lang=None): + """Deprecated. Get the cache associated with the TTS context. + + Arguments: + voice (str, optional): Voice for the cache. Defaults to None. + lang (str, optional): Language for the cache. Defaults to None. + + Returns: + TextToSpeechCache: The cache associated with the TTS context. + """ return self._get_ctxt().get_cache(self.audio_ext, self.config) @deprecated("clear_cache has been deprecated, use TTSContext().get_cache directly", "0.1.0") def clear_cache(self): - """ Remove all cached files. """ + """Deprecated. Clear all cached files.""" cache = self._get_ctxt().get_cache(self.audio_ext, self.config) cache.clear() @deprecated("save_phonemes has been deprecated, use TTSContext().get_cache directly", "0.1.0") def save_phonemes(self, key, phonemes): - """Cache phonemes + """Deprecated. Cache phonemes. Arguments: - key (str): Hash key for the sentence - phonemes (str): phoneme string to save + key (str): Hash key for the sentence. + phonemes (str): Phoneme string to save. + + Returns: + PhonemeFile: The PhonemeFile instance. """ cache = self._get_ctxt().get_cache(self.audio_ext, self.config) phoneme_file = cache.define_phoneme_file(key) @@ -706,10 +888,13 @@ def save_phonemes(self, key, phonemes): @deprecated("load_phonemes has been deprecated, use TTSContext().get_cache directly", "0.1.0") def load_phonemes(self, key): - """Load phonemes from cache file. + """Deprecated. Load phonemes from cache file. Arguments: - key (str): Key identifying phoneme cache + key (str): Key identifying phoneme cache. + + Returns: + str: Phonemes loaded from the cache file. """ cache = self._get_ctxt().get_cache(self.audio_ext, self.config) phoneme_file = cache.define_phoneme_file(key) @@ -718,6 +903,14 @@ def load_phonemes(self, key): @deprecated("get_from_cache has been deprecated, use TTSContext().get_from_cache directly", "0.1.0") def get_from_cache(self, sentence): + """Deprecated. Get data from the cache. + + Arguments: + sentence (str): Sentence used as cache key. + + Returns: + tuple: Tuple containing the audio and phonemes. + """ return self._get_ctxt().get_from_cache(sentence, self.audio_ext, self.config) From 79e608ab143e36f42ed32395ee504f40edef2583 Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 20 Apr 2024 23:42:38 +0100 Subject: [PATCH 08/15] unittests for session --- ovos_plugin_manager/templates/tts.py | 9 +++--- test/unittests/test_tts.py | 46 ++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index 122e3687..5108b97f 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -227,13 +227,14 @@ def available_languages(self) -> set: return set() @abc.abstractmethod - def get_tts(self, sentence, wav_file, lang=None): + def get_tts(self, sentence, wav_file, lang=None, voice=None): """Abstract method that a tts implementation needs to implement. Args: sentence (str): The input sentence to synthesize. wav_file (str): The output file path for the synthesized audio. lang (str, optional): The requested language (defaults to self.lang). + voice (str, optional): The requested voice (defaults to self.voice). Returns: tuple: (wav_file, phoneme) @@ -275,7 +276,7 @@ def voice(self): voice = self.config.get("voice") or "default" message = dig_for_message() if message: - sess = SessionManager.get() + sess = SessionManager.get(message) if sess.tts_preferences["plugin_id"] == self.plugin_id: v = sess.tts_preferences["config"].get("voice") if v: @@ -293,7 +294,7 @@ def voice(self, val): def lang(self): message = dig_for_message() if message: - sess = SessionManager.get() + sess = SessionManager.get(message) return sess.lang return self.config.get("lang") or 'en-us' @@ -561,7 +562,7 @@ def _get_ctxt(self, kwargs=None) -> TTSContext: # update kwargs from session if message: - sess = SessionManager.get() + sess = SessionManager.get(message) if sess.tts_preferences["plugin_id"] == self.plugin_id: for k, v in sess.tts_preferences["config"].items(): if k not in kwargs: diff --git a/test/unittests/test_tts.py b/test/unittests/test_tts.py index fe9ed18f..cc932ba1 100644 --- a/test/unittests/test_tts.py +++ b/test/unittests/test_tts.py @@ -1,6 +1,10 @@ import unittest -from unittest.mock import MagicMock, patch -from unittest.mock import Mock +from unittest.mock import MagicMock +from unittest.mock import patch, Mock + +from ovos_bus_client.session import Session +from ovos_config import Configuration +from ovos_utils.fakebus import FakeBus, Message from ovos_plugin_manager.templates.tts import TTS, TTSContext from ovos_plugin_manager.utils import PluginTypes, PluginConfigTypes @@ -333,3 +337,41 @@ def test_tts_synth_cache_disabled(self, hash_sentence_mock): tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash") self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None)) self.assertNotIn("fake_hash", tts_context_mock.get_cache.return_value.cached_sentences) + + +class TestSession(unittest.TestCase): + def test_tts_session(self): + sess = Session(session_id="123") + m = Message("speak", + context={"session": sess.serialize()}) + + tts = TTS() + tts.init(FakeBus(), Mock()) + self.assertEqual(tts.plugin_id, "ovos-tts-plugin-dummy") + self.assertEqual(tts.voice, "default") # no voice set + self.assertEqual(tts.lang, Configuration()["lang"]) # from config + + kwargs = {"message": m} + tts.execute("test sentence", **kwargs) + path, visemes, listen, tts_id, message = tts.queue.get() + self.assertEqual(message, m) + self.assertEqual(message.context["session"]["session_id"], sess.session_id) + + ctxt = tts._get_ctxt(kwargs) + self.assertEqual(ctxt.plugin_id, tts.plugin_id) + self.assertEqual(ctxt.lang, sess.lang) + self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/default/{sess.lang}") + self.assertEqual(ctxt.synth_kwargs, {'lang': 'en-us'}) + + sess = Session(session_id="123", + lang="klingon", + tts_prefs={"plugin_id": "ovos-tts-plugin-dummy", + "config": {"voice": "A"}}) + m = Message("speak", + context={"session": sess.serialize()}) + kwargs = {"message": m} + ctxt = tts._get_ctxt(kwargs) + self.assertEqual(ctxt.lang, sess.lang) + self.assertEqual(ctxt.voice, sess.tts_preferences["config"]["voice"]) + self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/{ctxt.voice}/{sess.lang}") + self.assertEqual(ctxt.synth_kwargs, {'lang': 'klingon', 'voice': 'A'}) From 7b1972560d29401b738efaa0a61e7fbd1e7a7c3f Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 20 Apr 2024 23:51:05 +0100 Subject: [PATCH 09/15] unittests for session --- test/unittests/test_tts.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/test/unittests/test_tts.py b/test/unittests/test_tts.py index cc932ba1..266e8b8d 100644 --- a/test/unittests/test_tts.py +++ b/test/unittests/test_tts.py @@ -341,28 +341,30 @@ def test_tts_synth_cache_disabled(self, hash_sentence_mock): class TestSession(unittest.TestCase): def test_tts_session(self): - sess = Session(session_id="123") + sess = Session(session_id="123", lang="en-us") m = Message("speak", context={"session": sess.serialize()}) tts = TTS() - tts.init(FakeBus(), Mock()) self.assertEqual(tts.plugin_id, "ovos-tts-plugin-dummy") self.assertEqual(tts.voice, "default") # no voice set - self.assertEqual(tts.lang, Configuration()["lang"]) # from config + self.assertEqual(tts.lang, "en-us") # from config + # test that session makes it all the way to the TTS.queue kwargs = {"message": m} tts.execute("test sentence", **kwargs) path, visemes, listen, tts_id, message = tts.queue.get() self.assertEqual(message, m) self.assertEqual(message.context["session"]["session_id"], sess.session_id) + # test that lang from Session is used ctxt = tts._get_ctxt(kwargs) self.assertEqual(ctxt.plugin_id, tts.plugin_id) self.assertEqual(ctxt.lang, sess.lang) - self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/default/{sess.lang}") + self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/default/en-us") self.assertEqual(ctxt.synth_kwargs, {'lang': 'en-us'}) + # test that tts_prefs are used if plugin_id matches sess = Session(session_id="123", lang="klingon", tts_prefs={"plugin_id": "ovos-tts-plugin-dummy", @@ -373,5 +375,21 @@ def test_tts_session(self): ctxt = tts._get_ctxt(kwargs) self.assertEqual(ctxt.lang, sess.lang) self.assertEqual(ctxt.voice, sess.tts_preferences["config"]["voice"]) - self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/{ctxt.voice}/{sess.lang}") + self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/A/klingon") self.assertEqual(ctxt.synth_kwargs, {'lang': 'klingon', 'voice': 'A'}) + + # test that tts_prefs are ignored if plugin_id doesnt match + sess = Session(session_id="123", + lang="klingon", + tts_prefs={"plugin_id": "ovos-tts-plugin-INVALID", + "config": {"voice": "A"}}) + m = Message("speak", + context={"session": sess.serialize()}) + kwargs = {"message": m} + ctxt = tts._get_ctxt(kwargs) + self.assertEqual(ctxt.lang, sess.lang) + self.assertEqual(ctxt.voice, "default") + self.assertNotEqual(ctxt.tts_id, f"ovos-tts-plugin-INVALID/A/klingon") + self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/default/klingon") + self.assertEqual(ctxt.synth_kwargs, {'lang': 'klingon'}) + From fef7f72b9ded381eb719fe6ef0ba1ff56d2eeb5a Mon Sep 17 00:00:00 2001 From: miro Date: Tue, 23 Apr 2024 19:50:52 +0100 Subject: [PATCH 10/15] drop tts_prefs from session --- ovos_plugin_manager/templates/tts.py | 28 ++++------------------------ 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index 5108b97f..fb8946d8 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -273,18 +273,7 @@ def handle_metric(self, metadata=None): # properties that reflect bus message session @property def voice(self): - voice = self.config.get("voice") or "default" - message = dig_for_message() - if message: - sess = SessionManager.get(message) - if sess.tts_preferences["plugin_id"] == self.plugin_id: - v = sess.tts_preferences["config"].get("voice") - if v: - voice = v - else: - # we got a request for a TTS plugin that isn't loaded! - LOG.error("ignoring TTS preferences in Session, plugin does not match!") - return voice + return self.config.get("voice") or "default" @voice.setter def voice(self, val): @@ -554,25 +543,16 @@ def _get_visemes(self, phonemes, sentence, ctxt): def _get_ctxt(self, kwargs=None) -> TTSContext: """create a TTSContext from arbitrary kwargs passed to synth/execute methods - takes preferences from Session into account if a message is present + takes lang from Session into account if a message is present """ # get request specific synth params kwargs = kwargs or {} message = kwargs.get("message") or dig_for_message() # update kwargs from session - if message: + if message and "lang" not in kwargs: sess = SessionManager.get(message) - if sess.tts_preferences["plugin_id"] == self.plugin_id: - for k, v in sess.tts_preferences["config"].items(): - if k not in kwargs: - kwargs[k] = v - else: - # we got a request for a TTS plugin that isn't loaded! - LOG.error("ignoring TTS preferences in Session, plugin does not match!") - - if "lang" not in kwargs: - kwargs["lang"] = sess.lang + kwargs["lang"] = sess.lang # filter kwargs accepted by this specific plugin kwargs = {k: v for k, v in kwargs.items() From aa12e95641df0cd960fd0ae1e31f7bb79254c856 Mon Sep 17 00:00:00 2001 From: miro Date: Tue, 23 Apr 2024 19:54:36 +0100 Subject: [PATCH 11/15] drop tts_prefs from session --- test/unittests/test_tts.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/test/unittests/test_tts.py b/test/unittests/test_tts.py index 266e8b8d..88b7a29f 100644 --- a/test/unittests/test_tts.py +++ b/test/unittests/test_tts.py @@ -364,32 +364,12 @@ def test_tts_session(self): self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/default/en-us") self.assertEqual(ctxt.synth_kwargs, {'lang': 'en-us'}) - # test that tts_prefs are used if plugin_id matches sess = Session(session_id="123", - lang="klingon", - tts_prefs={"plugin_id": "ovos-tts-plugin-dummy", - "config": {"voice": "A"}}) + lang="klingon") m = Message("speak", context={"session": sess.serialize()}) kwargs = {"message": m} ctxt = tts._get_ctxt(kwargs) self.assertEqual(ctxt.lang, sess.lang) - self.assertEqual(ctxt.voice, sess.tts_preferences["config"]["voice"]) - self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/A/klingon") - self.assertEqual(ctxt.synth_kwargs, {'lang': 'klingon', 'voice': 'A'}) - - # test that tts_prefs are ignored if plugin_id doesnt match - sess = Session(session_id="123", - lang="klingon", - tts_prefs={"plugin_id": "ovos-tts-plugin-INVALID", - "config": {"voice": "A"}}) - m = Message("speak", - context={"session": sess.serialize()}) - kwargs = {"message": m} - ctxt = tts._get_ctxt(kwargs) - self.assertEqual(ctxt.lang, sess.lang) - self.assertEqual(ctxt.voice, "default") - self.assertNotEqual(ctxt.tts_id, f"ovos-tts-plugin-INVALID/A/klingon") self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/default/klingon") self.assertEqual(ctxt.synth_kwargs, {'lang': 'klingon'}) - From f093364d562d1c78c64436bb45948682332b7181 Mon Sep 17 00:00:00 2001 From: miro Date: Thu, 25 Apr 2024 20:36:27 +0100 Subject: [PATCH 12/15] support phonetic spellings again per lang used to refer to a hardcoded nglish file in mycroft-core specific to mimic1 now generalized to be per TTS plugin and live in a "locale" folder like everything else --- ovos_plugin_manager/templates/tts.py | 108 ++++++++++++++------------- 1 file changed, 56 insertions(+), 52 deletions(-) diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index fb8946d8..8ab5ccce 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -1,13 +1,14 @@ import abc import asyncio import inspect +import os.path import re import subprocess from os.path import isfile, join from pathlib import Path from queue import Queue from threading import Thread -from typing import AsyncIterable, List +from typing import AsyncIterable, List, Dict import quebra_frases import requests @@ -48,7 +49,7 @@ class TTSContext: _caches (dict): A class-level dictionary acting as a cache store for different TTS contexts. """ - _caches = {} + _caches: Dict[str, TextToSpeechCache] = {} def __init__(self, plugin_id: str, lang: str, voice: str, synth_kwargs: dict = None): """ @@ -181,7 +182,7 @@ def __init__(self, lang=None, config=None, validator=None, if TTS.queue is None: TTS.queue = Queue() - self.spellings = self.load_spellings() + self.spellings: Dict[str, dict] = self.load_spellings() self._init_g2p() self.add_metric({"metric_type": "tts.init"}) @@ -270,7 +271,6 @@ def handle_metric(self, metadata=None): """ receive timing metrics for diagnostics does nothing by default but plugins might use it, eg, NeonCore""" - # properties that reflect bus message session @property def voice(self): return self.config.get("voice") or "default" @@ -279,18 +279,6 @@ def voice(self): def voice(self, val): self.config["voice"] = val - @property - def lang(self): - message = dig_for_message() - if message: - sess = SessionManager.get(message) - return sess.lang - return self.config.get("lang") or 'en-us' - - @lang.setter - def lang(self, val): - LOG.warning("self.lang can not be set! it comes from the bus message") - # SSML helpers @staticmethod def remove_ssml(text): @@ -439,7 +427,7 @@ def _init_playback(self, playback): if not TTS.playback.is_alive(): TTS.playback.start() - def load_spellings(self, config=None): + def load_spellings(self, config=None) -> Dict[str, dict]: """ Loads phonetic spellings of words as a dictionary. @@ -449,22 +437,22 @@ def load_spellings(self, config=None): Returns: dict: A dictionary of phonetic spellings. """ - path = join('text', self.lang.lower(), 'phonetic_spellings.txt') - try: - spellings_file = resolve_resource_file(path, config=config or Configuration()) - except: - LOG.debug('Failed to locate phonetic spellings resource file.') - return {} - if not spellings_file: - return {} - try: - with open(spellings_file) as f: - lines = filter(bool, f.read().split('\n')) - lines = [i.split(':') for i in lines] - return {key.strip(): value.strip() for key, value in lines} - except ValueError: - LOG.exception('Failed to load phonetic spellings.') - return {} + if config: + LOG.warning("config argument is deprecated and unused!") + spellings_data = {} + locale = f"{os.path.dirname(__file__)}/locale" + for lang in os.listdir(locale): + spellings_file = f"{locale}/{lang}/phonetic_spellings.txt" + if not os.path.isfile(spellings_file): + continue + try: + with open(spellings_file) as f: + lines = filter(bool, f.read().split('\n')) + lines = [i.split(':') for i in lines] + spellings_data[lang] = {key.strip(): value.strip() for key, value in lines} + except ValueError: + LOG.exception(f'Failed to load {lang} phonetic spellings.') + return spellings_data ## execution events def add_metric(self, metadata=None): @@ -514,11 +502,11 @@ def execute(self, sentence, ident=None, listen=False, **kwargs): self.end_audio() ## synth - def _replace_phonetic_spellings(self, sentence): - if self.phonetic_spelling: + def _replace_phonetic_spellings(self, sentence:str, lang: str) -> str: + if self.phonetic_spelling and lang in self.spellings: for word in re.findall(r"[\w']+", sentence): - if word.lower() in self.spellings: - spelled = self.spellings[word.lower()] + if word.lower() in self.spellings[lang]: + spelled = self.spellings[lang][word.lower()] sentence = sentence.replace(word, spelled) return sentence @@ -561,14 +549,17 @@ def _get_ctxt(self, kwargs=None) -> TTSContext: LOG.debug(f"TTS kwargs: {kwargs}") return TTSContext(plugin_id=self.plugin_id, - lang=kwargs.get("lang") or self.lang, + lang=kwargs.get("lang") or Configuration().get("lang", "en-us"), voice=kwargs.get("voice") or self.voice, synth_kwargs=kwargs) def _execute(self, sentence, ident, listen, preprocess=True, **kwargs): + # get request specific synth params + ctxt = self._get_ctxt(kwargs) + if preprocess: # pre-process - sentence = self._replace_phonetic_spellings(sentence) + sentence = self._replace_phonetic_spellings(sentence, ctxt.lang) chunks = self.preprocess_sentence(sentence) # Apply the listen flag to the last chunk, set the rest to False chunks = [(chunks[i], listen if i == len(chunks) - 1 else False) @@ -580,9 +571,6 @@ def _execute(self, sentence, ident, listen, preprocess=True, **kwargs): else: chunks = [(sentence, listen)] - # get request specific synth params - ctxt = self._get_ctxt(kwargs) - message = kwargs.get("message") or \ dig_for_message() or \ Message("speak", context={"session": {"session_id": ident}}) @@ -638,7 +626,7 @@ def synth(self, sentence, ctxt: TTSContext = None, **kwargs): # cache sentence + phonemes if self.enable_cache: - self._cache_sentence(sentence, audio, cache, + self._cache_sentence(sentence, ctxt.lang, audio, cache, phonemes, sentence_hash) return audio, phonemes @@ -669,7 +657,7 @@ def viseme(self, phonemes): return visimes or None ## cache - def _cache_phonemes(self, sentence, cache: TextToSpeechCache = None, phonemes=None, sentence_hash=None): + def _cache_phonemes(self, sentence, lang: str, cache: TextToSpeechCache = None, phonemes=None, sentence_hash=None): """ Caches phonemes for the given sentence. @@ -682,7 +670,7 @@ def _cache_phonemes(self, sentence, cache: TextToSpeechCache = None, phonemes=No sentence_hash = sentence_hash or hash_sentence(sentence) if not phonemes and self.g2p is not None: try: - phonemes = self.g2p.utterance2arpa(sentence, self.lang) + phonemes = self.g2p.utterance2arpa(sentence, lang) self.add_metric({"metric_type": "tts.phonemes.g2p"}) except Exception as e: self.add_metric({"metric_type": "tts.phonemes.g2p.error", "error": str(e)}) @@ -692,7 +680,7 @@ def _cache_phonemes(self, sentence, cache: TextToSpeechCache = None, phonemes=No return phoneme_file return None - def _cache_sentence(self, sentence, audio_file, cache, phonemes=None, sentence_hash=None): + def _cache_sentence(self, sentence, lang: str, audio_file, cache, phonemes=None, sentence_hash=None): """ Caches the sentence along with associated audio and phonemes. @@ -707,7 +695,7 @@ def _cache_sentence(self, sentence, audio_file, cache, phonemes=None, sentence_h # RANT: why do you hate strings ChrisV? if isinstance(audio_file.path, str): audio_file.path = Path(audio_file.path) - pho_file = self._cache_phonemes(sentence, cache, phonemes, sentence_hash) + pho_file = self._cache_phonemes(sentence, lang, cache, phonemes, sentence_hash) cache.cached_sentences[sentence_hash] = (audio_file, pho_file) self.add_metric({"metric_type": "tts.synth.cached"}) @@ -825,7 +813,6 @@ def get_voice(self, gender, lang=None): Returns: str: The selected voice. """ - lang = lang or self.lang return gender @deprecated("get_cache has been deprecated, use TTSContext().get_cache directly", @@ -894,6 +881,22 @@ def get_from_cache(self, sentence): """ return self._get_ctxt().get_from_cache(sentence, self.audio_ext, self.config) + @property + @deprecated("language is defined per request in get_tts, self.lang is not used", + "0.1.0") + def lang(self): + message = dig_for_message() + if message: + sess = SessionManager.get(message) + return sess.lang + return self.config.get("lang") or 'en-us' + + @lang.setter + @deprecated("language is defined per request in get_tts, self.lang is not used", + "0.1.0") + def lang(self, val): + LOG.warning("self.lang can not be set! it comes from the bus message") + class TTSValidator: """TTS Validator abstract class to be implemented by all TTS engines. @@ -1106,15 +1109,16 @@ async def generate_audio(self, sentence, wav_file, play_streaming=True, return wav_file def _execute(self, sentence, ident, listen, **kwargs): - sentence = self._replace_phonetic_spellings(sentence) - self.add_metric({"metric_type": "tts.preprocessed"}) - - sentence_hash = hash_sentence(sentence) # parse requested language for this TTS request ctxt = self._get_ctxt(kwargs) cache = ctxt.get_cache(self.audio_ext, self.config) + sentence = self._replace_phonetic_spellings(sentence, ctxt.lang) + self.add_metric({"metric_type": "tts.preprocessed"}) + + sentence_hash = hash_sentence(sentence) + # if cached, play existing file instead if self.enable_cache and sentence_hash in cache: super()._execute(sentence, ident, listen, From 4c9081938f2d7641a0e430c76dcc9935f8c948a0 Mon Sep 17 00:00:00 2001 From: miro Date: Thu, 25 Apr 2024 21:11:15 +0100 Subject: [PATCH 13/15] fix root_dir --- ovos_plugin_manager/templates/tts.py | 31 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index 8ab5ccce..04ef3e4a 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -3,6 +3,7 @@ import inspect import os.path import re +import sys import subprocess from os.path import isfile, join from pathlib import Path @@ -165,7 +166,7 @@ def __init__(self, lang=None, config=None, validator=None, log_deprecation("lang argument for TTS has been deprecated! it will be ignored, " "pass lang to get_tts directly instead") self.log_timestamps = False - + self.root_dir = os.path.dirname(os.path.abspath(sys.modules[self.__module__].__file__)) self.config = config or get_plugin_config(config, "tts") self.stopwatch = Stopwatch() @@ -177,7 +178,7 @@ def __init__(self, lang=None, config=None, validator=None, self.ssml_tags = ssml_tags or [] self.log_timestamps = self.config.get("log_timestamps", False) - self.enable_cache = self.config.get("enable_cache", False) + self.enable_cache = self.config.get("enable_cache", True) if TTS.queue is None: TTS.queue = Queue() @@ -420,7 +421,6 @@ def _init_playback(self, playback): TTS.playback = playback TTS.playback.set_bus(self.bus) - TTS.playback.attach_tts(self) if not TTS.playback.enclosure: TTS.playback.enclosure = EnclosureAPI(self.bus) @@ -440,18 +440,19 @@ def load_spellings(self, config=None) -> Dict[str, dict]: if config: LOG.warning("config argument is deprecated and unused!") spellings_data = {} - locale = f"{os.path.dirname(__file__)}/locale" - for lang in os.listdir(locale): - spellings_file = f"{locale}/{lang}/phonetic_spellings.txt" - if not os.path.isfile(spellings_file): - continue - try: - with open(spellings_file) as f: - lines = filter(bool, f.read().split('\n')) - lines = [i.split(':') for i in lines] - spellings_data[lang] = {key.strip(): value.strip() for key, value in lines} - except ValueError: - LOG.exception(f'Failed to load {lang} phonetic spellings.') + locale = f"{self.root_dir}/locale" + if os.path.isdir(locale): + for lang in os.listdir(locale): + spellings_file = f"{locale}/{lang}/phonetic_spellings.txt" + if not os.path.isfile(spellings_file): + continue + try: + with open(spellings_file) as f: + lines = filter(bool, f.read().split('\n')) + lines = [i.split(':') for i in lines] + spellings_data[lang] = {key.strip(): value.strip() for key, value in lines} + except ValueError: + LOG.exception(f'Failed to load {lang} phonetic spellings.') return spellings_data ## execution events From 6ead60791a9412e74ac7711498f20e9fff9ba72b Mon Sep 17 00:00:00 2001 From: miro Date: Thu, 25 Apr 2024 21:17:06 +0100 Subject: [PATCH 14/15] avoid reinits of playback thread --- ovos_plugin_manager/templates/tts.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index 04ef3e4a..b0425105 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -404,10 +404,9 @@ def init(self, bus=None, playback=None): if playback is None: LOG.warning("PlaybackThread should be inited by ovos-audio, initing via plugin has been deprecated, " "please pass playback=PlaybackThread() to TTS.init") - if TTS.playback: - playback.shutdown() - playback = PlaybackThread(TTS.queue, self.bus) # compat - playback.start() + if not TTS.playback: + playback = PlaybackThread(TTS.queue, self.bus) # compat + playback.start() self._init_playback(playback) self.add_metric({"metric_type": "tts.setup"}) From 3f77ce3d4465129af35e8eceaf8c02dd41e18d90 Mon Sep 17 00:00:00 2001 From: miro Date: Thu, 25 Apr 2024 21:26:29 +0100 Subject: [PATCH 15/15] add curate_caches helper for ovos-audio usage --- ovos_plugin_manager/templates/tts.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index b0425105..fe24bf86 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -127,6 +127,11 @@ def get_from_cache(self, sentence, audio_ext="wav", cache_config=None): phonemes = pho_file.load() return audio_file, phonemes + @classmethod + def curate_caches(cls): + for cache in TTSContext._caches.values(): + cache.curate() + class TTS: """TTS abstract class to be implemented by all TTS engines.