From f08225d2225b5011e6dcad1190024255c2655abf Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Mon, 30 Oct 2023 17:58:54 -0600 Subject: [PATCH 1/4] DH-4917 Encrypt llm_api_key --- dataherald/sql_database/models/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataherald/sql_database/models/types.py b/dataherald/sql_database/models/types.py index 43973baf..956460de 100644 --- a/dataherald/sql_database/models/types.py +++ b/dataherald/sql_database/models/types.py @@ -68,7 +68,7 @@ def set_uri_without_ssh(cls, v, values): raise ValueError("When use_ssh is False set uri") return v - @validator("uri", pre=True, always=True) + @validator("uri", "llm_api_key", pre=True, always=True) def encrypt(cls, value: str): fernet_encrypt = FernetEncrypt() try: From b828efa10944f24cb451ea1ed1791c13e1c52245 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 31 Oct 2023 09:33:02 -0400 Subject: [PATCH 2/4] DH-4917/fix the embedding issue --- dataherald/sql_generator/dataherald_sqlagent.py | 8 +++++--- requirements.txt | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 7d6a6cac..75ef484a 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -22,6 +22,7 @@ CallbackManagerForToolRun, ) from langchain.chains.llm import LLMChain +from langchain.embeddings import OpenAIEmbeddings from langchain.schema import AgentAction from langchain.tools.base import BaseTool from overrides import override @@ -205,9 +206,10 @@ def get_embedding( self, text: str, model: str = "text-embedding-ada-002" ) -> List[float]: text = text.replace("\n", " ") - return openai.Embedding.create(input=[text], model=model)["data"][0][ - "embedding" - ] + embedding = OpenAIEmbeddings( + openai_api_key=os.environ.get("OPENAI_API_KEY"), model=model + ) + return embedding.embed_query(text) def cosine_similarity(self, a: List[float], b: List[float]) -> float: return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4) diff --git a/requirements.txt b/requirements.txt index efc79b52..f9561df6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,3 +34,4 @@ sphinx-book-theme==1.0.1 boto3==1.28.38 botocore==1.31.38 PyAthena==3.0.6 +tiktoken==0.5.1 From 4e2f397cedd2d7ccefec9211b925d5dcc020bbd2 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 31 Oct 2023 14:57:17 -0400 Subject: [PATCH 3/4] DH-4917/fixing the issue with the embeddings --- dataherald/context_store/default.py | 1 + .../model/{base_models.py => base_model.py} | 0 dataherald/model/chat_model.py | 34 +++++++----------- .../delete_and_populate_golden_records.py | 1 + dataherald/scripts/migrate_v001_to_v002.py | 1 + dataherald/sql_database/models/types.py | 7 ++++ .../sql_generator/dataherald_sqlagent.py | 20 +++++++---- .../tests/vector_store/test_vector_store.py | 7 +++- dataherald/vector_store/__init__.py | 7 +++- dataherald/vector_store/chroma.py | 9 ++++- dataherald/vector_store/pinecone.py | 35 ++++++++++++++----- 11 files changed, 84 insertions(+), 38 deletions(-) rename dataherald/model/{base_models.py => base_model.py} (100%) diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index bb910f1c..0e9cbc3b 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -77,6 +77,7 @@ def add_golden_records( golden_record = golden_records_repository.insert(golden_record) self.vector_store.add_record( documents=question, + db_connection_id=record.db_connection_id, collection=self.golden_record_collection, metadata=[ { diff --git a/dataherald/model/base_models.py b/dataherald/model/base_model.py similarity index 100% rename from dataherald/model/base_models.py rename to dataherald/model/base_model.py diff --git a/dataherald/model/chat_model.py b/dataherald/model/chat_model.py index 2b19ebbb..4540a2bd 100644 --- a/dataherald/model/chat_model.py +++ b/dataherald/model/chat_model.py @@ -1,4 +1,3 @@ -import os from typing import Any from langchain.chat_models import ChatAnthropic, ChatCohere, ChatGooglePalm, ChatOpenAI @@ -6,7 +5,6 @@ from dataherald.model import LLMModel from dataherald.sql_database.models.types import DatabaseConnection -from dataherald.utils.encrypt import FernetEncrypt class ChatModel(LLMModel): @@ -21,23 +19,17 @@ def get_model( model_name="gpt-4-32k", **kwargs: Any ) -> Any: - if database_connection.llm_api_key is not None: - fernet_encrypt = FernetEncrypt() - api_key = fernet_encrypt.decrypt(database_connection.llm_api_key) - if model_family == "openai": - os.environ["OPENAI_API_KEY"] = api_key - elif model_family == "anthropic": - os.environ["ANTHROPIC_API_KEY"] = api_key - elif model_family == "google": - os.environ["GOOGLE_API_KEY"] = api_key - elif model_family == "cohere": - os.environ["COHERE_API_KEY"] = api_key - if os.environ.get("OPENAI_API_KEY") is not None: - return ChatOpenAI(model_name=model_name, **kwargs) - if os.environ.get("ANTHROPIC_API_KEY") is not None: - return ChatAnthropic(model_name=model_name, **kwargs) - if os.environ.get("GOOGLE_API_KEY") is not None: - return ChatGooglePalm(model_name=model_name, **kwargs) - if os.environ.get("COHERE_API_KEY") is not None: - return ChatCohere(model_name=model_name, **kwargs) + api_key = database_connection.decrypt_api_key() + if model_family == "openai": + return ChatOpenAI(model_name=model_name, openai_api_key=api_key, **kwargs) + if model_family == "anthropic": + return ChatAnthropic( + model_name=model_name, anthropic_api_key=api_key, **kwargs + ) + if model_family == "google": + return ChatGooglePalm( + model_name=model_name, google_api_key=api_key, **kwargs + ) + if model_family == "cohere": + return ChatCohere(model_name=model_name, cohere_api_key=api_key, **kwargs) raise ValueError("No valid API key environment variable found") diff --git a/dataherald/scripts/delete_and_populate_golden_records.py b/dataherald/scripts/delete_and_populate_golden_records.py index c088c46b..fe7ceeab 100644 --- a/dataherald/scripts/delete_and_populate_golden_records.py +++ b/dataherald/scripts/delete_and_populate_golden_records.py @@ -29,6 +29,7 @@ question = golden_record["question"] vector_store.add_record( documents=question, + db_connection_id=golden_record["db_connection_id"], collection=golden_record_collection, metadata=[ { diff --git a/dataherald/scripts/migrate_v001_to_v002.py b/dataherald/scripts/migrate_v001_to_v002.py index e1d9fd6f..03098397 100644 --- a/dataherald/scripts/migrate_v001_to_v002.py +++ b/dataherald/scripts/migrate_v001_to_v002.py @@ -52,6 +52,7 @@ def add_db_connection_id(collection_name: str, storage) -> None: question = golden_record["question"] vector_store.add_record( documents=question, + db_connection_id=golden_record["db_connection_id"], collection=golden_record_collection, metadata=[ { diff --git a/dataherald/sql_database/models/types.py b/dataherald/sql_database/models/types.py index 956460de..719f1d9e 100644 --- a/dataherald/sql_database/models/types.py +++ b/dataherald/sql_database/models/types.py @@ -1,3 +1,4 @@ +import os from typing import Any from pydantic import BaseModel, BaseSettings, Extra, validator @@ -76,3 +77,9 @@ def encrypt(cls, value: str): return value except Exception: return fernet_encrypt.encrypt(value) + + def decrypt_api_key(self): + if self.llm_api_key is not None and self.llm_api_key != "": + fernet_encrypt = FernetEncrypt() + return fernet_encrypt.decrypt(self.llm_api_key) + return os.environ.get("OPENAI_API_KEY") diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 75ef484a..0752a9c4 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -56,6 +56,7 @@ TOP_K = 50 +EMBEDDING_MODEL = "text-embedding-ada-002" def catch_exceptions(): # noqa: C901 @@ -201,15 +202,14 @@ class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): Use this tool to identify the relevant tables for the given question. """ db_scan: List[TableDescription] + embedding: OpenAIEmbeddings def get_embedding( - self, text: str, model: str = "text-embedding-ada-002" + self, + text: str, ) -> List[float]: text = text.replace("\n", " ") - embedding = OpenAIEmbeddings( - openai_api_key=os.environ.get("OPENAI_API_KEY"), model=model - ) - return embedding.embed_query(text) + return self.embedding.embed_query(text) def cosine_similarity(self, a: List[float], b: List[float]) -> float: return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4) @@ -465,6 +465,7 @@ class SQLDatabaseToolkit(BaseToolkit): few_shot_examples: List[dict] | None = Field(exclude=True, default=None) instructions: List[dict] | None = Field(exclude=True, default=None) db_scan: List[TableDescription] = Field(exclude=True) + embedding: OpenAIEmbeddings = Field(exclude=True) @property def dialect(self) -> str: @@ -490,7 +491,10 @@ def get_tools(self) -> List[BaseTool]: get_current_datetime = GetCurrentTimeTool(db=self.db, context=self.context) tools.append(get_current_datetime) tables_sql_db_tool = TablesSQLDatabaseTool( - db=self.db, context=self.context, db_scan=self.db_scan + db=self.db, + context=self.context, + db_scan=self.db_scan, + embedding=self.embedding, ) tools.append(tables_sql_db_tool) schema_sql_db_tool = SchemaSQLDatabaseTool( @@ -632,6 +636,10 @@ def generate_response( few_shot_examples=new_fewshot_examples, instructions=instructions, db_scan=db_scan, + embedding=OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ), ) agent_executor = self.create_sql_agent( toolkit=toolkit, diff --git a/dataherald/tests/vector_store/test_vector_store.py b/dataherald/tests/vector_store/test_vector_store.py index f2e18bf3..3e96b606 100644 --- a/dataherald/tests/vector_store/test_vector_store.py +++ b/dataherald/tests/vector_store/test_vector_store.py @@ -22,7 +22,12 @@ def query( @override def add_record( - self, documents: str, collection: str, metadata: Any, ids: List # noqa: ARG002 + self, + documents: str, + db_connection_id: str, + collection: str, + metadata: Any, + ids: List, # noqa: ARG002 ): pass diff --git a/dataherald/vector_store/__init__.py b/dataherald/vector_store/__init__.py index 4f9b3c12..59e4ebde 100644 --- a/dataherald/vector_store/__init__.py +++ b/dataherald/vector_store/__init__.py @@ -27,7 +27,12 @@ def create_collection(self, collection: str): @abstractmethod def add_record( - self, documents: str, collection: str, metadata: Any, ids: List = None + self, + documents: str, + db_connection_id: str, + collection: str, + metadata: Any, + ids: List = None, ): pass diff --git a/dataherald/vector_store/chroma.py b/dataherald/vector_store/chroma.py index 5f79607f..284ca5d5 100644 --- a/dataherald/vector_store/chroma.py +++ b/dataherald/vector_store/chroma.py @@ -37,7 +37,14 @@ def query( return self.convert_to_pinecone_object_model(query_results) @override - def add_record(self, documents: str, collection: str, metadata: Any, ids: List): + def add_record( + self, + documents: str, + db_connection_id: str, # noqa: ARG002 + collection: str, + metadata: Any, + ids: List, + ): target_collection = self.chroma_client.get_or_create_collection(collection) existing_rows = target_collection.get(ids=ids) if len(existing_rows["documents"]) == 0: diff --git a/dataherald/vector_store/pinecone.py b/dataherald/vector_store/pinecone.py index e3ee324f..492dae7a 100644 --- a/dataherald/vector_store/pinecone.py +++ b/dataherald/vector_store/pinecone.py @@ -1,11 +1,13 @@ import os from typing import Any, List -import openai import pinecone +from langchain.embeddings import OpenAIEmbeddings from overrides import override from dataherald.config import System +from dataherald.db import DB +from dataherald.repositories.database_connections import DatabaseConnectionRepository from dataherald.vector_store import VectorStore EMBEDDING_MODEL = "text-embedding-ada-002" @@ -31,9 +33,14 @@ def query( num_results: int, ) -> list: index = pinecone.Index(collection) - xq = openai.Embedding.create(input=query_texts[0], engine=EMBEDDING_MODEL)[ - "data" - ][0]["embedding"] + db_connection_repository = DatabaseConnectionRepository( + self.system.instance(DB) + ) + database_connection = db_connection_repository.find_by_id(db_connection_id) + embedding = OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL + ) + xq = embedding.embed_query(query_texts[0]) query_response = index.query( queries=[xq], filter={ @@ -45,13 +52,25 @@ def query( return query_response.to_dict()["results"][0]["matches"] @override - def add_record(self, documents: str, collection: str, metadata: Any, ids: List): + def add_record( + self, + documents: str, + db_connection_id: str, + collection: str, + metadata: Any, + ids: List, + ): if collection not in pinecone.list_indexes(): self.create_collection(collection) - + db_connection_repository = DatabaseConnectionRepository( + self.system.instance(DB) + ) + database_connection = db_connection_repository.find_by_id(db_connection_id) + embedding = OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL + ) index = pinecone.Index(collection) - res = openai.Embedding.create(input=[documents], engine=EMBEDDING_MODEL) - embeds = [record["embedding"] for record in res["data"]] + embeds = embedding.embed_documents([documents]) record = [(ids[0], embeds, metadata[0])] index.upsert(vectors=record) From 14284048e0995a9ae4513f3d2776499bf8350b6a Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 31 Oct 2023 15:43:04 -0400 Subject: [PATCH 4/4] DH-4917/reformat with black --- dataherald/sql_generator/dataherald_sqlagent.py | 1 - dataherald/sql_generator/generates_nl_answer.py | 5 ++++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index ad35b072..70f70dfa 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -55,7 +55,6 @@ logger = logging.getLogger(__name__) - TOP_K = int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")) EMBEDDING_MODEL = "text-embedding-ada-002" diff --git a/dataherald/sql_generator/generates_nl_answer.py b/dataherald/sql_generator/generates_nl_answer.py index 54720cb0..d412a4d0 100644 --- a/dataherald/sql_generator/generates_nl_answer.py +++ b/dataherald/sql_generator/generates_nl_answer.py @@ -46,7 +46,10 @@ def execute(self, query_response: Response) -> Response: ) database = SQLDatabase.get_sql_engine(database_connection) query_response = create_sql_query_status( - database, query_response.sql_query, query_response, top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")) + database, + query_response.sql_query, + query_response, + top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")), ) system_message_prompt = SystemMessagePromptTemplate.from_template( SYSTEM_TEMPLATE