diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 89dc6e84..69a78d3a 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -33,7 +33,7 @@ jobs: strategy: max-parallel: 2 matrix: - python-version: [ 3.7, 3.8, 3.9, "3.10" ] + python-version: [ 3.8, 3.9, "3.10" ] runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index 6f91085b..fe24bf86 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -1,49 +1,29 @@ -""" -this module is meant to enable usage of mycroft plugins inside and outside -mycroft, importing from here will make things work as planned in mycroft, -but if outside mycroft things will still work - -The main use case is for plugins to be used across different projects - -## Differences from upstream - -TTS: -- added automatic guessing of phonemes/visime calculation, enabling mouth -movements for all TTS engines (only mimic implements this in upstream) -- playback start call has been omitted and moved to init method -- init is called by mycroft, but non mycroft usage wont call it -- outside mycroft the enclosure is not set, bus is dummy and playback thread is not used - - playback queue is not wanted when some module is calling get_tts - - if playback was started on init then python scripts would never stop - from mycroft.tts import TTSFactory - engine = TTSFactory.create() - engine.get_tts("hello world", "hello_world." + engine.audio_ext) - # would hang here - engine.playback.stop() -""" import abc import asyncio import inspect -import random +import os.path import re +import sys import subprocess from os.path import isfile, join from pathlib import Path from queue import Queue from threading import Thread -from typing import AsyncIterable +from typing import AsyncIterable, List, Dict import quebra_frases import requests from ovos_bus_client.apis.enclosure import EnclosureAPI from ovos_bus_client.message import Message, dig_for_message +from ovos_bus_client.session import SessionManager from ovos_config import Configuration +from ovos_config.locations import get_xdg_cache_save_path from ovos_utils import classproperty from ovos_utils.fakebus import FakeBus from ovos_utils.file_utils import get_cache_directory from ovos_utils.file_utils import resolve_resource_file from ovos_utils.lang.visimes import VISIMES -from ovos_utils.log import LOG +from ovos_utils.log import LOG, deprecated, log_deprecation from ovos_utils.metrics import Stopwatch from ovos_utils.process_utils import RuntimeRequirements @@ -56,83 +36,101 @@ SSML_TAGS = re.compile(r'<[^>]*>') -class PlaybackThread(Thread): - """ PlaybackThread moved to ovos_audio.playback - standalone plugin usage should rely on self.get_tts - ovos-audio relies on self.execute and needs this class +class TTSContext: + """ + A context manager for handling Text-To-Speech (TTS) operations and caching. - this class was only in ovos-plugin-manager in order to - patch usage of our plugins in mycroft-core""" + Attributes: + plugin_id (str): Identifier for the TTS plugin being used. + lang (str): Language code for the TTS operation. + voice (str): Identifier for the voice type in use. + synth_kwargs (dict): Optional dictionary containing additional keyword arguments for the TTS synthesizer. - def __new__(self, *args, **kwargs): - LOG.warning("PlaybackThread moved to ovos_audio.playback") - try: - from ovos_audio.playback import PlaybackThread - return PlaybackThread(*args, **kwargs) - except ImportError: - raise ImportError("please install ovos-audio for playback handling") + Class Attributes: + _caches (dict): A class-level dictionary acting as a cache store for different TTS contexts. + """ + _caches: Dict[str, TextToSpeechCache] = {} -class TTSContext: - """ parses kwargs for valid signatures and extracts voice/lang optional parameters - it will look for a requested voice in kwargs and inside the source Message data if available. - voice can also be defined by a combination of language and gender, - in that case the helper method get_voice will be used to resolve the final voice_id - """ + def __init__(self, plugin_id: str, lang: str, voice: str, synth_kwargs: dict = None): + """ + Initializes the TTSContext instance. - def __init__(self, engine): - self.engine = engine + Parameters: + plugin_id (str): The unique identifier for the TTS plugin. + lang (str): The language in which the text will be synthesized. + voice (str): The voice model to be used for text synthesis. + synth_kwargs (dict, optional): Additional keyword arguments for the synthesizer. + """ + self.plugin_id = plugin_id + self.lang = lang + self.voice = voice + self.synth_kwargs = synth_kwargs or {} - def get_message(self, kwargs): - msg = kwargs.get("message") or dig_for_message() - if msg and isinstance(msg, Message): - return msg + @property + def tts_id(self): + """ + Constructs a unique identifier for the TTS context based on plugin, voice, and language. - def get_lang(self, kwargs): - # parse requested language for this TTS request - # NOTE: this is ovos only functionality, not in mycroft-core! - lang = kwargs.get("lang") - message = self.get_message(kwargs) - if not lang and message: - # get lang from message object if possible - lang = message.data.get("lang") or \ - message.context.get("lang") - return lang or self.engine.lang - - def get_gender(self, kwargs): - gender = kwargs.get("gender") - message = self.get_message(kwargs) - if not gender and message: - # get gender from message object if possible - gender = message.data.get("gender") or \ - message.context.get("gender") - return gender + Returns: + str: A unique identifier that represents the TTS context. + """ + return join(self.plugin_id, self.voice, self.lang) - def get_voice(self, kwargs): - # parse requested voice for this TTS request - # NOTE: this is ovos only functionality, not in mycroft-core! - voice = kwargs.get("voice") - message = self.get_message(kwargs) - if not voice and message: - # get voice from message object if possible - voice = message.data.get("voice") or \ - message.context.get("voice") - - if not voice: - gender = self.get_gender(kwargs) - if gender: - lang = self.get_lang(kwargs) - voice = self.engine.get_voice(gender, lang) - - return voice or self.engine.voice - - def get(self, kwargs=None): - kwargs = kwargs or {} - return self.get_lang(kwargs), self.get_voice(kwargs) + def get_cache(self, audio_ext="wav", cache_config=None): + """ + Retrieves or creates a cache instance for the current TTS context. - def get_cache(self, kwargs=None): - lang, voice = self.get(kwargs) - return self.engine.get_cache(voice, lang) + Parameters: + audio_ext (str, optional): The file extension for the audio files (default is 'wav'). + cache_config (dict, optional): Configuration settings for the cache, including parameters like + minimum free percent, persistence settings, and cache directory path. + + Returns: + TextToSpeechCache: The cache instance associated with the current TTS context. + """ + cache_config = cache_config or { + "min_free_percent": 75, + "persist_cache": False, + "persist_thresh": 1, + "preloaded_cache": f"{get_xdg_cache_save_path()}/{self.tts_id}" + } + if self.tts_id not in TTSContext._caches: + TTSContext._caches[self.tts_id] = TextToSpeechCache( + cache_config, self.tts_id, audio_ext + ) + return self._caches[self.tts_id] + + def get_from_cache(self, sentence, audio_ext="wav", cache_config=None): + """ + Retrieves an audio file and phoneme data from the cache, based on the input sentence. + + Parameters: + sentence (str): The sentence for which to retrieve audio data. + audio_ext (str, optional): The file extension of the audio file (default is 'wav'). + cache_config (dict, optional): Configuration settings for the cache. + + Returns: + tuple: A tuple containing the path to the cached audio file and optionally the phoneme data. + + Raises: + FileNotFoundError: If the sentence is not found in the cache. + """ + sentence_hash = hash_sentence(sentence) + phonemes = None + cache = self.get_cache(audio_ext, cache_config) + if sentence_hash not in cache: + raise FileNotFoundError(f"sentence is not cached, {sentence_hash}.{audio_ext}") + audio_file, pho_file = cache.cached_sentences[sentence_hash] + LOG.info(f"Found {audio_file.name} in TTS cache") + if pho_file: + phonemes = pho_file.load() + return audio_file, phonemes + + @classmethod + def curate_caches(cls): + for cache in TTSContext._caches.values(): + cache.curate() class TTS: @@ -141,26 +139,44 @@ class TTS: It aggregates the minimum required parameters and exposes ``execute(sentence)`` and ``validate_ssml(sentence)`` functions. - Arguments: - lang (str): - config (dict): Configuration for this specific tts engine - validator (TTSValidator): Used to verify proper installation - phonetic_spelling (bool): Whether to spell certain words phonetically - ssml_tags (list): Supported ssml properties. Ex. ['speak', 'prosody'] + Attributes: + queue (Queue): A queue for managing TTS playback tasks. + playback (PlaybackThread): The playback thread used for TTS audio output. + + Args: + lang (str): The language code for the TTS engine. + config (dict): Configuration settings for the specific TTS engine. + validator (TTSValidator): Validator used to verify proper installation. + audio_ext (str): The default audio file extension (default is 'wav'). + phonetic_spelling (bool): Whether to spell certain words phonetically. + ssml_tags (list): Supported SSML properties (e.g., ['speak', 'prosody']). """ queue = None playback = None - def __init__(self, lang="en-us", config=None, validator=None, + def __init__(self, lang=None, config=None, validator=None, audio_ext='wav', phonetic_spelling=True, ssml_tags=None): - self.log_timestamps = False + """ + Initializes the TTS engine with specified parameters. + Args: + lang (str): The language code (deprecated). + config (dict): Configuration settings for the TTS engine. + validator (TTSValidator): Validator for verifying installation. + audio_ext (str): Default audio file extension (default is 'wav'). + phonetic_spelling (bool): Whether to use phonetic spelling (default is True). + ssml_tags (list): Supported SSML tags (default is None). + """ + if lang is not None: + log_deprecation("lang argument for TTS has been deprecated! it will be ignored, " + "pass lang to get_tts directly instead") + self.log_timestamps = False + self.root_dir = os.path.dirname(os.path.abspath(sys.modules[self.__module__].__file__)) self.config = config or get_plugin_config(config, "tts") self.stopwatch = Stopwatch() self.tts_name = self.__class__.__name__ - self.bus = FakeBus() # initialized in "init" step - self.lang = lang or self.config.get("lang") or 'en-us' + self.validator = validator or TTSValidator(self) self.phonetic_spelling = phonetic_spelling self.audio_ext = audio_ext @@ -169,205 +185,85 @@ def __init__(self, lang="en-us", config=None, validator=None, self.enable_cache = self.config.get("enable_cache", True) - self.voice = self.config.get("voice") or "default" - # TODO can self.filename be deprecated ? is it used anywhere at all? - cache_dir = get_cache_directory(self.tts_name) - self.filename = join(cache_dir, 'tts.' + self.audio_ext) - - random.seed() - if TTS.queue is None: TTS.queue = Queue() - self.context = TTSContext(self) - - # NOTE: self.playback.start() was moved to init method - # playback queue is not wanted if we only care about get_tts - # init is called by mycroft, but non mycroft usage wont call it, - # outside mycroft the enclosure is not set, bus is dummy and - # playback thread is not used - self.spellings = self.load_spellings() - - self.caches = { - self.tts_id: TextToSpeechCache( - self.config, self.tts_id, self.audio_ext - )} + self.spellings: Dict[str, dict] = self.load_spellings() + self._init_g2p() - cfg = Configuration() - g2pm = self.config.get("g2p_module") - if g2pm: - if g2pm in find_g2p_plugins(): - cfg.setdefault("g2p", {}) - globl = cfg["g2p"].get("module") or g2pm - if globl != g2pm: - LOG.info(f"TTS requested {g2pm} explicitly, ignoring global module {globl} ") - cfg["g2p"]["module"] = g2pm - else: - LOG.warning(f"TTS selected {g2pm}, but it is not available!") + self.add_metric({"metric_type": "tts.init"}) - try: - self.g2p = OVOSG2PFactory.create(cfg) - except: - LOG.exception("G2P plugin not loaded, there will be no mouth movements") - self.g2p = None + # unused by plugins, assigned in init method by ovos-audio, + # only present for backwards compat reasons + self.bus = None - self.cache.curate() + self._plugin_id = "" # the plugin name - self.add_metric({"metric_type": "tts.init"}) + @property + def plugin_id(self) -> str: + """ + Retrieves the plugin ID for the TTS engine. + Returns: + str: The plugin ID associated with the TTS engine. + """ + if not self._plugin_id: + from ovos_plugin_manager.tts import find_tts_plugins + for tts_id, clazz in find_tts_plugins().items(): + if isinstance(self, clazz): + self._plugin_id = tts_id + break + return self._plugin_id + + # methods for individual plugins to override @classproperty def runtime_requirements(self): - """ skill developers should override this if they do not require connectivity - some examples: - IOT plugin that controls devices via LAN could return: - scans_on_init = True - RuntimeRequirements(internet_before_load=False, - network_before_load=scans_on_init, - requires_internet=False, - requires_network=True, - no_internet_fallback=True, - no_network_fallback=False) - online search plugin with a local cache: - has_cache = False - RuntimeRequirements(internet_before_load=not has_cache, - network_before_load=not has_cache, - requires_internet=True, - requires_network=True, - no_internet_fallback=True, - no_network_fallback=True) - a fully offline plugin: - RuntimeRequirements(internet_before_load=False, - network_before_load=False, - requires_internet=False, - requires_network=False, - no_internet_fallback=True, - no_network_fallback=True) - """ + """ WIP - currently unused, + placeholder to allow plugins to request internet/gui before load + refer to skills to see how it is used""" return RuntimeRequirements() @property - def tts_id(self): - lang, voice = self.context.get() - return join(self.tts_name, voice, lang) - - @property - def cache(self): - return self.caches.get(self.tts_id) or \ - self.get_cache() - - @cache.setter - def cache(self, val): - self.caches[self.tts_id] = val - - def get_cache(self, voice=None, lang=None): - lang = lang or self.lang - voice = voice or self.voice or "default" - tts_id = join(self.tts_name, voice, lang) - if tts_id not in self.caches: - self.caches[tts_id] = TextToSpeechCache( - self.config, tts_id, self.audio_ext - ) - return self.caches[tts_id] - - def handle_metric(self, metadata=None): - """ receive timing metrics for diagnostics - does nothing by default but plugins might use it, eg, NeonCore""" - - def add_metric(self, metadata=None): - """ wraps handle_metric to catch exceptions and log timestamps """ - try: - self.handle_metric(metadata) - if self.log_timestamps: - LOG.debug(f"time delta: {self.stopwatch.delta} metric: {metadata}") - except Exception as e: - LOG.exception(e) - - def load_spellings(self, config=None): - """Load phonetic spellings of words as dictionary.""" - path = join('text', self.lang.lower(), 'phonetic_spellings.txt') - try: - spellings_file = resolve_resource_file(path, config=config or Configuration()) - except: - LOG.debug('Failed to locate phonetic spellings resouce file.') - return {} - if not spellings_file: - return {} - try: - with open(spellings_file) as f: - lines = filter(bool, f.read().split('\n')) - lines = [i.split(':') for i in lines] - return {key.strip(): value.strip() for key, value in lines} - except ValueError: - LOG.exception('Failed to load phonetic spellings.') - return {} - - def begin_audio(self): - """Helper function for child classes to call in execute()""" - self.stopwatch.start() - self.add_metric({"metric_type": "tts.start"}) - - def end_audio(self, listen=False): - """Helper cleanup function for child classes to call in execute(). - - Arguments: - listen (bool): DEPRECATED: indication if listening trigger should be sent. - """ - self.add_metric({"metric_type": "tts.end"}) - self.stopwatch.stop() - - def init(self, bus=None, playback=None): - """ Performs intial setup of TTS object. - - Arguments: - bus: OpenVoiceOS messagebus connection + def available_languages(self) -> set: + """Return languages supported by this TTS implementation in this state + This property should be overridden by the derived class to advertise + what languages that engine supports. + Returns: + set: A set of supported language codes. """ - self.bus = bus or FakeBus() - if playback is None: - LOG.warning("PlaybackThread should be inited by ovos-audio, initing via plugin has been deprecated, " - "please pass playback=PlaybackThread() to TTS.init") - if TTS.playback: - playback.shutdown() - playback = PlaybackThread(TTS.queue, self.bus) # compat - playback.start() - self._init_playback(playback) - self.add_metric({"metric_type": "tts.setup"}) - - def _init_playback(self, playback): - TTS.playback = playback - TTS.playback.set_bus(self.bus) - TTS.playback.attach_tts(self) - if not TTS.playback.enclosure: - TTS.playback.enclosure = EnclosureAPI(self.bus) + return set() - if not TTS.playback.is_alive(): - TTS.playback.start() + @abc.abstractmethod + def get_tts(self, sentence, wav_file, lang=None, voice=None): + """Abstract method that a tts implementation needs to implement. - @property - def enclosure(self): - if not TTS.playback.enclosure: - bus = TTS.playback.bus or self.bus - TTS.playback.enclosure = EnclosureAPI(bus) - return TTS.playback.enclosure + Args: + sentence (str): The input sentence to synthesize. + wav_file (str): The output file path for the synthesized audio. + lang (str, optional): The requested language (defaults to self.lang). + voice (str, optional): The requested voice (defaults to self.voice). - @enclosure.setter - def enclosure(self, val): - TTS.playback.enclosure = val + Returns: + tuple: (wav_file, phoneme) + """ + return "", None - @abc.abstractmethod - def get_tts(self, sentence, wav_file, lang=None): - """Abstract method that a tts implementation needs to implement. + def preprocess_sentence(self, sentence: str) -> List[str]: + """Default preprocessing is a sentence_tokenizer, + ie. splits the utterance into sub-sentences using quebra_frases - Should get data from tts. + This method can be overridden to create chunks suitable to the + TTS engine in question. Arguments: - sentence(str): Sentence to synthesize - wav_file(str): output file - lang(str): requested language (optional), defaults to self.lang + sentence (str): sentence to preprocess Returns: - tuple: (wav_file, phoneme) + list: list of sentence parts """ - return "", None + if self.config.get("sentence_tokenize"): # TODO default to True on next major release + return quebra_frases.sentence_tokenize(sentence) + return [sentence] def modify_tag(self, tag): """Override to modify each supported ssml tag. @@ -377,6 +273,19 @@ def modify_tag(self, tag): """ return tag + def handle_metric(self, metadata=None): + """ receive timing metrics for diagnostics + does nothing by default but plugins might use it, eg, NeonCore""" + + @property + def voice(self): + return self.config.get("voice") or "default" + + @voice.setter + def voice(self, val): + self.config["voice"] = val + + # SSML helpers @staticmethod def remove_ssml(text): """Removes SSML tags from a string. @@ -463,22 +372,121 @@ def validate_ssml(self, utterance): # return text with supported ssml tags only return utterance.replace(" ", " ") - def _preprocess_sentence(self, sentence): - """Default preprocessing is a sentence_tokenizer, - ie. splits the utterance into sub-sentences using quebra_frases + # init helpers + def _init_g2p(self): + """ + Initializes the grapheme-to-phoneme (G2P) conversion for the TTS engine. + """ + cfg = Configuration() + g2pm = self.config.get("g2p_module") + if g2pm: + if g2pm in find_g2p_plugins(): + cfg.setdefault("g2p", {}) + globl = cfg["g2p"].get("module") or g2pm + if globl != g2pm: + LOG.info(f"TTS requested {g2pm} explicitly, ignoring global module {globl} ") + cfg["g2p"]["module"] = g2pm + else: + LOG.warning(f"TTS selected {g2pm}, but it is not available!") - This method can be overridden to create chunks suitable to the - TTS engine in question. + try: + self.g2p = OVOSG2PFactory.create(cfg) + except: + LOG.debug("G2P plugin not loaded, there will be no mouth movements") + self.g2p = None + + def init(self, bus=None, playback=None): + """ Connects TTS object to PlaybackQueue in ovos-audio. + + This method needs to be called in order for self.execute to do anything + + not needed if using get_tts / synth methods directly as intended in standalone usage Arguments: - sentence (str): sentence to preprocess + bus: OpenVoiceOS messagebus connection + """ + self.bus = bus or FakeBus() + if playback is None: + LOG.warning("PlaybackThread should be inited by ovos-audio, initing via plugin has been deprecated, " + "please pass playback=PlaybackThread() to TTS.init") + if not TTS.playback: + playback = PlaybackThread(TTS.queue, self.bus) # compat + playback.start() + self._init_playback(playback) + self.add_metric({"metric_type": "tts.setup"}) + + def _init_playback(self, playback): + """ + Initializes the playback functionality for the TTS engine. + + Args: + playback: PlaybackThread instance. + """ + + TTS.playback = playback + TTS.playback.set_bus(self.bus) + if not TTS.playback.enclosure: + TTS.playback.enclosure = EnclosureAPI(self.bus) + + if not TTS.playback.is_alive(): + TTS.playback.start() + + def load_spellings(self, config=None) -> Dict[str, dict]: + """ + Loads phonetic spellings of words as a dictionary. + + Args: + config (dict, optional): Configuration settings. Returns: - list: list of sentence parts + dict: A dictionary of phonetic spellings. """ - if self.config.get("sentence_tokenize"): # TODO default to True on next major release - return quebra_frases.sentence_tokenize(sentence) - return [sentence] + if config: + LOG.warning("config argument is deprecated and unused!") + spellings_data = {} + locale = f"{self.root_dir}/locale" + if os.path.isdir(locale): + for lang in os.listdir(locale): + spellings_file = f"{locale}/{lang}/phonetic_spellings.txt" + if not os.path.isfile(spellings_file): + continue + try: + with open(spellings_file) as f: + lines = filter(bool, f.read().split('\n')) + lines = [i.split(':') for i in lines] + spellings_data[lang] = {key.strip(): value.strip() for key, value in lines} + except ValueError: + LOG.exception(f'Failed to load {lang} phonetic spellings.') + return spellings_data + + ## execution events + def add_metric(self, metadata=None): + """ + Wraps handle_metric to catch exceptions and log timestamps. + + Args: + metadata (dict, optional): Additional metadata for the metric. + """ + try: + self.handle_metric(metadata) + if self.log_timestamps: + LOG.debug(f"time delta: {self.stopwatch.delta} metric: {metadata}") + except Exception as e: + LOG.exception(e) + + def begin_audio(self): + """Helper function for child classes to call in execute()""" + self.stopwatch.start() + self.add_metric({"metric_type": "tts.start"}) + + def end_audio(self, listen=False): + """Helper cleanup function for child classes to call in execute(). + + Arguments: + listen (bool): DEPRECATED: indication if listening trigger should be sent. + """ + self.add_metric({"metric_type": "tts.end"}) + self.stopwatch.stop() def execute(self, sentence, ident=None, listen=False, **kwargs): """Convert sentence to speech, preprocessing out unsupported ssml @@ -498,142 +506,135 @@ def execute(self, sentence, ident=None, listen=False, **kwargs): self._execute(sentence, ident, listen, **kwargs) self.end_audio() - def _replace_phonetic_spellings(self, sentence): - if self.phonetic_spelling: + ## synth + def _replace_phonetic_spellings(self, sentence:str, lang: str) -> str: + if self.phonetic_spelling and lang in self.spellings: for word in re.findall(r"[\w']+", sentence): - if word.lower() in self.spellings: - spelled = self.spellings[word.lower()] + if word.lower() in self.spellings[lang]: + spelled = self.spellings[lang][word.lower()] sentence = sentence.replace(word, spelled) return sentence + def _get_visemes(self, phonemes, sentence, ctxt): + # get visemes/mouth movements + viseme = [] + if phonemes: + viseme = self.viseme(phonemes) + elif self.g2p is not None: + try: + viseme = self.g2p.utterance2visemes(sentence, ctxt.lang) + except OutOfVocabulary: + pass + except: + # this one is unplanned, let devs know all the info so they can fix it + LOG.exception(f"Unexpected failure in G2P plugin: {self.g2p}") + + if not viseme: + # Debug level because this is expected in default installs + LOG.debug(f"no mouth movements available! unknown visemes for {sentence}") + return viseme + + def _get_ctxt(self, kwargs=None) -> TTSContext: + """create a TTSContext from arbitrary kwargs passed to synth/execute methods + takes lang from Session into account if a message is present + """ + # get request specific synth params + kwargs = kwargs or {} + message = kwargs.get("message") or dig_for_message() + + # update kwargs from session + if message and "lang" not in kwargs: + sess = SessionManager.get(message) + kwargs["lang"] = sess.lang + + # filter kwargs accepted by this specific plugin + kwargs = {k: v for k, v in kwargs.items() + if k in inspect.signature(self.get_tts).parameters + and k not in ["sentence", "wav_file"]} + + LOG.debug(f"TTS kwargs: {kwargs}") + return TTSContext(plugin_id=self.plugin_id, + lang=kwargs.get("lang") or Configuration().get("lang", "en-us"), + voice=kwargs.get("voice") or self.voice, + synth_kwargs=kwargs) + def _execute(self, sentence, ident, listen, preprocess=True, **kwargs): + # get request specific synth params + ctxt = self._get_ctxt(kwargs) + if preprocess: - sentence = self._replace_phonetic_spellings(sentence) - chunks = self._preprocess_sentence(sentence) + # pre-process + sentence = self._replace_phonetic_spellings(sentence, ctxt.lang) + chunks = self.preprocess_sentence(sentence) # Apply the listen flag to the last chunk, set the rest to False chunks = [(chunks[i], listen if i == len(chunks) - 1 else False) for i in range(len(chunks))] + + # metrics timing callback self.add_metric({"metric_type": "tts.preprocessed", "n_chunks": len(chunks)}) else: chunks = [(sentence, listen)] - lang, voice = self.context.get(kwargs) - tts_id = join(self.tts_name, voice, lang) + message = kwargs.get("message") or \ + dig_for_message() or \ + Message("speak", context={"session": {"session_id": ident}}) # synth -> queue for playback for sentence, l in chunks: # load from cache or synth + cache - audio_file, phonemes = self.synth(sentence, **kwargs) + audio_file, phonemes = self.synth(sentence, ctxt) # get visemes/mouth movements - viseme = [] - if phonemes: - viseme = self.viseme(phonemes) - elif self.g2p is not None: - try: - viseme = self.g2p.utterance2visemes(sentence, lang) - except OutOfVocabulary: - pass - except: - # this one is unplanned, let devs know all the info so they can fix it - LOG.exception(f"Unexpected failure in G2P plugin: {self.g2p}") - - if not viseme: - # Debug level because this is expected in default installs - LOG.debug(f"no mouth movements available! unknown visemes for {sentence}") - - message = kwargs.get("message") or \ - dig_for_message() or \ - Message("speak", context={"session": {"session_id": ident}}) + viseme = self._get_visemes(phonemes, sentence, ctxt) + + # queue audio for playback TTS.queue.put( - (str(audio_file), viseme, l, tts_id, message) + (str(audio_file), viseme, l, ctxt.tts_id, message) ) + + # metrics timing callback self.add_metric({"metric_type": "tts.queued"}) - def synth(self, sentence, **kwargs): - """ This method wraps get_tts - several optional keyword arguments are supported - sentence will be read/saved to cache""" + def synth(self, sentence, ctxt: TTSContext = None, **kwargs): + """ + Synthesizes speech for the given sentence. wraps get_tts + + sentence will be read/saved to cache + + Args: + sentence (str): The sentence to synthesize. + ctxt (TTSContext): The TTS context. + **kwargs: Additional synth arguments for get_tts. + + Returns: + tuple: A tuple containing the path to the synthesized audio file and phoneme data. + """ self.add_metric({"metric_type": "tts.synth.start"}) sentence_hash = hash_sentence(sentence) - # parse requested language for this TTS request - # NOTE: this is ovos/neon only functionality, not in mycroft-core! - lang, voice = self.context.get(kwargs) - kwargs["lang"] = lang - kwargs["voice"] = voice - - cache = self.get_cache(voice, lang) # cache per tts_id (lang/voice combo) + # parse kwargs for this TTS request + ctxt = ctxt or self._get_ctxt(kwargs) + cache = ctxt.get_cache(self.audio_ext, self.config) # load from cache if self.enable_cache and sentence_hash in cache: - audio, phonemes = self.get_from_cache(sentence, **kwargs) + audio, phonemes = ctxt.get_from_cache(sentence, cache) self.add_metric({"metric_type": "tts.synth.finished", "cache": True}) return audio, phonemes # synth + cache audio = cache.define_audio_file(sentence_hash) - - # filter kwargs per plugin, different plugins expose different options - # mycroft-core -> no kwargs - # ovos -> lang + voice optional kwargs - # neon-core -> message - kwargs = {k: v for k, v in kwargs.items() - if k in inspect.signature(self.get_tts).parameters - and k not in ["sentence", "wav_file"]} - - # finally do the TTS synth - audio.path, phonemes = self.get_tts(sentence, str(audio), **kwargs) + audio.path, phonemes = self.get_tts(sentence, str(audio), + **ctxt.synth_kwargs) self.add_metric({"metric_type": "tts.synth.finished"}) # cache sentence + phonemes if self.enable_cache: - self._cache_sentence(sentence, audio, phonemes, sentence_hash, - voice=voice, lang=lang) + self._cache_sentence(sentence, ctxt.lang, audio, cache, + phonemes, sentence_hash) return audio, phonemes - def _cache_phonemes(self, sentence, phonemes=None, sentence_hash=None): - sentence_hash = sentence_hash or hash_sentence(sentence) - if not phonemes and self.g2p is not None: - try: - phonemes = self.g2p.utterance2arpa(sentence, self.lang) - self.add_metric({"metric_type": "tts.phonemes.g2p"}) - except Exception as e: - self.add_metric({"metric_type": "tts.phonemes.g2p.error", "error": str(e)}) - if phonemes: - return self.save_phonemes(sentence_hash, phonemes) - return None - - def _cache_sentence(self, sentence, audio_file, phonemes=None, sentence_hash=None, - voice=None, lang=None): - sentence_hash = sentence_hash or hash_sentence(sentence) - # RANT: why do you hate strings ChrisV? - if isinstance(audio_file.path, str): - audio_file.path = Path(audio_file.path) - pho_file = self._cache_phonemes(sentence, phonemes, sentence_hash) - cache = self.get_cache(voice=voice, lang=lang) - cache.cached_sentences[sentence_hash] = (audio_file, pho_file) - self.add_metric({"metric_type": "tts.synth.cached"}) - - def get_from_cache(self, sentence, **kwargs): - sentence_hash = hash_sentence(sentence) - phonemes = None - cache = self.context.get_cache(kwargs) - audio_file, pho_file = cache.cached_sentences[sentence_hash] - LOG.info(f"Found {audio_file.name} in TTS cache") - if not pho_file: - # guess phonemes from sentence + cache them - pho_file = self._cache_phonemes(sentence, sentence_hash) - if pho_file: - phonemes = pho_file.load() - return audio_file, phonemes - - def get_voice(self, gender, lang=None): - """ map a language and gender to a valid voice for this TTS engine """ - lang = lang or self.lang - return gender - def viseme(self, phonemes): """Create visemes from phonemes. @@ -660,31 +661,52 @@ def viseme(self, phonemes): float(0.2))) return visimes or None - def clear_cache(self): - """ Remove all cached files. """ - self.cache.clear() - - def save_phonemes(self, key, phonemes): - """Cache phonemes + ## cache + def _cache_phonemes(self, sentence, lang: str, cache: TextToSpeechCache = None, phonemes=None, sentence_hash=None): + """ + Caches phonemes for the given sentence. - Arguments: - key (str): Hash key for the sentence - phonemes (str): phoneme string to save + Args: + sentence (str): The sentence to cache phonemes for. + cache (TextToSpeechCache): The cache instance. + phonemes (str, optional): The phonemes for the sentence. + sentence_hash (str, optional): The hash of the sentence. """ - phoneme_file = self.cache.define_phoneme_file(key) - phoneme_file.save(phonemes) - return phoneme_file + sentence_hash = sentence_hash or hash_sentence(sentence) + if not phonemes and self.g2p is not None: + try: + phonemes = self.g2p.utterance2arpa(sentence, lang) + self.add_metric({"metric_type": "tts.phonemes.g2p"}) + except Exception as e: + self.add_metric({"metric_type": "tts.phonemes.g2p.error", "error": str(e)}) + if phonemes: + phoneme_file = cache.define_phoneme_file(sentence_hash) + phoneme_file.save(phonemes) + return phoneme_file + return None - def load_phonemes(self, key): - """Load phonemes from cache file. + def _cache_sentence(self, sentence, lang: str, audio_file, cache, phonemes=None, sentence_hash=None): + """ + Caches the sentence along with associated audio and phonemes. - Arguments: - key (str): Key identifying phoneme cache + Args: + sentence (str): The sentence to cache. + audio_file (AudioFile): The audio file associated with the sentence. + cache (TextToSpeechCache): The cache instance. + phonemes (str, optional): The phonemes for the sentence. + sentence_hash (str, optional): The hash of the sentence. """ - phoneme_file = self.cache.define_phoneme_file(key) - return phoneme_file.load() + sentence_hash = sentence_hash or hash_sentence(sentence) + # RANT: why do you hate strings ChrisV? + if isinstance(audio_file.path, str): + audio_file.path = Path(audio_file.path) + pho_file = self._cache_phonemes(sentence, lang, cache, phonemes, sentence_hash) + cache.cached_sentences[sentence_hash] = (audio_file, pho_file) + self.add_metric({"metric_type": "tts.synth.cached"}) + ## shutdown def stop(self): + """Stops the TTS playback.""" if TTS.playback: try: TTS.playback.stop() @@ -693,22 +715,192 @@ def stop(self): self.add_metric({"metric_type": "tts.stop"}) def shutdown(self): + """Shuts down the TTS engine.""" self.stop() if TTS.playback: TTS.playback.detach_tts(self) def __del__(self): + """Destructor for the TTS object.""" self.shutdown() + # below code is all deprecated and marked for removal in next stable release @property - def available_languages(self) -> set: - """Return languages supported by this TTS implementation in this state - This property should be overridden by the derived class to advertise - what languages that engine supports. + @deprecated("self.enclosure has been deprecated, use EnclosureAPI directly decoupled from the plugin code", + "0.1.0") + def enclosure(self): + """Deprecated. Accessor for the enclosure property. + Returns: - set: supported languages + EnclosureAPI: The EnclosureAPI instance associated with the TTS playback. """ - return set() + if not TTS.playback.enclosure: + bus = TTS.playback.bus or self.bus + TTS.playback.enclosure = EnclosureAPI(bus) + return TTS.playback.enclosure + + @enclosure.setter + @deprecated("self.enclosure has been deprecated, use EnclosureAPI directly decoupled from the plugin code", + "0.1.0") + def enclosure(self, val): + """Deprecated. Setter for the enclosure property. + + Arguments: + val (EnclosureAPI): The EnclosureAPI instance to set. + """ + TTS.playback.enclosure = val + + @property + @deprecated("self.filename has been deprecated, unused for a long time now", + "0.1.0") + def filename(self): + """Deprecated. Accessor for the filename property. + + Returns: + str: The filename for the TTS audio. + """ + cache_dir = get_cache_directory(self.tts_name) + return join(cache_dir, 'tts.' + self.audio_ext) + + @filename.setter + @deprecated("self.filename has been deprecated, unused for a long time now", + "0.1.0") + def filename(self, val): + """Deprecated. Setter for the filename property. + + Arguments: + val (str): The filename to set. + """ + + @property + @deprecated("self.tts_id has been deprecated, use TTSContext().tts_id", + "0.1.0") + def tts_id(self): + """Deprecated. Accessor for the tts_id property. + + Returns: + str: The ID associated with the TTS context. + """ + return self._get_ctxt().tts_id + + @property + @deprecated("self.cache has been deprecated, use TTSContext().get_cache", + "0.1.0") + def cache(self): + """Deprecated. Accessor for the cache property. + + Returns: + TextToSpeechCache: The cache associated with the TTS context. + """ + return TTSContext._caches.get(self.tts_id) or \ + self.get_cache() + + @cache.setter + @deprecated("self.cache has been deprecated, use TTSContext().get_cache", + "0.1.0") + def cache(self, val): + """Deprecated. Setter for the cache property. + + Arguments: + val (TextToSpeechCache): The cache to set. + """ + TTSContext._caches[self.tts_id] = val + + @deprecated("get_voice was never formally adopted and is unused, it will be removed", + "0.1.0") + def get_voice(self, gender, lang=None): + """Deprecated. Get a valid voice for the TTS engine. + + Arguments: + gender (str): Gender of the voice. + lang (str, optional): Language for the voice. Defaults to None. + + Returns: + str: The selected voice. + """ + return gender + + @deprecated("get_cache has been deprecated, use TTSContext().get_cache directly", + "0.1.0") + def get_cache(self, voice=None, lang=None): + """Deprecated. Get the cache associated with the TTS context. + + Arguments: + voice (str, optional): Voice for the cache. Defaults to None. + lang (str, optional): Language for the cache. Defaults to None. + + Returns: + TextToSpeechCache: The cache associated with the TTS context. + """ + return self._get_ctxt().get_cache(self.audio_ext, self.config) + + @deprecated("clear_cache has been deprecated, use TTSContext().get_cache directly", + "0.1.0") + def clear_cache(self): + """Deprecated. Clear all cached files.""" + cache = self._get_ctxt().get_cache(self.audio_ext, self.config) + cache.clear() + + @deprecated("save_phonemes has been deprecated, use TTSContext().get_cache directly", + "0.1.0") + def save_phonemes(self, key, phonemes): + """Deprecated. Cache phonemes. + + Arguments: + key (str): Hash key for the sentence. + phonemes (str): Phoneme string to save. + + Returns: + PhonemeFile: The PhonemeFile instance. + """ + cache = self._get_ctxt().get_cache(self.audio_ext, self.config) + phoneme_file = cache.define_phoneme_file(key) + phoneme_file.save(phonemes) + return phoneme_file + + @deprecated("load_phonemes has been deprecated, use TTSContext().get_cache directly", + "0.1.0") + def load_phonemes(self, key): + """Deprecated. Load phonemes from cache file. + + Arguments: + key (str): Key identifying phoneme cache. + + Returns: + str: Phonemes loaded from the cache file. + """ + cache = self._get_ctxt().get_cache(self.audio_ext, self.config) + phoneme_file = cache.define_phoneme_file(key) + return phoneme_file.load() + + @deprecated("get_from_cache has been deprecated, use TTSContext().get_from_cache directly", + "0.1.0") + def get_from_cache(self, sentence): + """Deprecated. Get data from the cache. + + Arguments: + sentence (str): Sentence used as cache key. + + Returns: + tuple: Tuple containing the audio and phonemes. + """ + return self._get_ctxt().get_from_cache(sentence, self.audio_ext, self.config) + + @property + @deprecated("language is defined per request in get_tts, self.lang is not used", + "0.1.0") + def lang(self): + message = dig_for_message() + if message: + sess = SessionManager.get(message) + return sess.lang + return self.config.get("lang") or 'en-us' + + @lang.setter + @deprecated("language is defined per request in get_tts, self.lang is not used", + "0.1.0") + def lang(self, val): + LOG.warning("self.lang can not be set! it comes from the bus message") class TTSValidator: @@ -815,6 +1007,7 @@ class RemoteTTSTimeoutException(RemoteTTSException): class StreamingTTSCallbacks: """handle the playback of streaming TTS, can be overrided in StreamingTTS""" + def __init__(self, bus, play_args=None, tts_config=None): self.bus = bus self.config = tts_config or {} @@ -831,12 +1024,12 @@ def stream_start(self, message=None): message = message or \ dig_for_message() or \ Message("speak") - + # we don't use the regular PlaybackThread here, we need to handle recognizer_loop:audio_output_start if not self.config.get("pulse_duck", False): self.bus.emit(message.forward("ovos.common_play.duck")) self.bus.emit(message.forward("recognizer_loop:audio_output_start")) - + if self._process: self.stream_stop() LOG.debug(f"stream playback command: {self.play_args}") @@ -860,7 +1053,7 @@ def stream_stop(self, listen=False, message=None): message = message or \ dig_for_message() or \ Message("speak") - + if self._process: self._process.stdin.close() self._process.wait() @@ -869,7 +1062,7 @@ def stream_stop(self, listen=False, message=None): # we don't use the regular PlaybackThread here, we need to handle recognizer_loop:audio_output_end and listen flag if not self.config.get("pulse_duck", False): self.bus.emit(message.forward("ovos.common_play.unduck")) - self.bus.emit(message.forward("recognizer_loop:audio_output_end")) + self.bus.emit(message.forward("recognizer_loop:audio_output_end")) if listen: self.bus.emit(message.forward('mycroft.mic.listen')) @@ -878,14 +1071,14 @@ class StreamingTTS(TTS): """ Abstract class for a Streaming TTS engine implementation. Audio is streamed in chunks as it becomes available instead of waiting the full sentence to be synthesized - + this plugin can be used in a synchronous way like any other plugin via self.get_tts(sentence, wav_file) - + to play audio as it becomes available use self.generate_audio(sentence, wav_file) NOTE: StreamingTTS does not support phonemes """ - + def init(self, bus=None, playback=None, callbacks=None): """ Performs intial setup of TTS object. @@ -899,10 +1092,10 @@ def init(self, bus=None, playback=None, callbacks=None): tts_config=self.config) @abc.abstractmethod - async def stream_tts(self, sentence) -> AsyncIterable[bytes]: + async def stream_tts(self, sentence, **kwargs) -> AsyncIterable[bytes]: """yield chunks of TTS audio as they become available""" raise NotImplementedError - + async def generate_audio(self, sentence, wav_file, play_streaming=True, listen=False, message=None, plugin_kwargs=None): """save streamed TTS to wav file, if configured also play TTS as it becomes available""" @@ -921,22 +1114,20 @@ async def generate_audio(self, sentence, wav_file, play_streaming=True, return wav_file def _execute(self, sentence, ident, listen, **kwargs): - sentence = self._replace_phonetic_spellings(sentence) - self.add_metric({"metric_type": "tts.preprocessed"}) - - sentence_hash = hash_sentence(sentence) # parse requested language for this TTS request - lang, voice = self.context.get(kwargs) - kwargs["lang"] = lang - kwargs["voice"] = voice + ctxt = self._get_ctxt(kwargs) + cache = ctxt.get_cache(self.audio_ext, self.config) - # get path to cache final synthesized file - cache = self.get_cache(voice, lang) # cache per tts_id (lang/voice combo) + sentence = self._replace_phonetic_spellings(sentence, ctxt.lang) + self.add_metric({"metric_type": "tts.preprocessed"}) + + sentence_hash = hash_sentence(sentence) # if cached, play existing file instead if self.enable_cache and sentence_hash in cache: - super()._execute(sentence, ident, listen, preprocess=False, **kwargs) + super()._execute(sentence, ident, listen, + preprocess=False, **ctxt.synth_kwargs) return wav_file = str(cache.define_audio_file(sentence_hash)) @@ -945,10 +1136,10 @@ def _execute(self, sentence, ident, listen, **kwargs): dig_for_message() or \ Message("speak") - # filter kwargs per plugin, different plugins expose different options - plugin_kwargs = {k: v for k, v in kwargs.items() - if k in inspect.signature(self.stream_tts).parameters - and k not in ["sentence", "wav_file", "play_streaming"]} + # filter kwargs accepted by this specific plugin + ctxt.synth_kwargs = {k: v for k, v in kwargs.items() + if k in inspect.signature(self.stream_tts).parameters + and k not in ["sentence"]} # handle streaming TTS loop = asyncio.new_event_loop() @@ -956,16 +1147,16 @@ def _execute(self, sentence, ident, listen, **kwargs): try: self.add_metric({"metric_type": "tts.stream.start"}) loop.run_until_complete( - self.generate_audio(sentence, wav_file, + self.generate_audio(sentence, wav_file, play_streaming=True, listen=listen, message=message, - plugin_kwargs=plugin_kwargs) + plugin_kwargs=ctxt.synth_kwargs) ) finally: loop.close() self.add_metric({"metric_type": "tts.stream.end"}) - + def get_tts(self, sentence, wav_file, **kwargs): """wrap streaming TTS into sync usage""" loop = asyncio.new_event_loop() @@ -981,6 +1172,8 @@ def get_tts(self, sentence, wav_file, **kwargs): return wav_file, None # No phonemes +# below classes are deprecated and will be removed in 0.1.0 + class RemoteTTS(TTS): """ Abstract class for a Remote TTS engine implementation. @@ -988,6 +1181,8 @@ class RemoteTTS(TTS): Usage is discouraged """ + @deprecated("RemoteTTS has been deprecated, please use the regular TTS class", + "0.1.0") def __init__(self, lang, config, url, api_path, validator): super(RemoteTTS, self).__init__(lang, config, validator) self.api_path = api_path @@ -1006,3 +1201,20 @@ def get_tts(self, sentence, wav_file, lang=None): with open(wav_file, 'wb') as f: f.write(r.content) return wav_file, None + + +class PlaybackThread(Thread): + """ PlaybackThread moved to ovos_audio.playback + standalone plugin usage should rely on self.get_tts + ovos-audio relies on self.execute and needs this class + + this class was only in ovos-plugin-manager in order to + patch usage of our plugins in mycroft-core""" + + def __new__(self, *args, **kwargs): + LOG.warning("PlaybackThread moved to ovos_audio.playback") + try: + from ovos_audio.playback import PlaybackThread + return PlaybackThread(*args, **kwargs) + except ImportError: + raise ImportError("please install ovos-audio for playback handling") diff --git a/ovos_plugin_manager/tts.py b/ovos_plugin_manager/tts.py index ba1983a2..fe453423 100644 --- a/ovos_plugin_manager/tts.py +++ b/ovos_plugin_manager/tts.py @@ -203,8 +203,8 @@ def create(config=None): clazz = OVOSTTSFactory.get_class(tts_config) if clazz: LOG.info(f'Found plugin {tts_module}') - tts = clazz(lang=None, # explicitly read lang from config - config=tts_config) + tts = clazz(config=tts_config) + tts._plugin_id = tts_module tts.validator.validate() LOG.info(f'Loaded plugin {tts_module}') else: diff --git a/test/unittests/test_tts.py b/test/unittests/test_tts.py index 27fe6e3d..88b7a29f 100644 --- a/test/unittests/test_tts.py +++ b/test/unittests/test_tts.py @@ -1,7 +1,13 @@ import unittest +from unittest.mock import MagicMock from unittest.mock import patch, Mock + +from ovos_bus_client.session import Session +from ovos_config import Configuration +from ovos_utils.fakebus import FakeBus, Message + +from ovos_plugin_manager.templates.tts import TTS, TTSContext from ovos_plugin_manager.utils import PluginTypes, PluginConfigTypes -from ovos_plugin_manager.templates.tts import TTS class TestTTSTemplate(unittest.TestCase): @@ -114,23 +120,23 @@ def test_format_speak_tags_with_speech_no_tags(self): self.assertEqual(tagged_with_exclusion, valid_output) def test_playback_thread(self): - from ovos_plugin_manager.templates.tts import PlaybackThread + pass # TODO - + def test_tts_context(self): - from ovos_plugin_manager.templates.tts import TTSContext + pass # TODO - + def test_tts_validator(self): - from ovos_plugin_manager.templates.tts import TTSValidator + pass # TODO - + def test_concat_tts(self): - from ovos_plugin_manager.templates.tts import ConcatTTS + pass # TODO - + def test_remote_tt(self): - from ovos_plugin_manager.templates.tts import RemoteTTS + pass # TODO @@ -187,15 +193,15 @@ def test_get_tts_config(self, get_config): self.CONFIG_SECTION, None) def test_get_voice_id(self): - from ovos_plugin_manager.tts import get_voice_id + pass # TODO def test_scan_voices(self): - from ovos_plugin_manager.tts import scan_voices + pass # TODO def test_get_voices(self): - from ovos_plugin_manager.tts import get_voices + pass # TODO @@ -246,13 +252,13 @@ def test_create(self, get_class): "config": True, "lang": "en-ca"} get_class.assert_called_once_with(expected_config) - plugin_class.assert_called_once_with(lang=None, config=expected_config) + plugin_class.assert_called_once_with(config=expected_config) self.assertEqual(plugin, plugin_class()) # Test create with TTS config and no module config plugin = OVOSTTSFactory.create(tts_config) get_class.assert_called_with(tts_config) - plugin_class.assert_called_with(lang=None, config=tts_config) + plugin_class.assert_called_with(config=tts_config) self.assertEqual(plugin, plugin_class()) # Test create with TTS config with module-specific config @@ -260,5 +266,110 @@ def test_create(self, get_class): expected_config = {"module": "test-tts-plugin-test", "config": True, "lang": "es-mx"} get_class.assert_called_with(expected_config) - plugin_class.assert_called_with(lang=None, config=expected_config) + plugin_class.assert_called_with(config=expected_config) self.assertEqual(plugin, plugin_class()) + + +class TestTTSContext(unittest.TestCase): + + @patch("ovos_plugin_manager.templates.tts.TextToSpeechCache", autospec=True) + def test_tts_context_get_cache(self, cache_mock): + tts_context = TTSContext("plug", "voice", "lang") + + result = tts_context.get_cache() + + self.assertEqual(result, cache_mock.return_value) + self.assertEqual(result, tts_context._caches[tts_context.tts_id]) + + +class TestTTSCache(unittest.TestCase): + def setUp(self): + self.tts_mock = TTS(config={"some_config_key": "some_config_value"}) + self.tts_mock.stopwatch = MagicMock() + self.tts_mock.queue = MagicMock() + self.tts_mock.playback = MagicMock() + + @patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash") + @patch("ovos_plugin_manager.templates.tts.TTSContext") + def test_tts_synth(self, tts_context_mock, hash_sentence_mock): + tts_context_mock.get_cache.return_value = MagicMock() + tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path" + + sentence = "Hello world!" + result = self.tts_mock.synth(sentence, tts_context_mock) + + tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config) + tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash") + self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None)) + + @patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash") + def test_tts_synth_cache_enabled(self, hash_sentence_mock): + tts_context_mock = MagicMock() + tts_context_mock.tts_id = "fake_tts_id" + tts_context_mock.get_cache.return_value = MagicMock() + tts_context_mock.get_cache.return_value.cached_sentences = {} + tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path" + tts_context_mock._caches = {tts_context_mock.tts_id: tts_context_mock.get_cache.return_value} + + sentence = "Hello world!" + self.tts_mock.enable_cache = True + result = self.tts_mock.synth(sentence, tts_context_mock) + + tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config) + tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash") + self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None)) + self.assertIn("fake_hash", tts_context_mock.get_cache.return_value.cached_sentences) + + @patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash") + def test_tts_synth_cache_disabled(self, hash_sentence_mock): + tts_context_mock = MagicMock() + tts_context_mock.tts_id = "fake_tts_id" + tts_context_mock.get_cache.return_value = MagicMock() + tts_context_mock.get_cache.return_value.cached_sentences = {} + tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path" + tts_context_mock._caches = {tts_context_mock.tts_id: tts_context_mock.get_cache.return_value} + + sentence = "Hello world!" + self.tts_mock.enable_cache = False + result = self.tts_mock.synth(sentence, tts_context_mock) + + tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config) + tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash") + self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None)) + self.assertNotIn("fake_hash", tts_context_mock.get_cache.return_value.cached_sentences) + + +class TestSession(unittest.TestCase): + def test_tts_session(self): + sess = Session(session_id="123", lang="en-us") + m = Message("speak", + context={"session": sess.serialize()}) + + tts = TTS() + self.assertEqual(tts.plugin_id, "ovos-tts-plugin-dummy") + self.assertEqual(tts.voice, "default") # no voice set + self.assertEqual(tts.lang, "en-us") # from config + + # test that session makes it all the way to the TTS.queue + kwargs = {"message": m} + tts.execute("test sentence", **kwargs) + path, visemes, listen, tts_id, message = tts.queue.get() + self.assertEqual(message, m) + self.assertEqual(message.context["session"]["session_id"], sess.session_id) + + # test that lang from Session is used + ctxt = tts._get_ctxt(kwargs) + self.assertEqual(ctxt.plugin_id, tts.plugin_id) + self.assertEqual(ctxt.lang, sess.lang) + self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/default/en-us") + self.assertEqual(ctxt.synth_kwargs, {'lang': 'en-us'}) + + sess = Session(session_id="123", + lang="klingon") + m = Message("speak", + context={"session": sess.serialize()}) + kwargs = {"message": m} + ctxt = tts._get_ctxt(kwargs) + self.assertEqual(ctxt.lang, sess.lang) + self.assertEqual(ctxt.tts_id, f"{tts.plugin_id}/default/klingon") + self.assertEqual(ctxt.synth_kwargs, {'lang': 'klingon'})