-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathwebui.py
126 lines (102 loc) · 4.58 KB
/
webui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import time
import json
import uuid
import streamlit as st
from indexer import split_docs
from embedder import call_embed_model
from retriever import retrieve_docs
from chain_handler import setup_chain
from docs_db_handler import init_db, add_db_docs, load_docs
from session_handler import get_session_history, save_session_history
from langchain_core.runnables.history import RunnableWithMessageHistory
current_directory = os.path.dirname(os.path.abspath(__file__))
sessions_folder = os.path.join(current_directory, "sessions")
data_folder = os.path.join(current_directory, "data")
db_path = os.path.join(current_directory, "db")
#----- SIDEBAR - FILE UPLOAD - PREVIOUS CONVERSATIONS -----#
jsons = [f for f in os.listdir(sessions_folder) if f.endswith('.json')]
json_datas = []
for file in jsons:
with open(os.path.join(sessions_folder, file), 'r') as json_file:
data = json.load(json_file)
json_datas.append(data)
with st.sidebar:
st.title("RAG App Web UI")
with st.container():
st.header("File Upload")
# Check if 'uploaded_files' is already in session state
if 'uploaded_files' not in st.session_state:
st.session_state.uploaded_files = None
uploaded_files = st.file_uploader(
"Upload your documents from here:", accept_multiple_files=True, key="file_uploader"
)
if uploaded_files is not None and uploaded_files != st.session_state.uploaded_files:
st.session_state.uploaded_files = uploaded_files
with st.spinner("Loading..."):
# Iterate over the uploaded files and save them to the data folder
for uploaded_file in uploaded_files:
file_path = os.path.join(data_folder, uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
time.sleep(3)
st.success("Files uploaded and saved successfully!")
# Start a new conversation
if st.button("New Conversation"):
st.session_state.session_id = str(uuid.uuid4())
session_id = st.session_state.session_id
st.session_state.conversation = []
st.success("New conversation started!")
#----- CHAT INTERFACE -----#
if 'session_id' not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
session_id = st.session_state.session_id
# Use session state to store conversation
if 'conversation' not in st.session_state:
st.session_state.conversation = []
if 'conversation' not in st.session_state:
st.session_state.conversation = []
# Load documents only once
if 'docs' not in st.session_state:
st.session_state.docs = load_docs('data') # store loaded docs in session state
# Split and embed docs only once
if 'vectorstore' not in st.session_state:
chunks = split_docs(st.session_state.docs)
embeddings_model = call_embed_model("sentence-transformers/all-MiniLM-L12-v2")
st.session_state.vectorstore = init_db(chunks, embeddings_model, 'db', embeddings_model)
add_db_docs(st.session_state.vectorstore, 'data', 'db', embeddings_model)
# Display the conversation history
for message in st.session_state.conversation:
with st.chat_message(message["role"]):
st.write(message["message"])
chat_history = get_session_history(st.session_state.session_id)
print("prompt go")
prompt = st.chat_input("Say something")
if prompt:
print("prompt given")
with st.chat_message("human"):
st.write(prompt)
st.session_state.conversation.append({"role": "human", "message": prompt})
retriever = retrieve_docs(prompt, st.session_state.vectorstore, similar_docs_count=5, see_content=False)
rag_chain = setup_chain("llama3", retriever)
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
lambda _: chat_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
with st.chat_message("ai"):
answer = ""
placeholder = st.empty()
for response_chunk in conversational_rag_chain.stream(
{"input": prompt},
config={
"configurable": {"session_id": session_id}
},
):
if 'answer' in response_chunk:
answer += response_chunk["answer"]
placeholder.write(answer)
st.session_state.conversation.append({"role": "ai", "message": answer})
save_session_history(st.session_state.session_id)