diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index d057805c..55090f6a 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -373,7 +373,12 @@ def create_sql_agent( **kwargs: Dict[str, Any], ) -> AgentExecutor: tools = toolkit.get_tools() - prefix = prefix.format(dialect=toolkit.dialect) + admin_instructions = "" + for index, instruction in enumerate(toolkit.instructions): + admin_instructions += f"{index+1}) {instruction['instruction']}\n" + prefix = prefix.format( + dialect=toolkit.dialect, admin_instructions=admin_instructions + ) prompt = ZeroShotAgent.create_prompt( tools, prefix=prefix, diff --git a/dataherald/utils/agent_prompts.py b/dataherald/utils/agent_prompts.py index fe1f999c..b3be68a8 100644 --- a/dataherald/utils/agent_prompts.py +++ b/dataherald/utils/agent_prompts.py @@ -114,6 +114,10 @@ Using `current_date()` or `current_datetime()` in SQL queries is banned, use system_time tool to get the exact time of the query execution. If running the SQL query results in an error, rewrite the SQL query and try again. You can use db_schema tool to get the schema of the database. +only rewrite the query when the execution returned an error or it does not follow the instructions provided by the database administrator. Only rely on generate_sql tool to generate the SQL query. If the question does not seem related to the database, explain why you cannot answer the question. + +Here are the database admin instructions, that all queries must follow: +{admin_instructions} """ # noqa: E501