Skip to content

Commit

Permalink
[DH-5019] Always create a response object when a question is created (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 authored Nov 20, 2023
1 parent 8955b29 commit 71ec83d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
38 changes: 25 additions & 13 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bson import json_util
from bson.objectid import InvalidId, ObjectId
from fastapi import BackgroundTasks, HTTPException
from fastapi.encoders import jsonable_encoder
from fastapi.responses import FileResponse, JSONResponse
from overrides import override

Expand Down Expand Up @@ -150,14 +151,17 @@ def answer_question(
database_connection = db_connection_repository.find_by_id(
user_question.db_connection_id
)
response_repository = ResponseRepository(self.storage)

if not database_connection:
return JSONResponse(
status_code=404,
content={
"question_id": user_question.id,
"error_message": "Connections doesn't exist",
},
response = response_repository.insert(
Response(
question_id=user_question.id,
error_message="Connections doesn't exist",
sql_query="",
)
)
return JSONResponse(status_code=404, content=jsonable_encoder(response))
try:
context = context_store.retrieve_context_for_question(user_question)
start_generated_answer = time.time()
Expand All @@ -175,10 +179,16 @@ def answer_question(
)
generated_answer.confidence_score = confidence_score
except Exception as e:
response = response_repository.insert(
Response(
question_id=user_question.id, error_message=str(e), sql_query=""
)
)
return JSONResponse(
status_code=400,
content={"question_id": user_question.id, "error_message": str(e)},
content=jsonable_encoder(response),
)

if (
generate_csv
and len(generated_answer.sql_query_result.rows)
Expand Down Expand Up @@ -218,13 +228,15 @@ def run_and_catch_exceptions():
thread.join(timeout=int(os.getenv("DH_ENGINE_TIMEOUT")))
if thread.is_alive():
stop_event.set()
return JSONResponse(
status_code=400,
content={
"question_id": user_question.id,
"error_message": "Timeout Error",
},
response_repository = ResponseRepository(self.storage)
response = response_repository.insert(
Response(
question_id=user_question.id,
error_message="Timeout Error",
sql_query="",
)
)
return JSONResponse(status_code=400, content=jsonable_encoder(response))
return result

@override
Expand Down
3 changes: 2 additions & 1 deletion dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ def create_sql_agent(
suffix: str = FINETUNING_AGENT_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: List[str] | None = None,
max_iterations: int | None = int(os.getenv("AGENT_MAX_ITERATIONS", "20")),
max_iterations: int
| None = int(os.getenv("AGENT_MAX_ITERATIONS", "20")), # noqa: B008
max_execution_time: float | None = None,
early_stopping_method: str = "force",
verbose: bool = False,
Expand Down
3 changes: 2 additions & 1 deletion dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,8 @@ def create_sql_agent(
input_variables: List[str] | None = None,
max_examples: int = 20,
number_of_instructions: int = 1,
max_iterations: int | None = int(os.getenv("AGENT_MAX_ITERATIONS", "20")),
max_iterations: int
| None = int(os.getenv("AGENT_MAX_ITERATIONS", "20")), # noqa: B008
max_execution_time: float | None = None,
early_stopping_method: str = "force",
verbose: bool = False,
Expand Down

0 comments on commit 71ec83d

Please sign in to comment.