Skip to content

Commit

Permalink
refactor: separated history responsibility from the conversation handler
Browse files Browse the repository at this point in the history
  • Loading branch information
umbertogriffo committed Jan 27, 2025
1 parent f78a471 commit b68d440
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 62 deletions.
10 changes: 10 additions & 0 deletions chatbot/bot/conversation/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,13 @@ def append(self, msg: str):
if len(self) == self.total_length:
self.pop(0)
super().append(msg)

def __str__(self):
"""
Get the chat history as a single string.
Returns:
str: The chat history concatenated into a single string, with each message separated by a newline.
"""
chat_history = "\n".join([msg for msg in self])
return chat_history
Original file line number Diff line number Diff line change
Expand Up @@ -11,66 +11,43 @@
logger = get_logger(__name__)


class ConversationRetrieval:
class ConversationHandler:
"""
A class for managing conversation retrieval using a language model.
A class for managing a conversation using a large language model.
Attributes:
llm (LlmClient): The language model client for conversation-related tasks.
chat_history (List[Tuple[str, str]]): A list to store the conversation
history as tuples of questions and answers.
"""

def __init__(self, llm: LamaCppClient, chat_history: ChatHistory) -> None:
def __init__(self, llm: LamaCppClient) -> None:
"""
Initializes a new instance of the ConversationRetrieval class.
Initializes a new instance of the ConversationHandler class.
Args:
llm (LlmClient): The language model client for conversation-related tasks.
chat_history (ChatHistory): The chat history object to store conversation history.
"""
self.llm = llm
self.chat_history = chat_history

def get_chat_history(self) -> str:
"""
Retrieves the chat history as a single string.
Returns:
str: The chat history concatenated into a single string, with each message separated by a newline.
"""
chat_history = "\n".join([msg for msg in self.chat_history])
return chat_history

def append_chat_history(self, question: str, answer: str) -> None:
"""
Append a new question and answer to the chat history.

Args:
question (str): The question to add to the chat history.
answer (str): The answer to add to the chat history.
"""
self.chat_history.append(f"question: {question}, answer: {answer}")

def refine_question(self, question: str, max_new_tokens: int = 128) -> str:
def refine_question(self, question: str, chat_history: ChatHistory, max_new_tokens: int = 128) -> str:
"""
Refines the given question based on the chat history.
Args:
question (str): The original question.
chat_history (List[Tuple[str, str]]): A list to store the conversation
history as tuples of questions and answers.
max_new_tokens (int, optional): The maximum number of tokens to generate in the answer.
Defaults to 128.
Returns:
str: The refined question.
"""
chat_history = self.get_chat_history()

if chat_history:
logger.info("--- Refining the question based on the chat history... ---")

conversation_awareness_prompt = self.llm.generate_refined_question_conversation_awareness_prompt(
question, chat_history
question, str(chat_history)
)

logger.info(f"--- Prompt:\n {conversation_awareness_prompt} \n---")
Expand All @@ -83,12 +60,14 @@ def refine_question(self, question: str, max_new_tokens: int = 128) -> str:
else:
return question

def answer(self, question: str, max_new_tokens: int = 512) -> Any:
def answer(self, question: str, chat_history: ChatHistory, max_new_tokens: int = 512) -> Any:
"""
Generates an answer to the given question based on the chat history or a direct prompt.
Args:
question (str): The input question for which an answer is generated.
chat_history (List[Tuple[str, str]]): A list to store the conversation
history as tuples of questions and answers.
max_new_tokens (int, optional): The maximum number of tokens to generate in the answer.
Defaults to 512.
Expand All @@ -102,13 +81,12 @@ def answer(self, question: str, max_new_tokens: int = 512) -> Any:
If no chat history is available, a prompt is generated directly from the input question,
and the answer is generated accordingly.
"""
chat_history = self.get_chat_history()

if chat_history:
logger.info("--- Answer the question based on the chat history... ---")

conversation_awareness_prompt = self.llm.generate_refined_answer_conversation_awareness_prompt(
question, chat_history
question, str(chat_history)
)

