-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathobsidian_rag.py
79 lines (61 loc) · 3.06 KB
/
obsidian_rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import glob
import argparse
from typing import Dict, List, Tuple, Union
from langchain import hub
from langchain_community.chat_models import ChatOllama
from langchain_community.document_loaders import ObsidianLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.vectorstores import chroma
from langchain.chains.question_answering import load_qa_chain
from langchain.chains import RetrievalQA
from langchain_core.callbacks import StreamingStdOutCallbackHandler
import gradio as gr
def get_args() -> argparse.Namespace:
parser: argparse.ArgumentParser = argparse.ArgumentParser(description="Get user file path")
parser.add_argument('--notes_dir', help='Input file path')
parser.add_argument('--filepath', help='Alternative input file path')
parser.add_argument('--vectorize', default=False, action=argparse.BooleanOptionalAction, help='Whether to vectorize the file')
return parser.parse_args()
def format_docs(docs: List[str]) -> str:
return "\n\n".join(doc.page_content for doc in docs)
def remove_all_files_in_folder(directory: str) -> None:
os.system(f"rm -rf {directory}/*")
def main(question: str) -> str:
args = get_args()
notes_dir = args.notes_dir or args.filepath # Use filepath if notes_dir is not provided
if args.vectorize:
loader: ObsidianLoader = ObsidianLoader(path=notes_dir)
data: List[str] = loader.load()
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits: List[str] = text_splitter.split_documents(data)
# Hard reset cause LLM be weird
remove_all_files_in_folder("vectorstore")
vectorstore: chroma.Chroma = chroma.Chroma.from_documents(documents=all_splits, embedding=OllamaEmbeddings(model='mistral'), persist_directory="vectorstore")
print('Vectorized!')
else:
vectorstore: chroma.Chroma = chroma.Chroma(embedding_function=OllamaEmbeddings(model='mistral'), persist_directory="vectorstore")
print('Loaded vectorstore!')
rag_prompt: str = hub.pull("rlm/rag-prompt")
# print('rag prompt loaded!', rag_prompt)
llm = ChatOllama(model="mistral", callbacks=[StreamingStdOutCallbackHandler()])
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=vectorstore.as_retriever(search_kwargs={'k': 3}),
chain_type_kwargs={"prompt": rag_prompt},
)
return qa_chain({"query": question})['result']
# TODO: seperate out loading function for local vectorstore
# demo = gr.Interface(fn=main, inputs="text", outputs="text")
if __name__ == "__main__":
while True:
question = input("Enter your question (or 'quit' to exit): ")
if question.lower() == 'quit':
break
answer = main(question)
print("\nAnswer:", answer)
print("\n" + "-"*50 + "\n")
# demo.launch(show_api=False)