Skip to content

Commit

Permalink
chat vector db chain (langchain-ai#22)
Browse files Browse the repository at this point in the history
* cr

* cr
  • Loading branch information
hwchase17 authored Feb 17, 2023
1 parent 0de457d commit 46b1485
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 17 deletions.
42 changes: 42 additions & 0 deletions docs/docs/modules/chains/chat_vector_db_qa.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Chat Vector DB QA Chain

A Chat Vector DB QA chain takes as input a question and chat history.
It first combines the chat history and the question into a standalone question, then looks up relevant documents from the vector database, and then passes those documents and the question to a question answering chain to return a response.

To create one, you will need a vectorstore, which can be created from embeddings.

Below is an end-to-end example of doing question answering over a recent state of the union address.

```typescript
import { OpenAI } from "langchain/llms";
import { ChatVectorDBQAChain } from "langchain/chains";
import { HNSWLib } from "langchain/vectorstores";
import { OpenAIEmbeddings } from "langchain/embeddings";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import * as fs from 'fs';


/* Initialize the LLM to use to answer the question */
const model = new OpenAI({});
/* Load in the file we want to do question answering over */
const text = fs.readFileSync('state_of_the_union.txt','utf8');
/* Split the text into chunks */
const textSplitter = new RecursiveCharacterTextSplitter({chunkSize: 1000});
const docs = textSplitter.createDocuments([text]);
/* Create the vectorstore */
const vectorStore = await HNSWLib.fromDocuments(
docs,
new OpenAIEmbeddings()
);
/* Create the chain */
const chain = ChatVectorDBQAChain.fromLLM(model, vectorStore);
/* Ask it a question */
const question = "What did the president say about Justice Breyer?"
const res = await chain.call({ question: question, chat_history: [] });
console.log(res);
/* Ask it a follow up question */
const chatHistory = question + res["text"]
const followUpRes = await chain.call({ question: "Was that nice?", chat_history: chatHistory });
console.log(followUpRes);

```
31 changes: 31 additions & 0 deletions examples/src/chains/chat_vector_db.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { OpenAI } from "langchain/llms";
import { ChatVectorDBQAChain } from "langchain/chains";
import { HNSWLib } from "langchain/vectorstores";
import { OpenAIEmbeddings } from "langchain/embeddings";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import * as fs from 'fs';

export const run = async () => {
/* Initialize the LLM to use to answer the question */
const model = new OpenAI({});
/* Load in the file we want to do question answering over */
const text = fs.readFileSync('state_of_the_union.txt','utf8');
/* Split the text into chunks */
const textSplitter = new RecursiveCharacterTextSplitter({chunkSize: 1000});
const docs = textSplitter.createDocuments([text]);
/* Create the vectorstore */
const vectorStore = await HNSWLib.fromDocuments(
docs,
new OpenAIEmbeddings()
);
/* Create the chain */
const chain = ChatVectorDBQAChain.fromLLM(model, vectorStore);
/* Ask it a question */
const question = "What did the president say about Justice Breyer?";
const res = await chain.call({ question, chat_history: [] });
console.log(res);
/* Ask it a follow up question */
const chatHistory = question + res.text;
const followUpRes = await chain.call({ question: "Was that nice?", chat_history: chatHistory });
console.log(followUpRes);
};
2 changes: 1 addition & 1 deletion langchain/agents/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ export { Agent, StaticAgent, staticImplements, AgentInput } from "./agent";
export { AgentExecutor } from "./executor";
export { ZeroShotAgent, SerializedZeroShotAgent } from "./mrkl";
export { Tool } from "./tools";
export {initializeAgentExecutor} from "./initialize"
export {initializeAgentExecutor} from "./initialize";

export { loadAgent } from "./load";
3 changes: 2 additions & 1 deletion langchain/agents/initialize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Tool } from "./tools";
import { BaseLLM } from "../llms";
import { AgentExecutor } from "./executor";
import { ZeroShotAgent } from "./mrkl";

export const initializeAgentExecutor = async (
tools: Tool[],
llm: BaseLLM,
Expand All @@ -15,7 +16,7 @@ export const initializeAgentExecutor = async (
tools,
returnIntermediateSteps: true,
});
return executor
return executor;
default:
throw new Error("Unknown agent type");
}
Expand Down
4 changes: 2 additions & 2 deletions langchain/chains/base.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { LLMChain, StuffDocumentsChain, VectorDBQAChain } from "./index";
import { LLMChain, StuffDocumentsChain, VectorDBQAChain, ChatVectorDBQAChain } from "./index";
import { BaseMemory } from "../memory";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type ChainValues = Record<string, any>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type LoadValues = Record<string, any>;

