Skip to content

Commit

Permalink
DH-4883 save send run refactor (#237)
Browse files Browse the repository at this point in the history
* DH-4883 Add and update responses endpoint to only process the sql_query

* Add run_evaluator flag for endpoints that create a response

---------

Co-authored-by: Juan Carlos Jose Camacho <[email protected]>
  • Loading branch information
DishenWang2023 and jcjc712 authored Nov 3, 2023
1 parent 844491d commit 6d40e93
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 44 deletions.
17 changes: 14 additions & 3 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ def scan_db(
pass

@abstractmethod
def answer_question(self, question_request: QuestionRequest) -> Response:
def answer_question(
self, run_evaluator: bool = True, question_request: QuestionRequest = None
) -> Response:
pass

@abstractmethod
def answer_question_with_timeout(
self, question_request: QuestionRequest
self, run_evaluator: bool = True, question_request: QuestionRequest = None
) -> Response:
pass

Expand Down Expand Up @@ -100,7 +102,12 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]:
pass

@abstractmethod
def create_response(self, query_request: CreateResponseRequest) -> Response:
def create_response(
self,
run_evaluator: bool = True,
sql_response_only: bool = False,
query_request: CreateResponseRequest = None,
) -> Response:
pass

@abstractmethod
Expand All @@ -111,6 +118,10 @@ def get_responses(self, question_id: str | None = None) -> list[Response]:
def get_response(self, response_id: str) -> Response:
pass

@abstractmethod
def update_response(self, response_id: str) -> Response:
pass

@abstractmethod
def delete_golden_record(self, golden_record_id: str) -> dict:
pass
Expand Down
60 changes: 45 additions & 15 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ def scan_db(
return True

@override
def answer_question(self, question_request: QuestionRequest) -> Response:
def answer_question(
self, run_evaluator: bool = True, question_request: QuestionRequest = None
) -> Response:
"""Takes in an English question and answers it based on content from the registered databases"""
logger.info(f"Answer question: {question_request.question}")
sql_generation = self.system.instance(SQLGenerator)
evaluator = self.system.instance(Evaluator)
context_store = self.system.instance(ContextStore)

user_question = Question(
Expand Down Expand Up @@ -152,22 +153,24 @@ def answer_question(self, question_request: QuestionRequest) -> Response:
user_question, database_connection, context[0]
)
logger.info("Starts evaluator...")
confidence_score = evaluator.get_confidence_score(
user_question, generated_answer, database_connection
)
if run_evaluator:
evaluator = self.system.instance(Evaluator)
confidence_score = evaluator.get_confidence_score(
user_question, generated_answer, database_connection
)
generated_answer.confidence_score = confidence_score
except Exception as e:
return JSONResponse(
status_code=400,
content={"question_id": user_question.id, "error_message": str(e)},
)
generated_answer.confidence_score = confidence_score
generated_answer.exec_time = time.time() - start_generated_answer
response_repository = ResponseRepository(self.storage)
return response_repository.insert(generated_answer)

@override
def answer_question_with_timeout(
self, question_request: QuestionRequest
self, run_evaluator: bool = True, question_request: QuestionRequest = None
) -> Response:
result = None
exception = None
Expand All @@ -182,7 +185,7 @@ def answer_question_with_timeout(
def run_and_catch_exceptions():
nonlocal result, exception
if not stop_event.is_set():
result = self.answer_question(question_request)
result = self.answer_question(run_evaluator, question_request)

thread = threading.Thread(target=run_and_catch_exceptions)
thread.start()
Expand Down Expand Up @@ -348,6 +351,29 @@ def get_response(self, response_id: str) -> Response:

return result

@override
def update_response(self, response_id: str) -> Response:
response_repository = ResponseRepository(self.storage)

try:
response = response_repository.find_by_id(response_id)
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e
if not response:
raise HTTPException(status_code=404, detail="Question not found")

start_generated_answer = time.time()
try:
generates_nl_answer = GeneratesNlAnswer(self.system, self.storage)
response = generates_nl_answer.execute(response)
response.exec_time = time.time() - start_generated_answer
response_repository.update(response)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
except SQLInjectionError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
return response

@override
def get_questions(self, db_connection_id: str | None = None) -> list[Question]:
question_repository = QuestionRepository(self.storage)
Expand Down Expand Up @@ -397,9 +423,11 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]:

@override
def create_response(
self, query_request: CreateResponseRequest # noqa: ARG002
self,
run_evaluator: bool = True,
sql_response_only: bool = False,
query_request: CreateResponseRequest = None, # noqa: ARG002
) -> Response:
evaluator = self.system.instance(Evaluator)
question_repository = QuestionRepository(self.storage)
user_question = question_repository.find_by_id(query_request.question_id)
db_connection_repository = DatabaseConnectionRepository(self.storage)
Expand All @@ -417,11 +445,13 @@ def create_response(
start_generated_answer = time.time()
try:
generates_nl_answer = GeneratesNlAnswer(self.system, self.storage)
response = generates_nl_answer.execute(response)
confidence_score = evaluator.get_confidence_score(
user_question, response, database_connection
)
response.confidence_score = confidence_score
response = generates_nl_answer.execute(response, sql_response_only)
if run_evaluator:
evaluator = self.system.instance(Evaluator)
confidence_score = evaluator.get_confidence_score(
user_question, response, database_connection
)
response.confidence_score = confidence_score
response.exec_time = time.time() - start_generated_answer
response_repository.update(response)
except ValueError as e:
Expand Down
32 changes: 27 additions & 5 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ def __init__(self, settings: Settings):
tags=["Responses"],
)

self.router.add_api_route(
"/api/v1/responses/{response_id}",
self.update_response,
methods=["PATCH"],
tags=["Responses"],
)

self.router.add_api_route(
"/api/v1/sql-query-executions",
self.execute_sql_query,
Expand Down Expand Up @@ -216,10 +223,14 @@ def scan_db(
) -> bool:
return self._api.scan_db(scanner_request, background_tasks)

def answer_question(self, question_request: QuestionRequest) -> Response:
def answer_question(
self, run_evaluator: bool = True, question_request: QuestionRequest = None
) -> Response:
if os.getenv("DH_ENGINE_TIMEOUT", None):
return self._api.answer_question_with_timeout(question_request)
return self._api.answer_question(question_request)
return self._api.answer_question_with_timeout(
run_evaluator, question_request
)
return self._api.answer_question(run_evaluator, question_request)

def get_questions(self, db_connection_id: str | None = None) -> list[Question]:
return self._api.get_questions(db_connection_id)
Expand Down Expand Up @@ -282,13 +293,24 @@ def get_response(self, response_id: str) -> Response:
"""Get a response"""
return self._api.get_response(response_id)

def update_response(self, response_id: str) -> Response:
"""Update a response"""
return self._api.update_response(response_id)

def execute_sql_query(self, query: Query) -> tuple[str, dict]:
"""Executes a query on the given db_connection_id"""
return self._api.execute_sql_query(query)

def create_response(self, query_request: CreateResponseRequest) -> Response:
def create_response(
self,
run_evaluator: bool = True,
sql_response_only: bool = False,
query_request: CreateResponseRequest = None,
) -> Response:
"""Executes a query on the given db_connection_id"""
return self._api.create_response(query_request)
return self._api.create_response(
run_evaluator, sql_response_only, query_request
)

def delete_golden_record(self, golden_record_id: str) -> dict:
"""Deletes a golden record"""
Expand Down
50 changes: 29 additions & 21 deletions dataherald/sql_generator/generates_nl_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def __init__(self, system, storage):
self.storage = storage
self.model = ChatModel(self.system)

def execute(self, query_response: Response) -> Response:
def execute(
self, query_response: Response, sql_response_only: bool = False
) -> Response:
question_repository = QuestionRepository(self.storage)
question = question_repository.find_by_id(query_response.question_id)

Expand All @@ -45,24 +47,30 @@ def execute(self, query_response: Response) -> Response:
model_name=os.getenv("LLM_MODEL", "gpt-4"),
)
database = SQLDatabase.get_sql_engine(database_connection)
query_response = create_sql_query_status(
database,
query_response.sql_query,
query_response,
top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")),
)
system_message_prompt = SystemMessagePromptTemplate.from_template(
SYSTEM_TEMPLATE
)
human_message_prompt = HumanMessagePromptTemplate.from_template(HUMAN_TEMPLATE)
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
chain = LLMChain(llm=self.llm, prompt=chat_prompt)
nl_resp = chain.run(
question=question.question,
sql_query=query_response.sql_query,
sql_query_result=str(query_response.sql_query_result),
)
query_response.response = nl_resp

if not query_response.sql_query_result:
query_response = create_sql_query_status(
database,
query_response.sql_query,
query_response,
top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")),
)

if not sql_response_only:
system_message_prompt = SystemMessagePromptTemplate.from_template(
SYSTEM_TEMPLATE
)
human_message_prompt = HumanMessagePromptTemplate.from_template(
HUMAN_TEMPLATE
)
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
chain = LLMChain(llm=self.llm, prompt=chat_prompt)
nl_resp = chain.run(
question=question.question,
sql_query=query_response.sql_query,
sql_query_result=str(query_response.sql_query_result),
)
query_response.response = nl_resp
return query_response

0 comments on commit 6d40e93

Please sign in to comment.