Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DH-4917 Encrypt llm_api_key #235

Merged
merged 5 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def add_golden_records(
golden_record = golden_records_repository.insert(golden_record)
self.vector_store.add_record(
documents=question,
db_connection_id=record.db_connection_id,
collection=self.golden_record_collection,
metadata=[
{
Expand Down
File renamed without changes.
34 changes: 13 additions & 21 deletions dataherald/model/chat_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import os
from typing import Any

from langchain.chat_models import ChatAnthropic, ChatCohere, ChatGooglePalm, ChatOpenAI
from overrides import override

from dataherald.model import LLMModel
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.utils.encrypt import FernetEncrypt


class ChatModel(LLMModel):
Expand All @@ -21,23 +19,17 @@ def get_model(
model_name="gpt-4-32k",
**kwargs: Any
) -> Any:
if database_connection.llm_api_key is not None:
fernet_encrypt = FernetEncrypt()
api_key = fernet_encrypt.decrypt(database_connection.llm_api_key)
if model_family == "openai":
os.environ["OPENAI_API_KEY"] = api_key
elif model_family == "anthropic":
os.environ["ANTHROPIC_API_KEY"] = api_key
elif model_family == "google":
os.environ["GOOGLE_API_KEY"] = api_key
elif model_family == "cohere":
os.environ["COHERE_API_KEY"] = api_key
if os.environ.get("OPENAI_API_KEY") is not None:
return ChatOpenAI(model_name=model_name, **kwargs)
if os.environ.get("ANTHROPIC_API_KEY") is not None:
return ChatAnthropic(model_name=model_name, **kwargs)
if os.environ.get("GOOGLE_API_KEY") is not None:
return ChatGooglePalm(model_name=model_name, **kwargs)
if os.environ.get("COHERE_API_KEY") is not None:
return ChatCohere(model_name=model_name, **kwargs)
api_key = database_connection.decrypt_api_key()
if model_family == "openai":
return ChatOpenAI(model_name=model_name, openai_api_key=api_key, **kwargs)
if model_family == "anthropic":
return ChatAnthropic(
model_name=model_name, anthropic_api_key=api_key, **kwargs
)
if model_family == "google":
return ChatGooglePalm(
model_name=model_name, google_api_key=api_key, **kwargs
)
if model_family == "cohere":
return ChatCohere(model_name=model_name, cohere_api_key=api_key, **kwargs)
raise ValueError("No valid API key environment variable found")
1 change: 1 addition & 0 deletions dataherald/scripts/delete_and_populate_golden_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
question = golden_record["question"]
vector_store.add_record(
documents=question,
db_connection_id=golden_record["db_connection_id"],
collection=golden_record_collection,
metadata=[
{
Expand Down
1 change: 1 addition & 0 deletions dataherald/scripts/migrate_v001_to_v002.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def add_db_connection_id(collection_name: str, storage) -> None:
question = golden_record["question"]
vector_store.add_record(
documents=question,
db_connection_id=golden_record["db_connection_id"],
collection=golden_record_collection,
metadata=[
{
Expand Down
9 changes: 8 additions & 1 deletion dataherald/sql_database/models/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Any

from pydantic import BaseModel, BaseSettings, Extra, validator
Expand Down Expand Up @@ -68,11 +69,17 @@ def set_uri_without_ssh(cls, v, values):
raise ValueError("When use_ssh is False set uri")
return v

@validator("uri", pre=True, always=True)
@validator("uri", "llm_api_key", pre=True, always=True)
def encrypt(cls, value: str):
fernet_encrypt = FernetEncrypt()
try:
fernet_encrypt.decrypt(value)
return value
except Exception:
return fernet_encrypt.encrypt(value)

def decrypt_api_key(self):
if self.llm_api_key is not None and self.llm_api_key != "":
fernet_encrypt = FernetEncrypt()
return fernet_encrypt.decrypt(self.llm_api_key)
return os.environ.get("OPENAI_API_KEY")
20 changes: 15 additions & 5 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import AgentAction
from langchain.tools.base import BaseTool
from overrides import override
Expand Down Expand Up @@ -55,6 +56,7 @@


TOP_K = int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50"))
EMBEDDING_MODEL = "text-embedding-ada-002"


def catch_exceptions(): # noqa: C901
Expand Down Expand Up @@ -200,14 +202,14 @@ class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
Use this tool to identify the relevant tables for the given question.
"""
db_scan: List[TableDescription]
embedding: OpenAIEmbeddings

def get_embedding(
self, text: str, model: str = "text-embedding-ada-002"
self,
text: str,
) -> List[float]:
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], model=model)["data"][0][
"embedding"
]
return self.embedding.embed_query(text)

def cosine_similarity(self, a: List[float], b: List[float]) -> float:
return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4)
Expand Down Expand Up @@ -463,6 +465,7 @@ class SQLDatabaseToolkit(BaseToolkit):
few_shot_examples: List[dict] | None = Field(exclude=True, default=None)
instructions: List[dict] | None = Field(exclude=True, default=None)
db_scan: List[TableDescription] = Field(exclude=True)
embedding: OpenAIEmbeddings = Field(exclude=True)

@property
def dialect(self) -> str:
Expand All @@ -488,7 +491,10 @@ def get_tools(self) -> List[BaseTool]:
get_current_datetime = GetCurrentTimeTool(db=self.db, context=self.context)
tools.append(get_current_datetime)
tables_sql_db_tool = TablesSQLDatabaseTool(
db=self.db, context=self.context, db_scan=self.db_scan
db=self.db,
context=self.context,
db_scan=self.db_scan,
embedding=self.embedding,
)
tools.append(tables_sql_db_tool)
schema_sql_db_tool = SchemaSQLDatabaseTool(
Expand Down Expand Up @@ -630,6 +636,10 @@ def generate_response(
few_shot_examples=new_fewshot_examples,
instructions=instructions,
db_scan=db_scan,
embedding=OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
),
)
agent_executor = self.create_sql_agent(
toolkit=toolkit,
Expand Down
5 changes: 4 additions & 1 deletion dataherald/sql_generator/generates_nl_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def execute(self, query_response: Response) -> Response:
)
database = SQLDatabase.get_sql_engine(database_connection)
query_response = create_sql_query_status(
database, query_response.sql_query, query_response, top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50"))
database,
query_response.sql_query,
query_response,
top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")),
)
system_message_prompt = SystemMessagePromptTemplate.from_template(
SYSTEM_TEMPLATE
Expand Down
7 changes: 6 additions & 1 deletion dataherald/tests/vector_store/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ def query(

@override
def add_record(
self, documents: str, collection: str, metadata: Any, ids: List # noqa: ARG002
self,
documents: str,
db_connection_id: str,
collection: str,
metadata: Any,
ids: List, # noqa: ARG002
):
pass

Expand Down
7 changes: 6 additions & 1 deletion dataherald/vector_store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ def create_collection(self, collection: str):

@abstractmethod
def add_record(
self, documents: str, collection: str, metadata: Any, ids: List = None
self,
documents: str,
db_connection_id: str,
collection: str,
metadata: Any,
ids: List = None,
):
pass

Expand Down
9 changes: 8 additions & 1 deletion dataherald/vector_store/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ def query(
return self.convert_to_pinecone_object_model(query_results)

@override
def add_record(self, documents: str, collection: str, metadata: Any, ids: List):
def add_record(
self,
documents: str,
db_connection_id: str, # noqa: ARG002
collection: str,
metadata: Any,
ids: List,
):
target_collection = self.chroma_client.get_or_create_collection(collection)
existing_rows = target_collection.get(ids=ids)
if len(existing_rows["documents"]) == 0:
Expand Down
35 changes: 27 additions & 8 deletions dataherald/vector_store/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from typing import Any, List

import openai
import pinecone
from langchain.embeddings import OpenAIEmbeddings
from overrides import override

from dataherald.config import System
from dataherald.db import DB
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.vector_store import VectorStore

EMBEDDING_MODEL = "text-embedding-ada-002"
Expand All @@ -31,9 +33,14 @@ def query(
num_results: int,
) -> list:
index = pinecone.Index(collection)
xq = openai.Embedding.create(input=query_texts[0], engine=EMBEDDING_MODEL)[
"data"
][0]["embedding"]
db_connection_repository = DatabaseConnectionRepository(
self.system.instance(DB)
)
database_connection = db_connection_repository.find_by_id(db_connection_id)
embedding = OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL
)
xq = embedding.embed_query(query_texts[0])
query_response = index.query(
queries=[xq],
filter={
Expand All @@ -45,13 +52,25 @@ def query(
return query_response.to_dict()["results"][0]["matches"]

@override
def add_record(self, documents: str, collection: str, metadata: Any, ids: List):
def add_record(
self,
documents: str,
db_connection_id: str,
collection: str,
metadata: Any,
ids: List,
):
if collection not in pinecone.list_indexes():
self.create_collection(collection)

db_connection_repository = DatabaseConnectionRepository(
self.system.instance(DB)
)
database_connection = db_connection_repository.find_by_id(db_connection_id)
embedding = OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL
)
index = pinecone.Index(collection)
res = openai.Embedding.create(input=[documents], engine=EMBEDDING_MODEL)
embeds = [record["embedding"] for record in res["data"]]
embeds = embedding.embed_documents([documents])
record = [(ids[0], embeds, metadata[0])]
index.upsert(vectors=record)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ sphinx-book-theme==1.0.1
boto3==1.28.38
botocore==1.31.38
PyAthena==3.0.6
tiktoken==0.5.1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we using this package?