Skip to content

Commit

Permalink
Merge pull request #109 from fengsh27/main
Browse files Browse the repository at this point in the history
Refactor and Integrate RagAgent to Conversation
  • Loading branch information
slobentanzer authored Jan 30, 2024
2 parents d45292e + 1c4d8cc commit 13945c2
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 172 deletions.
3 changes: 3 additions & 0 deletions biochatter/database_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
schema_config_or_info_dict=schema_config_or_info_dict,
)
self.connection_args = connection_args
self.driver = None

def connect(self) -> None:
"""
Expand All @@ -40,6 +41,8 @@ def connect(self) -> None:
user=user,
password=password,
)
def is_connected(self) -> bool:
return not self.driver is None

def get_query_results(self, query: str, k: int = 3) -> list[Document]:
"""
Expand Down
116 changes: 53 additions & 63 deletions biochatter/llm_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
st = None

from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, List, Tuple
import openai

from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
Expand All @@ -22,6 +22,7 @@
import json

from .vectorstore import DocumentEmbedder
from .rag_agent import RagAgent
from ._stats import get_stats

OPENAI_MODELS = [
Expand Down Expand Up @@ -74,14 +75,13 @@ def __init__(
prompts: dict,
correct: bool = True,
split_correction: bool = False,
rag_agent: DocumentEmbedder = None,
):
super().__init__()
self.model_name = model_name
self.prompts = prompts
self.correct = correct
self.split_correction = split_correction
self.rag_agent = rag_agent
self.rag_agents: List[RagAgent] = []
self.history = []
self.messages = []
self.ca_messages = []
Expand All @@ -90,6 +90,25 @@ def __init__(
def set_user_name(self, user_name: str):
self.user_name = user_name

def set_rag_agent(self, agent: RagAgent):
"""
Update or insert rag_agent: if the rag_agent with the same mode already
exists, it will be updated. Otherwise, the new rag_agent will be inserted.
"""
i, _ = self._find_rag_agent(agent.mode)
if i < 0:
# insert
self.rag_agents.append(agent)
else:
# update
self.rag_agents[i] = agent

def _find_rag_agent(self, mode: str) -> Tuple[int, RagAgent]:
for i, val in enumerate(self.rag_agents):
if val.mode == mode:
return i, val
return -1, None

@abstractmethod
def set_api_key(self, api_key: str, user: Optional[str] = None):
pass
Expand All @@ -100,9 +119,6 @@ def get_prompts(self):
def set_prompts(self, prompts: dict):
self.prompts = prompts

def set_rag_agent(self, rag_agent: DocumentEmbedder):
self.rag_agent = rag_agent

def append_ai_message(self, message: str):
self.messages.append(
AIMessage(
Expand Down Expand Up @@ -160,12 +176,10 @@ def setup_data_input_tool(self, df, input_file_name: str):
msg = self.prompts["tool_prompts"][tool_name].format(df=df)
self.append_system_message(msg)

def query(self, text: str, collection_name: Optional[str] = None):
def query(self, text: str):
self.append_user_message(text)

if self.rag_agent:
if self.rag_agent.use_prompt:
self._inject_context(text, collection_name)
self._inject_context(text)

msg, token_usage = self._primary_query()

Expand Down Expand Up @@ -221,10 +235,11 @@ def _primary_query(self, text: str):
def _correct_response(self, msg: str):
pass

def _inject_context(self, text: str, collection_name: Optional[str] = None):
def _inject_context(self, text: str):
"""
Inject the context into the prompt from vector database similarity
search. Finds the most similar n text fragments and adds them to the
Inject the context received from the RAG agent into the prompt. The RAG
agent will find the most similar n text fragments and add them to the
message history object for usage in the next prompt. Uses the document
summarisation prompt set to inject the context. The ultimate prompt
should include the placeholder for the statements, `{statements}` (used
Expand All @@ -233,38 +248,30 @@ def _inject_context(self, text: str, collection_name: Optional[str] = None):
Args:
text (str): The user query to be used for similarity search.
"""
if not self.rag_agent.used:
st.info(
"No document has been analysed yet. To use retrieval augmented "
"generation, please analyse at least one document first."
)
return

sim_msg = (
f"Performing similarity search to inject {self.rag_agent.n_results}"
" fragments ..."
)
sim_msg = f"Performing similarity search to inject fragments ..."

if st:
with st.spinner(sim_msg):
statements = [
doc.page_content
for doc in self.rag_agent.similarity_search(
text,
self.rag_agent.n_results,
)
]
statements = []
for agent in self.rag_agents:
if not agent.use_prompt:
continue
statements = statements + [
doc[0] for doc in agent.generate_responses(text)
]

else:
statements = [
doc.page_content
for doc in self.rag_agent.similarity_search(
text,
self.rag_agent.n_results,
)
]

prompts = self.prompts["rag_agent_prompts"]
if statements:
statements = []
for agent in self.rag_agents:
if not agent.use_prompt:
continue
statements = statements + [
doc[0] for doc in agent.generate_responses(text)
]

if statements and len(statements) > 0:
prompts = self.prompts["rag_agent_prompts"]
self.current_statements = statements
for i, prompt in enumerate(prompts):
# if last prompt, format the statements into the prompt
Expand Down Expand Up @@ -317,7 +324,6 @@ def __init__(
prompts: dict,
correct: bool = True,
split_correction: bool = False,
rag_agent: DocumentEmbedder = None,
):
"""
Expand All @@ -335,10 +341,9 @@ def __init__(
prompts=prompts,
correct=correct,
split_correction=split_correction,
rag_agent=rag_agent,
)

def query(self, text: str, collection_name: Optional[str] = None):
def query(self, text: str):
"""
Return the entire message history as a single string. This is the
message that is sent to the wasm model.
Expand All @@ -355,9 +360,7 @@ def query(self, text: str, collection_name: Optional[str] = None):
"""
self.append_user_message(text)

if self.rag_agent:
if self.rag_agent.use_prompt:
self._inject_context(text, collection_name)
self._inject_context(text)

return (self._primary_query(), None, None)

Expand Down Expand Up @@ -389,7 +392,6 @@ def __init__(
model_name: str = "auto",
correct: bool = True,
split_correction: bool = False,
rag_agent: DocumentEmbedder = None,
):
"""
Expand All @@ -415,17 +417,14 @@ def __init__(
splitting the output into sentences and correcting each sentence
individually.
rag_agent (DocumentEmbedder): A RAG agent to use for retieval
augmented generation.
"""
from xinference.client import Client

super().__init__(
model_name=model_name,
prompts=prompts,
correct=correct,
split_correction=split_correction,
rag_agent=rag_agent,
)
self.client = Client(base_url=base_url)

Expand Down Expand Up @@ -737,7 +736,6 @@ def __init__(
prompts: dict,
correct: bool = True,
split_correction: bool = False,
rag_agent: DocumentEmbedder = None,
):
"""
Connect to OpenAI's GPT API and set up a conversation with the user.
Expand All @@ -752,16 +750,12 @@ def __init__(
split_correction (bool): Whether to correct the model output by
splitting the output into sentences and correcting each
sentence individually.
rag_agent (DocumentEmbedder): A RAG agent to use for
retrieval augmented generation (RAG).
"""
super().__init__(
model_name=model_name,
prompts=prompts,
correct=correct,
split_correction=split_correction,
rag_agent=rag_agent,
)

self.ca_model_name = "gpt-3.5-turbo"
Expand Down Expand Up @@ -903,7 +897,6 @@ def __init__(
prompts: dict,
correct: bool = True,
split_correction: bool = False,
rag_agent: DocumentEmbedder = None,
version: Optional[str] = None,
base_url: Optional[str] = None,
):
Expand All @@ -923,9 +916,6 @@ def __init__(
splitting the output into sentences and correcting each
sentence individually.
rag_agent (DocumentEmbedder): A vector database connection to use for
retrieval augmented generation (RAG).
version (str): The version of the Azure API to use.
base_url (str): The base URL of the Azure API to use.
Expand All @@ -935,7 +925,6 @@ def __init__(
prompts=prompts,
correct=correct,
split_correction=split_correction,
rag_agent=rag_agent,
)

self.version = version
Expand Down Expand Up @@ -994,13 +983,14 @@ def __init__(
model_name: str,
prompts: dict,
split_correction: bool,
rag_agent: DocumentEmbedder = None,
):
"""
DEPRECATED: Superceded by XinferenceConversation.
"""
super().__init__(
model_name=model_name,
prompts=prompts,
split_correction=split_correction,
rag_agent=rag_agent,
)

self.messages = []
Expand Down
Loading

0 comments on commit 13945c2

Please sign in to comment.