From b81e7e002c9ae5d8bdb23456c1862c53bbf7c419 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Carlos=20Jos=C3=A9=20Camacho?= Date: Fri, 17 Nov 2023 09:32:19 -0600 Subject: [PATCH] [DH-5014] Fix duplicated questions (#255) --- dataherald/api/fastapi.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index c520c340..df0a96d7 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -131,23 +131,24 @@ def answer_question( run_evaluator: bool = True, generate_csv: bool = False, question_request: QuestionRequest = None, + user_question: Question | None = None, ) -> Response: """Takes in an English question and answers it based on content from the registered databases""" - logger.info(f"Answer question: {question_request.question}") sql_generation = self.system.instance(SQLGenerator) context_store = self.system.instance(ContextStore) - user_question = Question( - question=question_request.question, - db_connection_id=question_request.db_connection_id, - ) - - question_repository = QuestionRepository(self.storage) - user_question = question_repository.insert(user_question) + if not user_question: + user_question = Question( + question=question_request.question, + db_connection_id=question_request.db_connection_id, + ) + question_repository = QuestionRepository(self.storage) + user_question = question_repository.insert(user_question) + logger.info(f"Answer question: {user_question.question}") db_connection_repository = DatabaseConnectionRepository(self.storage) database_connection = db_connection_repository.find_by_id( - question_request.db_connection_id + user_question.db_connection_id ) if not database_connection: return JSONResponse( @@ -209,7 +210,7 @@ def run_and_catch_exceptions(): nonlocal result, exception if not stop_event.is_set(): result = self.answer_question( - run_evaluator, generate_csv, question_request + run_evaluator, generate_csv, None, user_question ) thread = threading.Thread(target=run_and_catch_exceptions)