From 5e5fa52a664b90629ad4eb1b0b368114cde4a67e Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Mon, 6 Nov 2023 17:31:53 -0500 Subject: [PATCH] DH-4952/initial commits --- dataherald/api/fastapi.py | 2 +- dataherald/sql_generator/__init__.py | 2 +- dataherald/sql_generator/dataherald_sqlagent.py | 8 ++++---- dataherald/sql_generator/langchain_sqlagent.py | 10 +++++----- dataherald/sql_generator/langchain_sqlchain.py | 10 +++++----- dataherald/sql_generator/llamaindex.py | 10 +++++----- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 0e6c870f..c51b4048 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -151,7 +151,7 @@ def answer_question( context = context_store.retrieve_context_for_question(user_question) start_generated_answer = time.time() generated_answer = sql_generation.generate_response( - user_question, database_connection, context[0] + user_question, database_connection, context ) logger.info("Starts evaluator...") if run_evaluator: diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 9d5673c4..805cb42b 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -75,7 +75,7 @@ def generate_response( self, user_question: Question, database_connection: DatabaseConnection, - context: List[dict] = None, + context: Tuple[List[dict] | None, List[dict] | None], ) -> Response: """Generates a response to a user question.""" pass diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 4f18fc2e..55f36674 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -4,7 +4,7 @@ import os import time from functools import wraps -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Tuple import numpy as np import openai @@ -100,7 +100,7 @@ class BaseSQLDatabaseTool(BaseModel): """Base tool for interacting with the SQL database and the context information.""" db: SQLDatabase = Field(exclude=True) - context: List[dict] | None = Field(exclude=True, default=None) + context: Tuple[List[dict] | None, List[dict] | None] = Field(exclude=True, default=None) class Config(BaseTool.Config): """Configuration for this pydantic object.""" @@ -464,7 +464,7 @@ class SQLDatabaseToolkit(BaseToolkit): """Dataherald toolkit""" db: SQLDatabase = Field(exclude=True) - context: List[dict] | None = Field(exclude=True, default=None) + context: Tuple[List[dict] | None, List[dict] | None] = Field(exclude=True, default=None) few_shot_examples: List[dict] | None = Field(exclude=True, default=None) instructions: List[dict] | None = Field(exclude=True, default=None) db_scan: List[TableDescription] = Field(exclude=True) @@ -612,7 +612,7 @@ def generate_response( self, user_question: Question, database_connection: DatabaseConnection, - context: List[dict] = None, + context: Tuple[List[dict] | None, List[dict] | None], ) -> Response: start_time = time.time() context_store = self.system.instance(ContextStore) diff --git a/dataherald/sql_generator/langchain_sqlagent.py b/dataherald/sql_generator/langchain_sqlagent.py index 99201478..0f86e97a 100644 --- a/dataherald/sql_generator/langchain_sqlagent.py +++ b/dataherald/sql_generator/langchain_sqlagent.py @@ -3,7 +3,7 @@ import logging import os import time -from typing import Any, List +from typing import Any, List, Tuple from langchain.agents import initialize_agent from langchain.agents.agent_toolkits import SQLDatabaseToolkit @@ -28,7 +28,7 @@ def generate_response( self, user_question: Question, database_connection: DatabaseConnection, - context: List[dict] = None, + context: Tuple[List[dict] | None, List[dict] | None], ) -> Response: # type: ignore logger.info(f"Generating SQL response to question: {str(user_question.dict())}") self.llm = self.model.get_model( @@ -48,10 +48,10 @@ def generate_response( handle_parsing_errors=True, return_intermediate_steps=True, ) - if context is not None: + if context[0] is not None: samples_prompt_string = "The following are some similar previous questions and their correct SQL queries from these databases: \ \n" - for sample in context: + for sample in context[0]: samples_prompt_string += ( f"Question: {sample['nl_question']} \nSQL: {sample['sql_query']} \n" ) @@ -59,7 +59,7 @@ def generate_response( question_with_context = ( f"{user_question.question} An example of a similar question and the query that was generated \ to answer it is the following {samples_prompt_string}" - if context is not None + if context[0] is not None else user_question.question ) with get_openai_callback() as cb: diff --git a/dataherald/sql_generator/langchain_sqlchain.py b/dataherald/sql_generator/langchain_sqlchain.py index 48e0bdad..0a3f7d5d 100644 --- a/dataherald/sql_generator/langchain_sqlchain.py +++ b/dataherald/sql_generator/langchain_sqlchain.py @@ -3,7 +3,7 @@ import logging import os import time -from typing import Any, List +from typing import Any, List, Tuple from langchain import SQLDatabaseChain from langchain.callbacks import get_openai_callback @@ -46,7 +46,7 @@ def generate_response( self, user_question: Question, database_connection: DatabaseConnection, - context: List[dict] = None, + context: Tuple[List[dict] | None, List[dict] | None], ) -> Response: start_time = time.time() self.llm = self.model.get_model( @@ -56,12 +56,12 @@ def generate_response( ) self.database = SQLDatabase.get_sql_engine(database_connection) logger.info( - f"Generating SQL response to question: {str(user_question.dict())} with passed context {context}" + f"Generating SQL response to question: {str(user_question.dict())} with passed context {context[0]}" ) - if context is not None: + if context[0] is not None: samples_prompt_string = "The following are some similar previous questions and their correct SQL queries from these databases: \ \n" - for sample in context: + for sample in context[0]: samples_prompt_string += ( f"Question: {sample['nl_question']} \nSQL: {sample['sql_query']} \n" ) diff --git a/dataherald/sql_generator/llamaindex.py b/dataherald/sql_generator/llamaindex.py index 59de3c90..ec87be71 100644 --- a/dataherald/sql_generator/llamaindex.py +++ b/dataherald/sql_generator/llamaindex.py @@ -3,7 +3,7 @@ import logging import os import time -from typing import Any, List +from typing import Any, List, Tuple import tiktoken from langchain.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS @@ -34,7 +34,7 @@ def generate_response( self, user_question: Question, database_connection: DatabaseConnection, - context: List[dict] = None, + context: Tuple[List[dict] | None, List[dict] | None], ) -> Response: start_time = time.time() logger.info(f"Generating SQL response to question: {str(user_question.dict())}") @@ -55,17 +55,17 @@ def generate_response( metadata_obj.reflect(db_engine) table_schema_objs = [] table_node_mapping = SQLTableNodeMapping(self.database) - if context is not None: + if context[0] is not None: samples_prompt_string = "The following are some similar previous questions and their correct SQL queries from these databases: \ \n" - for sample in context: + for sample in context[0]: samples_prompt_string += ( f"Question: {sample['nl_question']} \nSQL: {sample['sql_query']} \n" ) question_with_context = ( f"{user_question.question} An example of a similar question and the query that was generated to answer it \ is the following {samples_prompt_string}" - if context is not None + if context[0] is not None else user_question.question ) for table_name in metadata_obj.tables.keys():