From cf8ab85beea7fcffedca265bcd01abe2250dad00 Mon Sep 17 00:00:00 2001 From: Alex-Karmazin Date: Fri, 7 Jun 2024 02:23:00 +0300 Subject: [PATCH] Refactoring --- just_agents/llm_session.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/just_agents/llm_session.py b/just_agents/llm_session.py index 48dabf3..0745a48 100644 --- a/just_agents/llm_session.py +++ b/just_agents/llm_session.py @@ -51,7 +51,7 @@ def instruct(self, prompt: str): self.memory.add_message(system_instruction, True) return system_instruction - def query(self, prompt: str = None, stream: bool = False, run_callbacks: bool = True, output: Optional[Path] = None) -> str: + def query(self, prompt: str, stream: bool = False, run_callbacks: bool = True, output: Optional[Path] = None) -> str: """ Query large language model :param prompt: @@ -59,9 +59,18 @@ def query(self, prompt: str = None, stream: bool = False, run_callbacks: bool = :param run_callbacks: :return: """ - if prompt is not None: - question = Message(role="user", content=prompt) - self.memory.add_message(question) + + question = Message(role="user", content=prompt) + self.memory.add_message(question, run_callbacks) + return self._query(stream, run_callbacks, output) + + + def query_all(self, messages: list, stream: bool = False, run_callbacks: bool = True, output: Optional[Path] = None) -> str: + self.memory.add_messages(messages, run_callbacks) + return self._query(stream, run_callbacks, output) + + + def _query(self, stream: bool = False, run_callbacks: bool = True, output: Optional[Path] = None) -> str: options: Dict = self.llm_options response: ModelResponse = completion(messages=self.memory.messages, stream=stream, **options) self._process_response(response) @@ -71,7 +80,8 @@ def query(self, prompt: str = None, stream: bool = False, run_callbacks: bool = self._process_response(response) answer = self.message_from_response(response) self.memory.add_message(answer, run_callbacks) - result: str = self.memory.last_message.content if self.memory.last_message is not None and self.memory.last_message.content is not None else str(self.memory.last_message) + result: str = self.memory.last_message.content if self.memory.last_message is not None and self.memory.last_message.content is not None else str( + self.memory.last_message) if output is not None: output.write_text(result) return result