-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrag_chain.py
35 lines (29 loc) · 1.75 KB
/
rag_chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableMap
from langchain_openai import ChatOpenAI,AzureChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
import os
def get_expression_chain(
retriever
) -> Runnable:
"""Return a chain defined primarily in LangChain Expression Language"""
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are an expert Q&A system that is trusted around the world.\nAlways answer the query using the provided context information, and not prior knowledge.\nSome rules to follow:\n1. Never directly reference the given context in your answer.\n2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines."),
("human", "Context information is below.\n---------------------\n{context_str}\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: {query_str}\nAnswer: "),
]
)
llm = AzureChatOpenAI(azure_endpoint=os.getenv('AZURE_END_POINT'),model=os.getenv('AZURE_OPENAI_GPT_MODEL_NAME'), temperature=0,api_version="2024-02-01",api_key=os.getenv('AZURE_OPENAI_API_KEY'))
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain_from_docs = (
RunnablePassthrough.assign(context_str=(lambda x: format_docs(x["context_str"])))
| prompt
| llm
| StrOutputParser()
)
rag_chain_with_source = RunnableParallel(
{"context_str": retriever, "query_str": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)
return rag_chain_with_source