forked from langchain-ai/langchainjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chat vector db chain (langchain-ai#22)
* cr * cr
- Loading branch information
Showing
13 changed files
with
298 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Chat Vector DB QA Chain | ||
|
||
A Chat Vector DB QA chain takes as input a question and chat history. | ||
It first combines the chat history and the question into a standalone question, then looks up relevant documents from the vector database, and then passes those documents and the question to a question answering chain to return a response. | ||
|
||
To create one, you will need a vectorstore, which can be created from embeddings. | ||
|
||
Below is an end-to-end example of doing question answering over a recent state of the union address. | ||
|
||
```typescript | ||
import { OpenAI } from "langchain/llms"; | ||
import { ChatVectorDBQAChain } from "langchain/chains"; | ||
import { HNSWLib } from "langchain/vectorstores"; | ||
import { OpenAIEmbeddings } from "langchain/embeddings"; | ||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; | ||
import * as fs from 'fs'; | ||
|
||
|
||
/* Initialize the LLM to use to answer the question */ | ||
const model = new OpenAI({}); | ||
/* Load in the file we want to do question answering over */ | ||
const text = fs.readFileSync('state_of_the_union.txt','utf8'); | ||
/* Split the text into chunks */ | ||
const textSplitter = new RecursiveCharacterTextSplitter({chunkSize: 1000}); | ||
const docs = textSplitter.createDocuments([text]); | ||
/* Create the vectorstore */ | ||
const vectorStore = await HNSWLib.fromDocuments( | ||
docs, | ||
new OpenAIEmbeddings() | ||
); | ||
/* Create the chain */ | ||
const chain = ChatVectorDBQAChain.fromLLM(model, vectorStore); | ||
/* Ask it a question */ | ||
const question = "What did the president say about Justice Breyer?" | ||
const res = await chain.call({ question: question, chat_history: [] }); | ||
console.log(res); | ||
/* Ask it a follow up question */ | ||
const chatHistory = question + res["text"] | ||
const followUpRes = await chain.call({ question: "Was that nice?", chat_history: chatHistory }); | ||
console.log(followUpRes); | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import { OpenAI } from "langchain/llms"; | ||
import { ChatVectorDBQAChain } from "langchain/chains"; | ||
import { HNSWLib } from "langchain/vectorstores"; | ||
import { OpenAIEmbeddings } from "langchain/embeddings"; | ||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; | ||
import * as fs from 'fs'; | ||
|
||
export const run = async () => { | ||
/* Initialize the LLM to use to answer the question */ | ||
const model = new OpenAI({}); | ||
/* Load in the file we want to do question answering over */ | ||
const text = fs.readFileSync('state_of_the_union.txt','utf8'); | ||
/* Split the text into chunks */ | ||
const textSplitter = new RecursiveCharacterTextSplitter({chunkSize: 1000}); | ||
const docs = textSplitter.createDocuments([text]); | ||
/* Create the vectorstore */ | ||
const vectorStore = await HNSWLib.fromDocuments( | ||
docs, | ||
new OpenAIEmbeddings() | ||
); | ||
/* Create the chain */ | ||
const chain = ChatVectorDBQAChain.fromLLM(model, vectorStore); | ||
/* Ask it a question */ | ||
const question = "What did the president say about Justice Breyer?"; | ||
const res = await chain.call({ question, chat_history: [] }); | ||
console.log(res); | ||
/* Ask it a follow up question */ | ||
const chatHistory = question + res.text; | ||
const followUpRes = await chain.call({ question: "Was that nice?", chat_history: chatHistory }); | ||
console.log(followUpRes); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import { | ||
BaseChain, | ||
ChainValues, | ||
SerializedStuffDocumentsChain, | ||
StuffDocumentsChain, | ||
SerializedLLMChain, | ||
loadQAChain, | ||
LLMChain, | ||
} from "./index"; | ||
|
||
import { PromptTemplate } from "../prompt"; | ||
|
||
import { VectorStore } from "../vectorstores/base"; | ||
import { BaseLLM } from "../llms"; | ||
|
||
import { resolveConfigFromFile } from "../util"; | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
export type LoadValues = Record<string, any>; | ||
|
||
const question_generator_template = `Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. | ||
Chat History: | ||
{chat_history} | ||
Follow Up Input: {question} | ||
Standalone question:`; | ||
const question_generator_prompt = PromptTemplate.fromTemplate(question_generator_template); | ||
|
||
const qa_template = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. | ||
{context} | ||
Question: {question} | ||
Helpful Answer:`; | ||
const qa_prompt = PromptTemplate.fromTemplate(qa_template); | ||
|
||
|
||
export interface ChatVectorDBQAChainInput { | ||
vectorstore: VectorStore; | ||
k: number; | ||
combineDocumentsChain: StuffDocumentsChain; | ||
questionGeneratorChain: LLMChain; | ||
outputKey: string; | ||
inputKey: string; | ||
} | ||
|
||
export type SerializedChatVectorDBQAChain = { | ||
_type: "chat-vector-db"; | ||
k: number; | ||
combine_documents_chain: SerializedStuffDocumentsChain; | ||
combine_documents_chain_path?: string; | ||
question_generator: SerializedLLMChain; | ||
}; | ||
|
||
export class ChatVectorDBQAChain extends BaseChain implements ChatVectorDBQAChainInput { | ||
k = 4; | ||
|
||
inputKey = "question"; | ||
|
||
chatHistoryKey = "chat_history"; | ||
|
||
outputKey = "result"; | ||
|
||
vectorstore: VectorStore; | ||
|
||
combineDocumentsChain: StuffDocumentsChain; | ||
|
||
questionGeneratorChain: LLMChain; | ||
|
||
constructor(fields: { | ||
vectorstore: VectorStore; | ||
combineDocumentsChain: StuffDocumentsChain; | ||
questionGeneratorChain: LLMChain; | ||
inputKey?: string; | ||
outputKey?: string; | ||
k?: number; | ||
}) { | ||
super(); | ||
this.vectorstore = fields.vectorstore; | ||
this.combineDocumentsChain = fields.combineDocumentsChain; | ||
this.questionGeneratorChain = fields.questionGeneratorChain; | ||
this.inputKey = fields.inputKey ?? this.inputKey; | ||
this.outputKey = fields.outputKey ?? this.outputKey; | ||
this.k = fields.k ?? this.k; | ||
} | ||
|
||
async _call(values: ChainValues): Promise<ChainValues> { | ||
if (!(this.inputKey in values)) { | ||
throw new Error(`Question key ${this.inputKey} not found.`); | ||
} | ||
if (!(this.chatHistoryKey in values)) { | ||
throw new Error(`chat history key ${this.inputKey} not found.`); | ||
} | ||
const question: string = values[this.inputKey]; | ||
const chatHistory: string = values[this.chatHistoryKey]; | ||
let newQuestion = question; | ||
if (chatHistory.length > 0){ | ||
const result = await this.questionGeneratorChain.call({question, chat_history: chatHistory}); | ||
const keys = Object.keys(result); | ||
if (keys.length === 1) { | ||
newQuestion = result[keys[0]]; | ||
} else { | ||
throw new Error( | ||
"Return from llm chain has multiple values, only single values supported." | ||
); | ||
|
||
} | ||
} | ||
const docs = await this.vectorstore.similaritySearch(newQuestion, this.k); | ||
const inputs = { question, input_documents: docs, chat_history: chatHistory}; | ||
const result = await this.combineDocumentsChain.call(inputs); | ||
return result; | ||
} | ||
|
||
_chainType() { | ||
return "chat-vector-db" as const; | ||
} | ||
|
||
static async deserialize( | ||
data: SerializedChatVectorDBQAChain, | ||
values: LoadValues | ||
) { | ||
if (!("vectorstore" in values)) { | ||
throw new Error( | ||
`Need to pass in a vectorstore to deserialize VectorDBQAChain` | ||
); | ||
} | ||
const { vectorstore } = values; | ||
const serializedCombineDocumentsChain = resolveConfigFromFile< | ||
"combine_documents_chain", | ||
SerializedStuffDocumentsChain | ||
>("combine_documents_chain", data); | ||
const serializedQuestionGeneratorChain = resolveConfigFromFile< | ||
"question_generator", | ||
SerializedLLMChain | ||
>("question_generator", data); | ||
|
||
return new ChatVectorDBQAChain({ | ||
combineDocumentsChain: await StuffDocumentsChain.deserialize( | ||
serializedCombineDocumentsChain | ||
), | ||
questionGeneratorChain: await LLMChain.deserialize( | ||
serializedQuestionGeneratorChain | ||
), | ||
k: data.k, | ||
vectorstore, | ||
}); | ||
} | ||
|
||
serialize(): SerializedChatVectorDBQAChain { | ||
return { | ||
_type: this._chainType(), | ||
combine_documents_chain: this.combineDocumentsChain.serialize(), | ||
question_generator: this.questionGeneratorChain.serialize(), | ||
k: this.k, | ||
}; | ||
} | ||
|
||
static fromLLM(llm: BaseLLM, vectorstore: VectorStore): ChatVectorDBQAChain { | ||
const qaChain = loadQAChain(llm, qa_prompt); | ||
const questionGeneratorChain = new LLMChain({prompt: question_generator_prompt, llm}); | ||
const instance = new this({ vectorstore, combineDocumentsChain: qaChain, questionGeneratorChain}); | ||
return instance; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import { test } from "@jest/globals"; | ||
import { OpenAI } from "../../llms/openai"; | ||
import { PromptTemplate } from "../../prompt"; | ||
import { LLMChain } from "../llm_chain"; | ||
import { StuffDocumentsChain } from "../combine_docs_chain"; | ||
import { ChatVectorDBQAChain } from "../chat_vector_db_chain"; | ||
import { HNSWLib } from "../../vectorstores/hnswlib"; | ||
import { OpenAIEmbeddings } from "../../embeddings"; | ||
|
||
test("Test ChatVectorDBQAChain", async () => { | ||
const model = new OpenAI({}); | ||
const prompt = PromptTemplate.fromTemplate("Print {question}, and ignore {chat_history}"); | ||
const vectorStore = await HNSWLib.fromTexts( | ||
["Hello world", "Bye bye", "hello nice world", "bye", "hi"], | ||
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }], | ||
new OpenAIEmbeddings() | ||
); | ||
const llmChain = new LLMChain({ prompt, llm: model }); | ||
const combineDocsChain = new StuffDocumentsChain({ | ||
llmChain, | ||
documentVariableName: "foo", | ||
}); | ||
const chain = new ChatVectorDBQAChain({ | ||
combineDocumentsChain: combineDocsChain, | ||
vectorstore: vectorStore, | ||
questionGeneratorChain: llmChain, | ||
}); | ||
const res = await chain.call({ question: "foo", chat_history: "bar" }); | ||
console.log({ res }); | ||
}); | ||
|
||
test("Test ChatVectorDBQAChain from LLM", async () => { | ||
const model = new OpenAI({}); | ||
const vectorStore = await HNSWLib.fromTexts( | ||
["Hello world", "Bye bye", "hello nice world", "bye", "hi"], | ||
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }], | ||
new OpenAIEmbeddings() | ||
); | ||
const chain = ChatVectorDBQAChain.fromLLM(model, vectorStore); | ||
const res = await chain.call({ question: "foo", chat_history: "bar" }); | ||
console.log({ res }); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.