Skip to content

Commit

Permalink
streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Nov 13, 2024
1 parent 9dd869f commit 48143e6
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions ovos_plugin_manager/templates/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Optional, List, Iterable, Tuple, Dict, Union, Any

from json_database import JsonStorageXDG
from ovos_utils.log import LOG, log_deprecation
from ovos_utils.lang import standardize_lang_tag
from ovos_utils.log import LOG, log_deprecation
from ovos_utils.xdg_utils import xdg_cache_home

from ovos_plugin_manager.templates.language import LanguageTranslator, LanguageDetector
Expand Down Expand Up @@ -407,21 +407,44 @@ class ChatMessageSolver(QuestionSolver):
{"role": "user", "content": "Orange."},
]
"""

@abc.abstractmethod
def continue_chat(self, messages: List[Dict[str, str]],
lang: Optional[str]) -> Optional[str]:
lang: Optional[str],
units: Optional[str] = None) -> Optional[str]:
pass

@auto_detect_lang(text_keys=["messages"])
@auto_translate(translate_keys=["messages"])
def get_chat_completion(self, messages: List[Dict[str, str]],
lang: Optional[str] = None) -> Optional[str]:
return self.continue_chat(messages=messages, lang=lang)
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
return self.continue_chat(messages=messages, lang=lang, units=units)

@auto_detect_lang(text_keys=["query"])
@auto_translate(translate_keys=["query"])
def get_spoken_answer(self, query: str, lang: Optional[str] = None) -> Optional[str]:
return self.continue_chat(messages=[{"role": "user", "content": query}], lang=lang)
@_deprecate_context2lang()
def stream_utterances(self, messages: List[Dict[str, str]],
lang: Optional[str] = None,
units: Optional[str] = None) -> Iterable[str]:
"""
Stream utterances for the given query as they become available.
Args:
messages: The chat messages.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.
Returns:
Iterable[str]: An iterable of utterances.
"""
ans = _call_with_sanitized_kwargs(self.get_chat_completion, messages, lang=lang, units=units)
for utt in self.sentence_split(ans):
yield utt

def get_spoken_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
# just for api compat since it's a subclass, shouldn't be directly used
return self.continue_chat(messages=[{"role": "user", "content": query}], lang=lang, units=units)


class CorpusSolver(QuestionSolver):
Expand Down

0 comments on commit 48143e6

Please sign in to comment.