From 6d40e930bc15dd9ac584caaba657f36b2ef7894e Mon Sep 17 00:00:00 2001 From: Dishen <44216194+DishenWang2023@users.noreply.github.com> Date: Fri, 3 Nov 2023 12:30:25 -0500 Subject: [PATCH] DH-4883 save send run refactor (#237) * DH-4883 Add and update responses endpoint to only process the sql_query * Add run_evaluator flag for endpoints that create a response --------- Co-authored-by: Juan Carlos Jose Camacho --- dataherald/api/__init__.py | 17 +++++- dataherald/api/fastapi.py | 60 ++++++++++++++----- dataherald/server/fastapi/__init__.py | 32 ++++++++-- .../sql_generator/generates_nl_answer.py | 50 +++++++++------- 4 files changed, 115 insertions(+), 44 deletions(-) diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index 21ef7df7..628225de 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -36,12 +36,14 @@ def scan_db( pass @abstractmethod - def answer_question(self, question_request: QuestionRequest) -> Response: + def answer_question( + self, run_evaluator: bool = True, question_request: QuestionRequest = None + ) -> Response: pass @abstractmethod def answer_question_with_timeout( - self, question_request: QuestionRequest + self, run_evaluator: bool = True, question_request: QuestionRequest = None ) -> Response: pass @@ -100,7 +102,12 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]: pass @abstractmethod - def create_response(self, query_request: CreateResponseRequest) -> Response: + def create_response( + self, + run_evaluator: bool = True, + sql_response_only: bool = False, + query_request: CreateResponseRequest = None, + ) -> Response: pass @abstractmethod @@ -111,6 +118,10 @@ def get_responses(self, question_id: str | None = None) -> list[Response]: def get_response(self, response_id: str) -> Response: pass + @abstractmethod + def update_response(self, response_id: str) -> Response: + pass + @abstractmethod def delete_golden_record(self, golden_record_id: str) -> dict: pass diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index d4ed1c50..f98cda7f 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -118,11 +118,12 @@ def scan_db( return True @override - def answer_question(self, question_request: QuestionRequest) -> Response: + def answer_question( + self, run_evaluator: bool = True, question_request: QuestionRequest = None + ) -> Response: """Takes in an English question and answers it based on content from the registered databases""" logger.info(f"Answer question: {question_request.question}") sql_generation = self.system.instance(SQLGenerator) - evaluator = self.system.instance(Evaluator) context_store = self.system.instance(ContextStore) user_question = Question( @@ -152,22 +153,24 @@ def answer_question(self, question_request: QuestionRequest) -> Response: user_question, database_connection, context[0] ) logger.info("Starts evaluator...") - confidence_score = evaluator.get_confidence_score( - user_question, generated_answer, database_connection - ) + if run_evaluator: + evaluator = self.system.instance(Evaluator) + confidence_score = evaluator.get_confidence_score( + user_question, generated_answer, database_connection + ) + generated_answer.confidence_score = confidence_score except Exception as e: return JSONResponse( status_code=400, content={"question_id": user_question.id, "error_message": str(e)}, ) - generated_answer.confidence_score = confidence_score generated_answer.exec_time = time.time() - start_generated_answer response_repository = ResponseRepository(self.storage) return response_repository.insert(generated_answer) @override def answer_question_with_timeout( - self, question_request: QuestionRequest + self, run_evaluator: bool = True, question_request: QuestionRequest = None ) -> Response: result = None exception = None @@ -182,7 +185,7 @@ def answer_question_with_timeout( def run_and_catch_exceptions(): nonlocal result, exception if not stop_event.is_set(): - result = self.answer_question(question_request) + result = self.answer_question(run_evaluator, question_request) thread = threading.Thread(target=run_and_catch_exceptions) thread.start() @@ -348,6 +351,29 @@ def get_response(self, response_id: str) -> Response: return result + @override + def update_response(self, response_id: str) -> Response: + response_repository = ResponseRepository(self.storage) + + try: + response = response_repository.find_by_id(response_id) + except InvalidId as e: + raise HTTPException(status_code=400, detail=str(e)) from e + if not response: + raise HTTPException(status_code=404, detail="Question not found") + + start_generated_answer = time.time() + try: + generates_nl_answer = GeneratesNlAnswer(self.system, self.storage) + response = generates_nl_answer.execute(response) + response.exec_time = time.time() - start_generated_answer + response_repository.update(response) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except SQLInjectionError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + return response + @override def get_questions(self, db_connection_id: str | None = None) -> list[Question]: question_repository = QuestionRepository(self.storage) @@ -397,9 +423,11 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]: @override def create_response( - self, query_request: CreateResponseRequest # noqa: ARG002 + self, + run_evaluator: bool = True, + sql_response_only: bool = False, + query_request: CreateResponseRequest = None, # noqa: ARG002 ) -> Response: - evaluator = self.system.instance(Evaluator) question_repository = QuestionRepository(self.storage) user_question = question_repository.find_by_id(query_request.question_id) db_connection_repository = DatabaseConnectionRepository(self.storage) @@ -417,11 +445,13 @@ def create_response( start_generated_answer = time.time() try: generates_nl_answer = GeneratesNlAnswer(self.system, self.storage) - response = generates_nl_answer.execute(response) - confidence_score = evaluator.get_confidence_score( - user_question, response, database_connection - ) - response.confidence_score = confidence_score + response = generates_nl_answer.execute(response, sql_response_only) + if run_evaluator: + evaluator = self.system.instance(Evaluator) + confidence_score = evaluator.get_confidence_score( + user_question, response, database_connection + ) + response.confidence_score = confidence_score response.exec_time = time.time() - start_generated_answer response_repository.update(response) except ValueError as e: diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 89da742e..97f31702 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -164,6 +164,13 @@ def __init__(self, settings: Settings): tags=["Responses"], ) + self.router.add_api_route( + "/api/v1/responses/{response_id}", + self.update_response, + methods=["PATCH"], + tags=["Responses"], + ) + self.router.add_api_route( "/api/v1/sql-query-executions", self.execute_sql_query, @@ -216,10 +223,14 @@ def scan_db( ) -> bool: return self._api.scan_db(scanner_request, background_tasks) - def answer_question(self, question_request: QuestionRequest) -> Response: + def answer_question( + self, run_evaluator: bool = True, question_request: QuestionRequest = None + ) -> Response: if os.getenv("DH_ENGINE_TIMEOUT", None): - return self._api.answer_question_with_timeout(question_request) - return self._api.answer_question(question_request) + return self._api.answer_question_with_timeout( + run_evaluator, question_request + ) + return self._api.answer_question(run_evaluator, question_request) def get_questions(self, db_connection_id: str | None = None) -> list[Question]: return self._api.get_questions(db_connection_id) @@ -282,13 +293,24 @@ def get_response(self, response_id: str) -> Response: """Get a response""" return self._api.get_response(response_id) + def update_response(self, response_id: str) -> Response: + """Update a response""" + return self._api.update_response(response_id) + def execute_sql_query(self, query: Query) -> tuple[str, dict]: """Executes a query on the given db_connection_id""" return self._api.execute_sql_query(query) - def create_response(self, query_request: CreateResponseRequest) -> Response: + def create_response( + self, + run_evaluator: bool = True, + sql_response_only: bool = False, + query_request: CreateResponseRequest = None, + ) -> Response: """Executes a query on the given db_connection_id""" - return self._api.create_response(query_request) + return self._api.create_response( + run_evaluator, sql_response_only, query_request + ) def delete_golden_record(self, golden_record_id: str) -> dict: """Deletes a golden record""" diff --git a/dataherald/sql_generator/generates_nl_answer.py b/dataherald/sql_generator/generates_nl_answer.py index d412a4d0..4fa38e9d 100644 --- a/dataherald/sql_generator/generates_nl_answer.py +++ b/dataherald/sql_generator/generates_nl_answer.py @@ -31,7 +31,9 @@ def __init__(self, system, storage): self.storage = storage self.model = ChatModel(self.system) - def execute(self, query_response: Response) -> Response: + def execute( + self, query_response: Response, sql_response_only: bool = False + ) -> Response: question_repository = QuestionRepository(self.storage) question = question_repository.find_by_id(query_response.question_id) @@ -45,24 +47,30 @@ def execute(self, query_response: Response) -> Response: model_name=os.getenv("LLM_MODEL", "gpt-4"), ) database = SQLDatabase.get_sql_engine(database_connection) - query_response = create_sql_query_status( - database, - query_response.sql_query, - query_response, - top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")), - ) - system_message_prompt = SystemMessagePromptTemplate.from_template( - SYSTEM_TEMPLATE - ) - human_message_prompt = HumanMessagePromptTemplate.from_template(HUMAN_TEMPLATE) - chat_prompt = ChatPromptTemplate.from_messages( - [system_message_prompt, human_message_prompt] - ) - chain = LLMChain(llm=self.llm, prompt=chat_prompt) - nl_resp = chain.run( - question=question.question, - sql_query=query_response.sql_query, - sql_query_result=str(query_response.sql_query_result), - ) - query_response.response = nl_resp + + if not query_response.sql_query_result: + query_response = create_sql_query_status( + database, + query_response.sql_query, + query_response, + top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")), + ) + + if not sql_response_only: + system_message_prompt = SystemMessagePromptTemplate.from_template( + SYSTEM_TEMPLATE + ) + human_message_prompt = HumanMessagePromptTemplate.from_template( + HUMAN_TEMPLATE + ) + chat_prompt = ChatPromptTemplate.from_messages( + [system_message_prompt, human_message_prompt] + ) + chain = LLMChain(llm=self.llm, prompt=chat_prompt) + nl_resp = chain.run( + question=question.question, + sql_query=query_response.sql_query, + sql_query_result=str(query_response.sql_query_result), + ) + query_response.response = nl_resp return query_response