logger.debug(f"--- Prompt:\n {conversation_awareness_prompt} \n---")
Expand All @@ -128,6 +106,7 @@ def context_aware_answer(
self,
ctx_synthesis_strategy: BaseSynthesisStrategy,
question: str,
chat_history: ChatHistory,
retrieved_contents: list[Document],
max_new_tokens: int = 512,
):
Expand All @@ -137,14 +116,16 @@ def context_aware_answer(
Args:
ctx_synthesis_strategy (BaseSynthesisStrategy): The strategy to use for context synthesis.
question (str): The input question for which an answer is generated.
chat_history (List[Tuple[str, str]]): A list to store the conversation
history as tuples of questions and answers.
retrieved_contents (list[Document]): A list of documents retrieved for context.
max_new_tokens (int, optional): The maximum number of tokens to generate in the answer. Defaults to 512.
Returns:
tuple: A tuple containing the answer streamer and formatted prompts.
"""
if not retrieved_contents:
return self.answer(question, max_new_tokens=max_new_tokens), []
return self.answer(question, chat_history, max_new_tokens=max_new_tokens), []

if isinstance(ctx_synthesis_strategy, AsyncTreeSummarizationStrategy):
loop = get_event_loop()
Expand Down
22 changes: 13 additions & 9 deletions chatbot/chatbot_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import streamlit as st
from bot.client.lama_cpp_client import LamaCppClient
from bot.conversation.chat_history import ChatHistory
from bot.conversation.conversation_retrieval import ConversationRetrieval
from bot.conversation.conversation_handler import ConversationHandler
from bot.model.model_registry import get_model_settings, get_models
from helpers.log import get_logger

Expand All @@ -30,8 +30,10 @@ def init_chat_history(total_length: int = 2) -> ChatHistory:


@st.cache_resource()
def load_conversational_retrieval(_llm: LamaCppClient, _chat_history: ChatHistory) -> ConversationRetrieval:
conversation_retrieval = ConversationRetrieval(_llm, _chat_history)
def load_conversational_retrieval(_llm: LamaCppClient) -> ConversationHandler:
conversation_retrieval = ConversationHandler(
_llm,
)
return conversation_retrieval


Expand Down Expand Up @@ -59,14 +61,14 @@ def init_welcome_message() -> None:
st.write("How can I help you today?")


def reset_chat_history(conversational_retrieval: ConversationRetrieval) -> None:
def reset_chat_history(chat_history: ChatHistory) -> None:
"""
Initializes the chat history, allowing users to clear the conversation.
"""
clear_button = st.sidebar.button("🗑️ Clear Conversation", key="clear")
if clear_button or "messages" not in st.session_state:
st.session_state.messages = []
conversational_retrieval.chat_history.clear()
chat_history.clear()


def display_messages_from_history():
Expand All @@ -87,8 +89,8 @@ def main(parameters) -> None:
init_page(root_folder)
llm = load_llm(model, model_folder)
chat_history = init_chat_history(2)
conversational_retrieval = load_conversational_retrieval(_llm=llm, _chat_history=chat_history)
reset_chat_history(conversational_retrieval)
conversational_retrieval = load_conversational_retrieval(_llm=llm)
reset_chat_history(chat_history)
init_welcome_message()
display_messages_from_history()

Expand All @@ -106,13 +108,15 @@ def main(parameters) -> None:
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
for token in conversational_retrieval.answer(question=user_input, max_new_tokens=max_new_tokens):
for token in conversational_retrieval.answer(
question=user_input, chat_history=chat_history, max_new_tokens=max_new_tokens
):
full_response += llm.parse_token(token)
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)

# Add assistant response to chat history
conversational_retrieval.append_chat_history(user_input, full_response)
chat_history.append(f"question: {user_input}, answer: {full_response}")
st.session_state.messages.append({"role": "assistant", "content": full_response})

took = time.time() - start_time
Expand Down
14 changes: 8 additions & 6 deletions chatbot/cli/rag_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from bot.client.lama_cpp_client import LamaCppClient
from bot.conversation.chat_history import ChatHistory
from bot.conversation.conversation_retrieval import ConversationRetrieval
from bot.conversation.conversation_handler import ConversationHandler
from bot.conversation.ctx_strategy import get_ctx_synthesis_strategies, get_ctx_synthesis_strategy
from bot.memory.embedder import Embedder
from bot.memory.vector_database.chroma import Chroma
Expand Down Expand Up @@ -70,7 +70,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_args()


def loop(conversation, synthesis_strategy, index, parameters) -> None:
def loop(conversation, chat_history, synthesis_strategy, index, parameters) -> None:
custom_fig = Figlet(font="graffiti")
console = Console(color_system="windows")
console.print(custom_fig.renderText("ChatBot"))
Expand All @@ -87,7 +87,7 @@ def loop(conversation, synthesis_strategy, index, parameters) -> None:
if question.lower() == "exit":
break

logger.info(f"--- Question: {question}, Chat_history: {conversation.get_chat_history()} ---")
logger.info(f"--- Question: {question}, Chat_history: {chat_history} ---")

start_time = time.time()
refined_question = conversation.refine_question(question)
Expand All @@ -112,7 +112,9 @@ def loop(conversation, synthesis_strategy, index, parameters) -> None:
answer += parsed_token
print(parsed_token, end="", flush=True)

conversation.append_chat_history(refined_question, answer)
chat_history.append(
f"question: {refined_question}, answer: {answer}",
)

console.print("\n[bold magenta]Formatted Answer:[/bold magenta]")
if answer:
Expand All @@ -134,12 +136,12 @@ def main(parameters):

synthesis_strategy = get_ctx_synthesis_strategy(parameters.synthesis_strategy, llm=llm)
chat_history = ChatHistory(total_length=2)
conversation = ConversationRetrieval(llm, chat_history)
conversation = ConversationHandler(llm)

embedding = Embedder()
index = Chroma(persist_directory=str(vector_store_path), embedding=embedding)

loop(conversation, synthesis_strategy, index, parameters)
loop(conversation, chat_history, synthesis_strategy, index, parameters)


if __name__ == "__main__":
Expand Down
22 changes: 11 additions & 11 deletions chatbot/rag_chatbot_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import streamlit as st
from bot.client.lama_cpp_client import LamaCppClient
from bot.conversation.chat_history import ChatHistory
from bot.conversation.conversation_retrieval import ConversationRetrieval
from bot.conversation.conversation_handler import ConversationHandler
from bot.conversation.ctx_strategy import (
BaseSynthesisStrategy,
get_ctx_synthesis_strategies,
Expand Down Expand Up @@ -36,8 +36,8 @@ def init_chat_history(total_length: int = 2) -> ChatHistory:


@st.cache_resource()
def load_conversational_retrieval(_llm: LamaCppClient, _chat_history: ChatHistory) -> ConversationRetrieval:
conversation_retrieval = ConversationRetrieval(_llm, _chat_history)
def load_conversational_retrieval(_llm: LamaCppClient) -> ConversationHandler:
conversation_retrieval = ConversationHandler(_llm)
return conversation_retrieval


Expand Down Expand Up @@ -93,14 +93,14 @@ def init_welcome_message() -> None:
st.write("How can I help you today?")


def reset_chat_history(conversational_retrieval: ConversationRetrieval) -> None:
def reset_chat_history(chat_history: ChatHistory) -> None:
"""
Initializes the chat history, allowing users to clear the conversation.
"""
clear_button = st.sidebar.button("Clear Conversation", key="clear")
clear_button = st.sidebar.button("🗑️ Clear Conversation", key="clear")
if clear_button or "messages" not in st.session_state:
st.session_state.messages = []
conversational_retrieval.chat_history.clear()
chat_history.clear()


def display_messages_from_history():
Expand Down Expand Up @@ -130,10 +130,10 @@ def main(parameters) -> None:
init_page(root_folder)
llm = load_llm_client(model_folder, model_name)
chat_history = init_chat_history(2)
conversational_retrieval = load_conversational_retrieval(_llm=llm, _chat_history=chat_history)
conversational_retrieval = load_conversational_retrieval(_llm=llm)
ctx_synthesis_strategy = load_ctx_synthesis_strategy(synthesis_strategy_name, _llm=llm)
index = load_index(vector_store_path)
reset_chat_history(conversational_retrieval)
reset_chat_history(chat_history)
init_welcome_message()
display_messages_from_history()

Expand All @@ -153,7 +153,7 @@ def main(parameters) -> None:
with st.spinner(
text="Refining the question and Retrieving the docs – hang tight! " "This should take seconds."
):
refined_user_input = conversational_retrieval.refine_question(user_input)
refined_user_input = conversational_retrieval.refine_question(user_input, chat_history=chat_history)
retrieved_contents, sources = index.similarity_search_with_threshold(
query=refined_user_input, k=parameters.k
)
Expand All @@ -179,15 +179,15 @@ def main(parameters) -> None:
full_response = ""
with st.spinner(text="Refining the context and Generating the answer for each text chunk – hang tight! "):
streamer, fmt_prompts = conversational_retrieval.context_aware_answer(
ctx_synthesis_strategy, user_input, retrieved_contents
ctx_synthesis_strategy, user_input, chat_history, retrieved_contents
)
for token in streamer:
full_response += llm.parse_token(token)
message_placeholder.markdown(full_response + "▌")

message_placeholder.markdown(full_response)

conversational_retrieval.append_chat_history(user_input, full_response)
chat_history.append(f"question: {user_input}, answer: {full_response}")
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": full_response})
took = time.time() - start_time
Expand Down

0 comments on commit b68d440

Please sign in to comment.