From 5d9a5215b12c85a194514d8cff049e2bbdc2ff6f Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 14 Nov 2023 12:17:28 -0500 Subject: [PATCH] DH-5001/removing the markdowns from queries for gpt-4-turbo --- dataherald/sql_generator/dataherald_sqlagent.py | 7 ++++++- dataherald/utils/models_context_window.py | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 1de633c4..cf03707e 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -155,6 +155,8 @@ def _run( run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: """Execute the query, return the results or an error message.""" + if '```sql' in query: + query = query.replace('```sql', '').replace('```', '') return self.db.run_sql(query, top_k=top_k)[0] async def _arun( @@ -690,7 +692,10 @@ def generate_response( for step in result["intermediate_steps"]: action = step[0] if type(action) == AgentAction and action.tool == "sql_db_query": - sql_query_list.append(self.format_sql_query(action.tool_input)) + query = self.format_sql_query(action.tool_input) + if '```sql' in query: + query = query.replace('```sql', '').replace('```', '') + sql_query_list.append(query) intermediate_steps = self.format_intermediate_representations( result["intermediate_steps"] ) diff --git a/dataherald/utils/models_context_window.py b/dataherald/utils/models_context_window.py index 97687971..9304879e 100644 --- a/dataherald/utils/models_context_window.py +++ b/dataherald/utils/models_context_window.py @@ -11,4 +11,6 @@ "gpt-3.5-turbo-0613": 4000, "gpt-3.5-turbo-16k-0613": 16000, "gpt-3.5-turbo-0301": 4000, + "gpt-4-1106-preview": 128000, + "gpt-3.5-turbo-1106": 16000, }