diff --git a/.env.example b/.env.example index 0f83b133..6245ff2c 100644 --- a/.env.example +++ b/.env.example @@ -1,8 +1,7 @@ # Openai info. All these fields are required for the engine to work. OPENAI_API_KEY = #This field is required for the engine to work. ORG_ID = -LLM_MODEL = 'gpt-4' #the openAI llm model that you want to use for evaluation and generating the nl answer. possible values: gpt-4, gpt-3.5-turbo -AGENT_LLM_MODEL = 'gpt-4-32k' # the llm model that you want to use for the agent, it should have a lrage context window. possible values: gpt-4-32k, gpt-3.5-turbo-16k +LLM_MODEL = 'gpt-4-1106-preview' #the openAI llm model that you want to use for evaluation and generating the nl answer. possible values: gpt-4, gpt-3.5-turbo DH_ENGINE_TIMEOUT = #timeout in seconds for the engine to return a response UPPER_LIMIT_QUERY_RETURN_ROWS = #The upper limit on number of rows returned from the query engine (equivalent to using LIMIT N in PostgreSQL/MySQL/SQlite). Defauls to 50 diff --git a/README.md b/README.md index 88898907..b2da2ba7 100644 --- a/README.md +++ b/README.md @@ -68,15 +68,12 @@ cp .env.example .env Specifically the following 5 fields must be manually set before the engine is started. -LLM_MODEL is employed by evaluators and natural language generators that do not necessitate an extensive context window. - -AGENT_LLM_MODEL, on the other hand, is utilized by the NL-to-SQL generator, which relies on a larger context window. +LLM_MODEL the LLM that you want to use. ``` #OpenAI credentials and model OPENAI_API_KEY = LLM_MODEL = -AGENT_LLM_MODEL = ORG_ID = #Encryption key for storing DB connection data in Mongo diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 4f18fc2e..d85f3bb9 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -10,7 +10,6 @@ import openai import pandas as pd import sqlalchemy -import tiktoken from bson.objectid import ObjectId from google.api_core.exceptions import GoogleAPIError from langchain.agents.agent import AgentExecutor @@ -41,7 +40,6 @@ DatabaseConnection, ) from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator -from dataherald.sql_generator.adaptive_agent_executor import AdaptiveAgentExecutor from dataherald.types import Question, Response from dataherald.utils.agent_prompts import ( AGENT_PREFIX, @@ -53,7 +51,6 @@ SUFFIX_WITH_FEW_SHOT_SAMPLES, SUFFIX_WITHOUT_FEW_SHOT_SAMPLES, ) -from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES logger = logging.getLogger(__name__) @@ -155,6 +152,8 @@ def _run( run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: """Execute the query, return the results or an error message.""" + if '```sql' in query: + query = query.replace('```sql', '').replace('```', '') return self.db.run_sql(query, top_k=top_k)[0] async def _arun( @@ -581,24 +580,15 @@ def create_sql_agent( input_variables=input_variables, ) llm_chain = LLMChain( - llm=self.short_context_llm, + llm=self.llm, prompt=prompt, callback_manager=callback_manager, ) tool_names = [tool.name for tool in tools] agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) - return AdaptiveAgentExecutor.from_agent_and_tools( + return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, - llm_list={ - "short_context_llm": self.short_context_llm, - "long_context_llm": self.long_context_llm, - }, - switch_to_larger_model_threshold=OPENAI_CONTEXT_WIDNOW_SIZES[ - self.short_context_llm.model_name - ] - - 500, - encoding=tiktoken.encoding_for_model(self.short_context_llm.model_name), callback_manager=callback_manager, verbose=verbose, max_iterations=max_iterations, @@ -617,15 +607,10 @@ def generate_response( start_time = time.time() context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) - self.short_context_llm = self.model.get_model( + self.llm = self.model.get_model( database_connection=database_connection, temperature=0, - model_name=os.getenv("LLM_MODEL", "gpt-4"), - ) - self.long_context_llm = self.model.get_model( - database_connection=database_connection, - temperature=0, - model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), + model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"), ) repository = TableDescriptionRepository(storage) db_scan = repository.get_all_tables_by_db( diff --git a/dataherald/sql_generator/langchain_sqlagent.py b/dataherald/sql_generator/langchain_sqlagent.py index 99201478..3aaebc01 100644 --- a/dataherald/sql_generator/langchain_sqlagent.py +++ b/dataherald/sql_generator/langchain_sqlagent.py @@ -34,7 +34,7 @@ def generate_response( self.llm = self.model.get_model( database_connection=database_connection, temperature=0, - model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), + model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"), ) self.database = SQLDatabase.get_sql_engine(database_connection) tools = SQLDatabaseToolkit(db=self.database, llm=self.llm).get_tools() diff --git a/dataherald/sql_generator/langchain_sqlchain.py b/dataherald/sql_generator/langchain_sqlchain.py index 48e0bdad..c615d009 100644 --- a/dataherald/sql_generator/langchain_sqlchain.py +++ b/dataherald/sql_generator/langchain_sqlchain.py @@ -52,7 +52,7 @@ def generate_response( self.llm = self.model.get_model( database_connection=database_connection, temperature=0, - model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), + model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"), ) self.database = SQLDatabase.get_sql_engine(database_connection) logger.info( diff --git a/dataherald/sql_generator/llamaindex.py b/dataherald/sql_generator/llamaindex.py index 59de3c90..a673e32a 100644 --- a/dataherald/sql_generator/llamaindex.py +++ b/dataherald/sql_generator/llamaindex.py @@ -41,7 +41,7 @@ def generate_response( self.llm = self.model.get_model( database_connection=database_connection, temperature=0, - model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"), + model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"), ) token_counter = TokenCountingHandler( tokenizer=tiktoken.encoding_for_model(self.llm.model_name).encode,