diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index c520c340..df0a96d7 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -131,23 +131,24 @@ def answer_question( run_evaluator: bool = True, generate_csv: bool = False, question_request: QuestionRequest = None, + user_question: Question | None = None, ) -> Response: """Takes in an English question and answers it based on content from the registered databases""" - logger.info(f"Answer question: {question_request.question}") sql_generation = self.system.instance(SQLGenerator) context_store = self.system.instance(ContextStore) - user_question = Question( - question=question_request.question, - db_connection_id=question_request.db_connection_id, - ) - - question_repository = QuestionRepository(self.storage) - user_question = question_repository.insert(user_question) + if not user_question: + user_question = Question( + question=question_request.question, + db_connection_id=question_request.db_connection_id, + ) + question_repository = QuestionRepository(self.storage) + user_question = question_repository.insert(user_question) + logger.info(f"Answer question: {user_question.question}") db_connection_repository = DatabaseConnectionRepository(self.storage) database_connection = db_connection_repository.find_by_id( - question_request.db_connection_id + user_question.db_connection_id ) if not database_connection: return JSONResponse( @@ -209,7 +210,7 @@ def run_and_catch_exceptions(): nonlocal result, exception if not stop_event.is_set(): result = self.answer_question( - run_evaluator, generate_csv, question_request + run_evaluator, generate_csv, None, user_question ) thread = threading.Thread(target=run_and_catch_exceptions)