Skip to content

Commit

Permalink
DH-4952/initial commits
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Nov 6, 2023
1 parent 6abc80a commit 5e5fa52
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def answer_question(
context = context_store.retrieve_context_for_question(user_question)
start_generated_answer = time.time()
generated_answer = sql_generation.generate_response(
user_question, database_connection, context[0]
user_question, database_connection, context
)
logger.info("Starts evaluator...")
if run_evaluator:
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def generate_response(
self,
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
context: Tuple[List[dict] | None, List[dict] | None],
) -> Response:
"""Generates a response to a user question."""
pass
8 changes: 4 additions & 4 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import time
from functools import wraps
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Tuple

import numpy as np
import openai
Expand Down Expand Up @@ -100,7 +100,7 @@ class BaseSQLDatabaseTool(BaseModel):
"""Base tool for interacting with the SQL database and the context information."""

db: SQLDatabase = Field(exclude=True)
context: List[dict] | None = Field(exclude=True, default=None)
context: Tuple[List[dict] | None, List[dict] | None] = Field(exclude=True, default=None)

class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -464,7 +464,7 @@ class SQLDatabaseToolkit(BaseToolkit):
"""Dataherald toolkit"""

db: SQLDatabase = Field(exclude=True)
context: List[dict] | None = Field(exclude=True, default=None)
context: Tuple[List[dict] | None, List[dict] | None] = Field(exclude=True, default=None)
few_shot_examples: List[dict] | None = Field(exclude=True, default=None)
instructions: List[dict] | None = Field(exclude=True, default=None)
db_scan: List[TableDescription] = Field(exclude=True)
Expand Down Expand Up @@ -612,7 +612,7 @@ def generate_response(
self,
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
context: Tuple[List[dict] | None, List[dict] | None],
) -> Response:
start_time = time.time()
context_store = self.system.instance(ContextStore)
Expand Down
10 changes: 5 additions & 5 deletions dataherald/sql_generator/langchain_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
import time
from typing import Any, List
from typing import Any, List, Tuple

from langchain.agents import initialize_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
Expand All @@ -28,7 +28,7 @@ def generate_response(
self,
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
context: Tuple[List[dict] | None, List[dict] | None],
) -> Response: # type: ignore
logger.info(f"Generating SQL response to question: {str(user_question.dict())}")
self.llm = self.model.get_model(
Expand All @@ -48,18 +48,18 @@ def generate_response(
handle_parsing_errors=True,
return_intermediate_steps=True,
)
if context is not None:
if context[0] is not None:
samples_prompt_string = "The following are some similar previous questions and their correct SQL queries from these databases: \
\n"
for sample in context:
for sample in context[0]:
samples_prompt_string += (
f"Question: {sample['nl_question']} \nSQL: {sample['sql_query']} \n"
)

question_with_context = (
f"{user_question.question} An example of a similar question and the query that was generated \
to answer it is the following {samples_prompt_string}"
if context is not None
if context[0] is not None
else user_question.question
)
with get_openai_callback() as cb:
Expand Down
10 changes: 5 additions & 5 deletions dataherald/sql_generator/langchain_sqlchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
import time
from typing import Any, List
from typing import Any, List, Tuple

from langchain import SQLDatabaseChain
from langchain.callbacks import get_openai_callback
Expand Down Expand Up @@ -46,7 +46,7 @@ def generate_response(
self,
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
context: Tuple[List[dict] | None, List[dict] | None],
) -> Response:
start_time = time.time()
self.llm = self.model.get_model(
Expand All @@ -56,12 +56,12 @@ def generate_response(
)
self.database = SQLDatabase.get_sql_engine(database_connection)
logger.info(
f"Generating SQL response to question: {str(user_question.dict())} with passed context {context}"
f"Generating SQL response to question: {str(user_question.dict())} with passed context {context[0]}"
)
if context is not None:
if context[0] is not None:
samples_prompt_string = "The following are some similar previous questions and their correct SQL queries from these databases: \
\n"
for sample in context:
for sample in context[0]:
samples_prompt_string += (
f"Question: {sample['nl_question']} \nSQL: {sample['sql_query']} \n"
)
Expand Down
10 changes: 5 additions & 5 deletions dataherald/sql_generator/llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
import time
from typing import Any, List
from typing import Any, List, Tuple

import tiktoken
from langchain.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS
Expand Down Expand Up @@ -34,7 +34,7 @@ def generate_response(
self,
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
context: Tuple[List[dict] | None, List[dict] | None],
) -> Response:
start_time = time.time()
logger.info(f"Generating SQL response to question: {str(user_question.dict())}")
Expand All @@ -55,17 +55,17 @@ def generate_response(
metadata_obj.reflect(db_engine)
table_schema_objs = []
table_node_mapping = SQLTableNodeMapping(self.database)
if context is not None:
if context[0] is not None:
samples_prompt_string = "The following are some similar previous questions and their correct SQL queries from these databases: \
\n"
for sample in context:
for sample in context[0]:
samples_prompt_string += (
f"Question: {sample['nl_question']} \nSQL: {sample['sql_query']} \n"
)
question_with_context = (
f"{user_question.question} An example of a similar question and the query that was generated to answer it \
is the following {samples_prompt_string}"
if context is not None
if context[0] is not None
else user_question.question
)
for table_name in metadata_obj.tables.keys():
Expand Down

0 comments on commit 5e5fa52

Please sign in to comment.