Skip to content

Commit

Permalink
DH-5025/minor changes to the new agent (#262)
Browse files Browse the repository at this point in the history
* DH-5025/minor changes to the new agent

* Dh-5025/reformat with black
  • Loading branch information
MohammadrezaPourreza authored Nov 22, 2023
1 parent d783e76 commit 2e451c5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
30 changes: 30 additions & 0 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,35 @@ async def _arun(
raise NotImplementedError("SystemTime tool does not support async")


class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool which takes in the given question and returns a list of tables with their relevance score to the question"""

name = "get_db_table_names"
description = """
Use this tool to get the list of tables in the database.
"""
db_scan: List[TableDescription]

@catch_exceptions()
def _run(
self,
input: str, # noqa: ARG002
run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002
) -> str:
"""Use the concatenation of table name, columns names, and the description of the table as the table representation"""
tables = []
for table in self.db_scan:
tables.append(table.table_name)
return f"Tables in the database: {','.join(tables)}"

async def _arun(
self,
input: str = "",
run_manager: AsyncCallbackManagerForToolRun | None = None,
) -> str:
raise NotImplementedError("TablesSQLDatabaseTool does not support async")


class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for querying a SQL database."""

Expand Down Expand Up @@ -347,6 +376,7 @@ def get_tools(self) -> List[BaseTool]:
)
)
tools.append(SchemaSQLDatabaseTool(db=self.db, db_scan=self.db_scan))
tools.append(TablesSQLDatabaseTool(db=self.db, db_scan=self.db_scan))
return tools


Expand Down
8 changes: 4 additions & 4 deletions dataherald/utils/agent_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@
If the question is complex:
1) Break the question into sub-questions.
2) Find the SQL query for each sub-question by using the generate_sql tool for each sub-question.
3) Combine the SQL queries for each sub-question into a single SQL query.
3) Combine the SQL queries for each sub-question into a single SQL query by using set operations, sub_uqeires, or nested queries.
Using `current_date()` or `current_datetime()` in SQL queries is banned, use system_time tool to get the exact time of the query execution.
If running the SQL query results in an error, rewrite the SQL query and try again. You can use db_schema tool to get the schema of the database.
only rewrite the query when the execution returned an error or it does not follow the instructions provided by the database administrator.
Only rely on generate_sql tool to generate the SQL query.
If running the SQL query results in an error, rewrite the SQL query and try again. You can use db_schema tool to get the schema of the database and get_db_table_names tool to get the names of the tables in the database.
You can only make minor changes to the SQL query generated by the generate_sql tool, when it does not follow the instructions provided by the database administrator or when it results in an error.
If the question does not seem related to the database, explain why you cannot answer the question.
For query editing, you do not need to use generate_sql tool, you can edit the SQL query directly.
Here are the database admin instructions, that all queries must follow:
{admin_instructions}
Expand Down

0 comments on commit 2e451c5

Please sign in to comment.