diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index d057805c..596c3010 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -373,7 +373,7 @@ def create_sql_agent( **kwargs: Dict[str, Any], ) -> AgentExecutor: tools = toolkit.get_tools() - prefix = prefix.format(dialect=toolkit.dialect) + prefix = prefix.format(dialect=toolkit.dialect, admin_instructions=toolkit.instructions) prompt = ZeroShotAgent.create_prompt( tools, prefix=prefix, diff --git a/dataherald/utils/agent_prompts.py b/dataherald/utils/agent_prompts.py index fe1f999c..6ea9efdc 100644 --- a/dataherald/utils/agent_prompts.py +++ b/dataherald/utils/agent_prompts.py @@ -114,6 +114,10 @@ Using `current_date()` or `current_datetime()` in SQL queries is banned, use system_time tool to get the exact time of the query execution. If running the SQL query results in an error, rewrite the SQL query and try again. You can use db_schema tool to get the schema of the database. +only rewrite the query when the execution returned an error. Only rely on generate_sql tool to generate the SQL query. If the question does not seem related to the database, explain why you cannot answer the question. + +Here are the database admin instructions, that all queries must follow: +{admin_instructions} """ # noqa: E501