forked from tonykipkemboi/ollama_pdf_rag
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstreamlit_app.py
355 lines (296 loc) · 12.3 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
"""
Streamlit application for PDF-based Retrieval-Augmented Generation (RAG) using Ollama + LangChain.
This application allows users to upload a PDF, process it,
and then ask questions about the content using a selected language model.
"""
import streamlit as st
import logging
import os
import tempfile
import shutil
import pdfplumber
import ollama
from langchain_community.document_loaders import UnstructuredPDFLoader
from langchain_ollama import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_ollama.chat_models import ChatOllama
from langchain_core.runnables import RunnablePassthrough
from langchain.retrievers.multi_query import MultiQueryRetriever
from typing import List, Tuple, Dict, Any, Optional
# Set protobuf environment variable to avoid error messages
# This might cause some issues with latency but it's a tradeoff
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# Streamlit page configuration
st.set_page_config(
page_title="Ollama PDF RAG Streamlit UI",
page_icon="🎈",
layout="wide",
initial_sidebar_state="collapsed",
)
# Logging configuration
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
@st.cache_resource(show_spinner=True)
def extract_model_names(
models_info: Dict[str, List[Dict[str, Any]]],
) -> Tuple[str, ...]:
"""
Extract model names from the provided models information.
Args:
models_info (Dict[str, List[Dict[str, Any]]]): Dictionary containing information about available models.
Returns:
Tuple[str, ...]: A tuple of model names.
"""
logger.info("Extracting model names from models_info")
model_names = tuple(model["name"] for model in models_info["models"])
logger.info(f"Extracted model names: {model_names}")
return model_names
def create_vector_db(file_upload) -> Chroma:
"""
Create a vector database from an uploaded PDF file.
Args:
file_upload (st.UploadedFile): Streamlit file upload object containing the PDF.
Returns:
Chroma: A vector store containing the processed document chunks.
"""
logger.info(f"Creating vector DB from file upload: {file_upload.name}")
temp_dir = tempfile.mkdtemp()
path = os.path.join(temp_dir, file_upload.name)
with open(path, "wb") as f:
f.write(file_upload.getvalue())
logger.info(f"File saved to temporary path: {path}")
loader = UnstructuredPDFLoader(path)
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100)
chunks = text_splitter.split_documents(data)
logger.info("Document split into chunks")
# Updated embeddings configuration
embeddings = OllamaEmbeddings(model="nomic-embed-text")
vector_db = Chroma.from_documents(
documents=chunks,
embedding=embeddings,
collection_name="myRAG"
)
logger.info("Vector DB created")
shutil.rmtree(temp_dir)
logger.info(f"Temporary directory {temp_dir} removed")
return vector_db
def process_question(question: str, vector_db: Chroma, selected_model: str) -> str:
"""
Process a user question using the vector database and selected language model.
Args:
question (str): The user's question.
vector_db (Chroma): The vector database containing document embeddings.
selected_model (str): The name of the selected language model.
Returns:
str: The generated response to the user's question.
"""
logger.info(f"Processing question: {question} using model: {selected_model}")
# Initialize LLM
llm = ChatOllama(model=selected_model)
# Query prompt template
QUERY_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an AI language model assistant. Your task is to generate 2
different versions of the given user question to retrieve relevant documents from
a vector database. By generating multiple perspectives on the user question, your
goal is to help the user overcome some of the limitations of the distance-based
similarity search. Provide these alternative questions separated by newlines.
Original question: {question}""",
)
# Set up retriever
retriever = MultiQueryRetriever.from_llm(
vector_db.as_retriever(),
llm,
prompt=QUERY_PROMPT
)
# RAG prompt template
template = """Answer the question based ONLY on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
# Create chain
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
response = chain.invoke(question)
logger.info("Question processed and response generated")
return response
@st.cache_data
def extract_all_pages_as_images(file_upload) -> List[Any]:
"""
Extract all pages from a PDF file as images.
Args:
file_upload (st.UploadedFile): Streamlit file upload object containing the PDF.
Returns:
List[Any]: A list of image objects representing each page of the PDF.
"""
logger.info(f"Extracting all pages as images from file: {file_upload.name}")
pdf_pages = []
with pdfplumber.open(file_upload) as pdf:
pdf_pages = [page.to_image().original for page in pdf.pages]
logger.info("PDF pages extracted as images")
return pdf_pages
def delete_vector_db(vector_db: Optional[Chroma]) -> None:
"""
Delete the vector database and clear related session state.
Args:
vector_db (Optional[Chroma]): The vector database to be deleted.
"""
logger.info("Deleting vector DB")
if vector_db is not None:
vector_db.delete_collection()
st.session_state.pop("pdf_pages", None)
st.session_state.pop("file_upload", None)
st.session_state.pop("vector_db", None)
st.success("Collection and temporary files deleted successfully.")
logger.info("Vector DB and related session state cleared")
st.rerun()
else:
st.error("No vector database found to delete.")
logger.warning("Attempted to delete vector DB, but none was found")
def main() -> None:
"""
Main function to run the Streamlit application.
"""
st.subheader("🧠 Ollama PDF RAG playground", divider="gray", anchor=False)
# Get available models
models_info = ollama.list()
available_models = extract_model_names(models_info)
# Create layout
col1, col2 = st.columns([1.5, 2])
# Initialize session state
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "vector_db" not in st.session_state:
st.session_state["vector_db"] = None
if "use_sample" not in st.session_state:
st.session_state["use_sample"] = False
# Model selection
if available_models:
selected_model = col2.selectbox(
"Pick a model available locally on your system ↓",
available_models,
key="model_select"
)
# Add checkbox for sample PDF
use_sample = col1.toggle(
"Use sample PDF (Scammer Agent Paper)",
key="sample_checkbox"
)
# Clear vector DB if switching between sample and upload
if use_sample != st.session_state.get("use_sample"):
if st.session_state["vector_db"] is not None:
st.session_state["vector_db"].delete_collection()
st.session_state["vector_db"] = None
st.session_state["pdf_pages"] = None
st.session_state["use_sample"] = use_sample
if use_sample:
# Use the sample PDF
sample_path = "scammer-agent.pdf"
if os.path.exists(sample_path):
if st.session_state["vector_db"] is None:
with st.spinner("Processing sample PDF..."):
loader = UnstructuredPDFLoader(file_path=sample_path)
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100)
chunks = text_splitter.split_documents(data)
st.session_state["vector_db"] = Chroma.from_documents(
documents=chunks,
embedding=OllamaEmbeddings(model="nomic-embed-text"),
collection_name="myRAG"
)
# Open and display the sample PDF
with pdfplumber.open(sample_path) as pdf:
pdf_pages = [page.to_image().original for page in pdf.pages]
st.session_state["pdf_pages"] = pdf_pages
else:
st.error("Sample PDF file not found in the current directory.")
else:
# Regular file upload with unique key
file_upload = col1.file_uploader(
"Upload a PDF file ↓",
type="pdf",
accept_multiple_files=False,
key="pdf_uploader"
)
if file_upload:
if st.session_state["vector_db"] is None:
with st.spinner("Processing uploaded PDF..."):
st.session_state["vector_db"] = create_vector_db(file_upload)
pdf_pages = extract_all_pages_as_images(file_upload)
st.session_state["pdf_pages"] = pdf_pages
# Display PDF if pages are available
if "pdf_pages" in st.session_state and st.session_state["pdf_pages"]:
# PDF display controls
zoom_level = col1.slider(
"Zoom Level",
min_value=100,
max_value=1000,
value=700,
step=50,
key="zoom_slider"
)
# Display PDF pages
with col1:
with st.container(height=410, border=True):
# Removed the key parameter from st.image()
for page_image in st.session_state["pdf_pages"]:
st.image(page_image, width=zoom_level)
# Delete collection button
delete_collection = col1.button(
"⚠️ Delete collection",
type="secondary",
key="delete_button"
)
if delete_collection:
delete_vector_db(st.session_state["vector_db"])
# Chat interface
with col2:
message_container = st.container(height=500, border=True)
# Display chat history
for i, message in enumerate(st.session_state["messages"]):
avatar = "🤖" if message["role"] == "assistant" else "😎"
with message_container.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# Chat input and processing
if prompt := st.chat_input("Enter a prompt here...", key="chat_input"):
try:
# Add user message to chat
st.session_state["messages"].append({"role": "user", "content": prompt})
with message_container.chat_message("user", avatar="😎"):
st.markdown(prompt)
# Process and display assistant response
with message_container.chat_message("assistant", avatar="🤖"):
with st.spinner(":green[processing...]"):
if st.session_state["vector_db"] is not None:
response = process_question(
prompt, st.session_state["vector_db"], selected_model
)
st.markdown(response)
else:
st.warning("Please upload a PDF file first.")
# Add assistant response to chat history
if st.session_state["vector_db"] is not None:
st.session_state["messages"].append(
{"role": "assistant", "content": response}
)
except Exception as e:
st.error(e, icon="⛔️")
logger.error(f"Error processing prompt: {e}")
else:
if st.session_state["vector_db"] is None:
st.warning("Upload a PDF file or use the sample PDF to begin chat...")
if __name__ == "__main__":
main()