Skip to content

Commit

Permalink
DH-4917/fixing the issue with the embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Oct 31, 2023
1 parent b828efa commit 4e2f397
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 38 deletions.
1 change: 1 addition & 0 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def add_golden_records(
golden_record = golden_records_repository.insert(golden_record)
self.vector_store.add_record(
documents=question,
db_connection_id=record.db_connection_id,
collection=self.golden_record_collection,
metadata=[
{
Expand Down
File renamed without changes.
34 changes: 13 additions & 21 deletions dataherald/model/chat_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import os
from typing import Any

from langchain.chat_models import ChatAnthropic, ChatCohere, ChatGooglePalm, ChatOpenAI
from overrides import override

from dataherald.model import LLMModel
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.utils.encrypt import FernetEncrypt


class ChatModel(LLMModel):
Expand All @@ -21,23 +19,17 @@ def get_model(
model_name="gpt-4-32k",
**kwargs: Any
) -> Any:
if database_connection.llm_api_key is not None:
fernet_encrypt = FernetEncrypt()
api_key = fernet_encrypt.decrypt(database_connection.llm_api_key)
if model_family == "openai":
os.environ["OPENAI_API_KEY"] = api_key
elif model_family == "anthropic":
os.environ["ANTHROPIC_API_KEY"] = api_key
elif model_family == "google":
os.environ["GOOGLE_API_KEY"] = api_key
elif model_family == "cohere":
os.environ["COHERE_API_KEY"] = api_key
if os.environ.get("OPENAI_API_KEY") is not None:
return ChatOpenAI(model_name=model_name, **kwargs)
if os.environ.get("ANTHROPIC_API_KEY") is not None:
return ChatAnthropic(model_name=model_name, **kwargs)
if os.environ.get("GOOGLE_API_KEY") is not None:
return ChatGooglePalm(model_name=model_name, **kwargs)
if os.environ.get("COHERE_API_KEY") is not None:
return ChatCohere(model_name=model_name, **kwargs)
api_key = database_connection.decrypt_api_key()
if model_family == "openai":
return ChatOpenAI(model_name=model_name, openai_api_key=api_key, **kwargs)
if model_family == "anthropic":
return ChatAnthropic(
model_name=model_name, anthropic_api_key=api_key, **kwargs
)
if model_family == "google":
return ChatGooglePalm(
model_name=model_name, google_api_key=api_key, **kwargs
)
if model_family == "cohere":
return ChatCohere(model_name=model_name, cohere_api_key=api_key, **kwargs)
raise ValueError("No valid API key environment variable found")
1 change: 1 addition & 0 deletions dataherald/scripts/delete_and_populate_golden_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
question = golden_record["question"]
vector_store.add_record(
documents=question,
db_connection_id=golden_record["db_connection_id"],
collection=golden_record_collection,
metadata=[
{
Expand Down
1 change: 1 addition & 0 deletions dataherald/scripts/migrate_v001_to_v002.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def add_db_connection_id(collection_name: str, storage) -> None:
question = golden_record["question"]
vector_store.add_record(
documents=question,
db_connection_id=golden_record["db_connection_id"],
collection=golden_record_collection,
metadata=[
{
Expand Down
7 changes: 7 additions & 0 deletions dataherald/sql_database/models/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Any

from pydantic import BaseModel, BaseSettings, Extra, validator
Expand Down Expand Up @@ -76,3 +77,9 @@ def encrypt(cls, value: str):
return value
except Exception:
return fernet_encrypt.encrypt(value)

def decrypt_api_key(self):
if self.llm_api_key is not None and self.llm_api_key != "":
fernet_encrypt = FernetEncrypt()
return fernet_encrypt.decrypt(self.llm_api_key)
return os.environ.get("OPENAI_API_KEY")
20 changes: 14 additions & 6 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@


TOP_K = 50
EMBEDDING_MODEL = "text-embedding-ada-002"


def catch_exceptions(): # noqa: C901
Expand Down Expand Up @@ -201,15 +202,14 @@ class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
Use this tool to identify the relevant tables for the given question.
"""
db_scan: List[TableDescription]
embedding: OpenAIEmbeddings

def get_embedding(
self, text: str, model: str = "text-embedding-ada-002"
self,
text: str,
) -> List[float]:
text = text.replace("\n", " ")
embedding = OpenAIEmbeddings(
openai_api_key=os.environ.get("OPENAI_API_KEY"), model=model
)
return embedding.embed_query(text)
return self.embedding.embed_query(text)

def cosine_similarity(self, a: List[float], b: List[float]) -> float:
return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4)
Expand Down Expand Up @@ -465,6 +465,7 @@ class SQLDatabaseToolkit(BaseToolkit):
few_shot_examples: List[dict] | None = Field(exclude=True, default=None)
instructions: List[dict] | None = Field(exclude=True, default=None)
db_scan: List[TableDescription] = Field(exclude=True)
embedding: OpenAIEmbeddings = Field(exclude=True)

@property
def dialect(self) -> str:
Expand All @@ -490,7 +491,10 @@ def get_tools(self) -> List[BaseTool]:
get_current_datetime = GetCurrentTimeTool(db=self.db, context=self.context)
tools.append(get_current_datetime)
tables_sql_db_tool = TablesSQLDatabaseTool(
db=self.db, context=self.context, db_scan=self.db_scan
db=self.db,
context=self.context,
db_scan=self.db_scan,
embedding=self.embedding,
)
tools.append(tables_sql_db_tool)
schema_sql_db_tool = SchemaSQLDatabaseTool(
Expand Down Expand Up @@ -632,6 +636,10 @@ def generate_response(
few_shot_examples=new_fewshot_examples,
instructions=instructions,
db_scan=db_scan,
embedding=OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
),
)
agent_executor = self.create_sql_agent(
toolkit=toolkit,
Expand Down
7 changes: 6 additions & 1 deletion dataherald/tests/vector_store/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ def query(

@override
def add_record(
self, documents: str, collection: str, metadata: Any, ids: List # noqa: ARG002
self,
documents: str,
db_connection_id: str,
collection: str,
metadata: Any,
ids: List, # noqa: ARG002
):
pass

Expand Down
7 changes: 6 additions & 1 deletion dataherald/vector_store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ def create_collection(self, collection: str):

@abstractmethod
def add_record(
self, documents: str, collection: str, metadata: Any, ids: List = None
self,
documents: str,
db_connection_id: str,
collection: str,
metadata: Any,
ids: List = None,
):
pass

Expand Down
9 changes: 8 additions & 1 deletion dataherald/vector_store/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ def query(
return self.convert_to_pinecone_object_model(query_results)

@override
def add_record(self, documents: str, collection: str, metadata: Any, ids: List):
def add_record(
self,
documents: str,
db_connection_id: str, # noqa: ARG002
collection: str,
metadata: Any,
ids: List,
):
target_collection = self.chroma_client.get_or_create_collection(collection)
existing_rows = target_collection.get(ids=ids)
if len(existing_rows["documents"]) == 0:
Expand Down
35 changes: 27 additions & 8 deletions dataherald/vector_store/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from typing import Any, List

import openai
import pinecone
from langchain.embeddings import OpenAIEmbeddings
from overrides import override

from dataherald.config import System
from dataherald.db import DB
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.vector_store import VectorStore

EMBEDDING_MODEL = "text-embedding-ada-002"
Expand All @@ -31,9 +33,14 @@ def query(
num_results: int,
) -> list:
index = pinecone.Index(collection)
xq = openai.Embedding.create(input=query_texts[0], engine=EMBEDDING_MODEL)[
"data"
][0]["embedding"]
db_connection_repository = DatabaseConnectionRepository(
self.system.instance(DB)
)
database_connection = db_connection_repository.find_by_id(db_connection_id)
embedding = OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL
)
xq = embedding.embed_query(query_texts[0])
query_response = index.query(
queries=[xq],
filter={
Expand All @@ -45,13 +52,25 @@ def query(
return query_response.to_dict()["results"][0]["matches"]

@override
def add_record(self, documents: str, collection: str, metadata: Any, ids: List):
def add_record(
self,
documents: str,
db_connection_id: str,
collection: str,
metadata: Any,
ids: List,
):
if collection not in pinecone.list_indexes():
self.create_collection(collection)

db_connection_repository = DatabaseConnectionRepository(
self.system.instance(DB)
)
database_connection = db_connection_repository.find_by_id(db_connection_id)
embedding = OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL
)
index = pinecone.Index(collection)
res = openai.Embedding.create(input=[documents], engine=EMBEDDING_MODEL)
embeds = [record["embedding"] for record in res["data"]]
embeds = embedding.embed_documents([documents])
record = [(ids[0], embeds, metadata[0])]
index.upsert(vectors=record)

Expand Down

0 comments on commit 4e2f397

Please sign in to comment.