const chainClasses = [LLMChain, StuffDocumentsChain, VectorDBQAChain];
const chainClasses = [LLMChain, StuffDocumentsChain, VectorDBQAChain, ChatVectorDBQAChain];

export type SerializedBaseChain = ReturnType<
InstanceType<(typeof chainClasses)[number]>["serialize"]
Expand Down
164 changes: 164 additions & 0 deletions langchain/chains/chat_vector_db_chain.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import {
BaseChain,
ChainValues,
SerializedStuffDocumentsChain,
StuffDocumentsChain,
SerializedLLMChain,
loadQAChain,
LLMChain,
} from "./index";

import { PromptTemplate } from "../prompt";

import { VectorStore } from "../vectorstores/base";
import { BaseLLM } from "../llms";

import { resolveConfigFromFile } from "../util";
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type LoadValues = Record<string, any>;

const question_generator_template = `Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:`;
const question_generator_prompt = PromptTemplate.fromTemplate(question_generator_template);

const qa_template = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:`;
const qa_prompt = PromptTemplate.fromTemplate(qa_template);


export interface ChatVectorDBQAChainInput {
vectorstore: VectorStore;
k: number;
combineDocumentsChain: StuffDocumentsChain;
questionGeneratorChain: LLMChain;
outputKey: string;
inputKey: string;
}

export type SerializedChatVectorDBQAChain = {
_type: "chat-vector-db";
k: number;
combine_documents_chain: SerializedStuffDocumentsChain;
combine_documents_chain_path?: string;
question_generator: SerializedLLMChain;
};

export class ChatVectorDBQAChain extends BaseChain implements ChatVectorDBQAChainInput {
k = 4;

inputKey = "question";

chatHistoryKey = "chat_history";

outputKey = "result";

vectorstore: VectorStore;

combineDocumentsChain: StuffDocumentsChain;

questionGeneratorChain: LLMChain;

constructor(fields: {
vectorstore: VectorStore;
combineDocumentsChain: StuffDocumentsChain;
questionGeneratorChain: LLMChain;
inputKey?: string;
outputKey?: string;
k?: number;
}) {
super();
this.vectorstore = fields.vectorstore;
this.combineDocumentsChain = fields.combineDocumentsChain;
this.questionGeneratorChain = fields.questionGeneratorChain;
this.inputKey = fields.inputKey ?? this.inputKey;
this.outputKey = fields.outputKey ?? this.outputKey;
this.k = fields.k ?? this.k;
}

async _call(values: ChainValues): Promise<ChainValues> {
if (!(this.inputKey in values)) {
throw new Error(`Question key ${this.inputKey} not found.`);
}
if (!(this.chatHistoryKey in values)) {
throw new Error(`chat history key ${this.inputKey} not found.`);
}
const question: string = values[this.inputKey];
const chatHistory: string = values[this.chatHistoryKey];
let newQuestion = question;
if (chatHistory.length > 0){
const result = await this.questionGeneratorChain.call({question, chat_history: chatHistory});
const keys = Object.keys(result);
if (keys.length === 1) {
newQuestion = result[keys[0]];
} else {
throw new Error(
"Return from llm chain has multiple values, only single values supported."
);

}
}
const docs = await this.vectorstore.similaritySearch(newQuestion, this.k);
const inputs = { question, input_documents: docs, chat_history: chatHistory};
const result = await this.combineDocumentsChain.call(inputs);
return result;
}

_chainType() {
return "chat-vector-db" as const;
}

static async deserialize(
data: SerializedChatVectorDBQAChain,
values: LoadValues
) {
if (!("vectorstore" in values)) {
throw new Error(
`Need to pass in a vectorstore to deserialize VectorDBQAChain`
);
}
const { vectorstore } = values;
const serializedCombineDocumentsChain = resolveConfigFromFile<
"combine_documents_chain",
SerializedStuffDocumentsChain
>("combine_documents_chain", data);
const serializedQuestionGeneratorChain = resolveConfigFromFile<
"question_generator",
SerializedLLMChain
>("question_generator", data);

return new ChatVectorDBQAChain({
combineDocumentsChain: await StuffDocumentsChain.deserialize(
serializedCombineDocumentsChain
),
questionGeneratorChain: await LLMChain.deserialize(
serializedQuestionGeneratorChain
),
k: data.k,
vectorstore,
});
}

serialize(): SerializedChatVectorDBQAChain {
return {
_type: this._chainType(),
combine_documents_chain: this.combineDocumentsChain.serialize(),
question_generator: this.questionGeneratorChain.serialize(),
k: this.k,
};
}

static fromLLM(llm: BaseLLM, vectorstore: VectorStore): ChatVectorDBQAChain {
const qaChain = loadQAChain(llm, qa_prompt);
const questionGeneratorChain = new LLMChain({prompt: question_generator_prompt, llm});
const instance = new this({ vectorstore, combineDocumentsChain: qaChain, questionGeneratorChain});
return instance;
}
}
1 change: 1 addition & 0 deletions langchain/chains/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export {
SerializedStuffDocumentsChain,
StuffDocumentsChain,
} from "./combine_docs_chain";
export { ChatVectorDBQAChain, SerializedChatVectorDBQAChain} from "./chat_vector_db_chain";
export { VectorDBQAChain, SerializedVectorDBQAChain } from "./vector_db_qa";
export { loadChain } from "./load";
export { loadQAChain } from "./question_answering/load";
4 changes: 2 additions & 2 deletions langchain/chains/question_answering/load.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { BaseLLM } from "../../llms";
import { LLMChain } from "../llm_chain";
import { StuffDocumentsChain } from "../combine_docs_chain";
import { prompt } from "./stuff_prompts";
import { DEFAULT_QA_PROMPT } from "./stuff_prompts";


export const loadQAChain = (llm: BaseLLM) => {
export const loadQAChain = (llm: BaseLLM, prompt = DEFAULT_QA_PROMPT) => {
const llmChain = new LLMChain({ prompt, llm });
const chain = new StuffDocumentsChain({llmChain});
return chain;
Expand Down
2 changes: 1 addition & 1 deletion langchain/chains/question_answering/stuff_prompts.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable */
import { PromptTemplate } from "../../prompt";

export const prompt = new PromptTemplate({
export const DEFAULT_QA_PROMPT = new PromptTemplate({
template: "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:",
inputVariables: ["context", "question"],
});
Expand Down
42 changes: 42 additions & 0 deletions langchain/chains/tests/chat_vector_db_qa_chain.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { test } from "@jest/globals";
import { OpenAI } from "../../llms/openai";
import { PromptTemplate } from "../../prompt";
import { LLMChain } from "../llm_chain";
import { StuffDocumentsChain } from "../combine_docs_chain";
import { ChatVectorDBQAChain } from "../chat_vector_db_chain";
import { HNSWLib } from "../../vectorstores/hnswlib";
import { OpenAIEmbeddings } from "../../embeddings";

test("Test ChatVectorDBQAChain", async () => {
const model = new OpenAI({});
const prompt = PromptTemplate.fromTemplate("Print {question}, and ignore {chat_history}");
const vectorStore = await HNSWLib.fromTexts(
["Hello world", "Bye bye", "hello nice world", "bye", "hi"],
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }],
new OpenAIEmbeddings()
);
const llmChain = new LLMChain({ prompt, llm: model });
const combineDocsChain = new StuffDocumentsChain({
llmChain,
documentVariableName: "foo",
});
const chain = new ChatVectorDBQAChain({
combineDocumentsChain: combineDocsChain,
vectorstore: vectorStore,
questionGeneratorChain: llmChain,
});
const res = await chain.call({ question: "foo", chat_history: "bar" });
console.log({ res });
});

test("Test ChatVectorDBQAChain from LLM", async () => {
const model = new OpenAI({});
const vectorStore = await HNSWLib.fromTexts(
["Hello world", "Bye bye", "hello nice world", "bye", "hi"],
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }],
new OpenAIEmbeddings()
);
const chain = ChatVectorDBQAChain.fromLLM(model, vectorStore);
const res = await chain.call({ question: "foo", chat_history: "bar" });
console.log({ res });
});
2 changes: 1 addition & 1 deletion langchain/vectorstores/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export abstract class VectorStore {
): Promise<[Document, number][]>;

async addDocuments(documents: Document[]): Promise<void> {
const texts = documents.map( ({pageContent}) => (pageContent))
const texts = documents.map( ({pageContent}) => (pageContent));
this.addVectors(
await this.embeddings.embedDocuments(texts),
documents
Expand Down
6 changes: 3 additions & 3 deletions langchain/vectorstores/hnswlib.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ export class HNSWLib extends SaveableVectorStore {
metadatas: object[],
embeddings: Embeddings
): Promise<HNSWLib> {
var docs = [];
const docs = [];
for (let i = 0; i < texts.length; i++) {
let newDoc = new Document({pageContent: texts[i], metadata: metadatas[i]});
const newDoc = new Document({pageContent: texts[i], metadata: metadatas[i]});
docs.push(newDoc);
}
return HNSWLib.fromDocuments(docs, embeddings)
return HNSWLib.fromDocuments(docs, embeddings);
}

static async fromDocuments(
Expand Down
Loading

0 comments on commit 46b1485

Please sign in to comment.