-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdoc-qa.py
165 lines (144 loc) · 7.24 KB
/
doc-qa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import streamlit as st
import chromadb
from PyPDF2 import PdfReader
import os
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain.schema import HumanMessage, SystemMessage
from langchain_community.embeddings import OllamaEmbeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.chat_models import ChatOllama
# Streamlit App Setup
st.title("Document Q&A with RAG Setup")
st.write("Upload PDF or TXT files to create embeddings and ask questions.")
# Ask user to choose between Google Generative AI, OpenAI, or Local LLM
model_choice = st.selectbox("Choose an AI model:", ["Google Generative AI", "OpenAI", "Local LLM (Ollama)"])
# API Key inputs based on model choice
chat_model = None
embeddings_model = None
if model_choice == "Google Generative AI":
api_key = st.text_input("Enter your Google Generative AI API key:", type="password")
if not api_key:
st.warning("Please enter your Google Generative AI API key to proceed.")
st.stop()
embeddings_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
chat_model = ChatGoogleGenerativeAI(model='gemini-1.5-pro-latest', google_api_key=api_key, temperature=0.8)
print("Using Google Generative AI model")
elif model_choice == "OpenAI":
api_key = st.text_input("Enter your OpenAI API key:", type="password")
if not api_key:
st.warning("Please enter your OpenAI API key to proceed.")
st.stop()
embeddings_model = OpenAIEmbeddings(api_key=api_key, model='text-embedding-ada-002')
chat_model = ChatOpenAI(api_key=api_key, model_name="gpt-3.5-turbo", temperature=0.8, max_tokens=None, timeout=None, max_retries=2)
print("Using OpenAI model")
elif model_choice == "Local LLM (Ollama)":
embeddings_model = OllamaEmbeddings(model="nomic-embed-text", show_progress=True)
chat_model = ChatOllama(model='mistral')
print("Using Local LLM (Ollama) model")
# Set up ChromaDB client if not already initialized
if 'chroma_client' not in st.session_state:
try:
print("Initializing ChromaDB client")
st.session_state['chroma_client'] = chromadb.Client(settings=chromadb.config.Settings(persist_directory="./chroma_db", anonymized_telemetry=False))
except ValueError as e:
st.error(f"An error occurred while setting up ChromaDB client: {e}")
st.stop()
chroma_client = st.session_state['chroma_client']
# Set up ChromaDB collection
if 'chroma_collection' not in st.session_state:
try:
print("Creating ChromaDB collection: document_embeddings")
st.session_state['chroma_collection'] = chroma_client.create_collection(name="document_embeddings")
except Exception as e:
if 'already exists' in str(e).lower():
print("ChromaDB collection already exists. Retrieving collection: document_embeddings")
st.session_state['chroma_collection'] = chroma_client.get_collection(name="document_embeddings")
else:
st.error("An error occurred while creating or accessing the collection.")
st.stop()
chroma_collection = st.session_state['chroma_collection']
uploaded_files = st.file_uploader("Upload your PDF or TXT files", type=["pdf", "txt"], accept_multiple_files=True)
# Helper Functions
def extract_text_from_pdf(file):
print(f"Extracting text from PDF: {file.name}")
pdf_reader = PdfReader(file)
text = ""
for page in pdf_reader.pages:
try:
page_text = page.extract_text()
if page_text:
text += page_text
else:
st.warning(f"Warning: Could not extract text from one of the pages in {file.name}")
except Exception as e:
st.warning(f"Warning: An error occurred while extracting text from a page in {file.name}: {e}")
print(f"Extracted text length from {file.name}: {len(text)} characters")
return text
def extract_text_from_txt(file):
print(f"Extracting text from TXT: {file.name}")
text = file.read().decode("utf-8")
print(f"Extracted text length from {file.name}: {len(text)} characters")
return text
def create_embeddings(content):
print("Creating embeddings for content")
embedding_id = str(hash(content))
embeddings = embeddings_model.embed_query(content)
chroma_collection.add(ids=[embedding_id], documents=[content], embeddings=[embeddings])
print("Embeddings created and added to ChromaDB collection")
def retrieve_relevant_context(query):
print(f"Retrieving relevant context for query: {query}")
query_embedding = embeddings_model.embed_query(query)
results = chroma_collection.query(query_embeddings=[query_embedding], n_results=5)
print(f"Number of relevant documents retrieved: {len(results['documents'])}")
return [doc[0] if isinstance(doc, list) else doc for doc in results['documents']]
def answer_query_with_context(query, context):
if model_choice in ["Google Generative AI", "OpenAI"]:
messages = [
SystemMessage(content="You are an assistant that answers questions based on the provided context."),
HumanMessage(content=f"Context: {context}\nQuestion: {query}\nAnswer:")
]
try:
print("Querying model with provided context and question")
response = chat_model(messages)
except Exception as e:
st.error(f"An error occurred while querying the model: {e}")
return "Error: Unable to get response."
return response.content.strip()
elif model_choice == "Local LLM (Ollama)":
prompt = f"Context: {context}\nQuestion: {query}\nAnswer:"
try:
print("Querying local LLM with provided context and question")
response = chat_model.invoke(prompt)
except Exception as e:
st.error(f"An error occurred while querying the local LLM: {e}")
return "Error: Unable to get response."
return response.content
# Process Uploaded Files
if uploaded_files:
for uploaded_file in uploaded_files:
if uploaded_file.type == "application/pdf":
text_content = extract_text_from_pdf(uploaded_file)
elif uploaded_file.type == "text/plain":
text_content = extract_text_from_txt(uploaded_file)
else:
st.error("Unsupported file format.")
continue
create_embeddings(text_content)
st.success(f"Embeddings created for {uploaded_file.name}")
# Ask Questions - with Chat History
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
user_query = st.text_input("Ask a question based on the uploaded documents:")
if user_query:
print(f"User query received: {user_query}")
relevant_contexts = retrieve_relevant_context(user_query)
combined_context = "\n".join(map(str, relevant_contexts))
answer = answer_query_with_context(user_query, combined_context)
st.session_state['chat_history'].append((user_query, answer))
print(f"Answer generated: {answer}")
# Display Chat History
for question, response in st.session_state['chat_history']:
st.write(f"**You:** {question}")
st.write(f"**Assistant:** {response}")
print(f"Chat history - Question: {question}, Response: {response}")