Skip to content

Commit

Permalink
DH-4905 Add flag to create a csv file when the response have many rows
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Nov 1, 2023
1 parent 844491d commit 44b69fd
Show file tree
Hide file tree
Showing 12 changed files with 207 additions and 25 deletions.
23 changes: 20 additions & 3 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from fastapi import BackgroundTasks
from fastapi.responses import FileResponse

from dataherald.api.types import Query
from dataherald.config import Component
Expand Down Expand Up @@ -36,12 +37,18 @@ def scan_db(
pass

@abstractmethod
def answer_question(self, question_request: QuestionRequest) -> Response:
def answer_question(
self,
store_substantial_query_result_in_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
pass

@abstractmethod
def answer_question_with_timeout(
self, question_request: QuestionRequest
self,
store_substantial_query_result_in_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
pass

Expand Down Expand Up @@ -100,7 +107,11 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]:
pass

@abstractmethod
def create_response(self, query_request: CreateResponseRequest) -> Response:
def create_response(
self,
store_substantial_query_result_in_csv: bool = False,
query_request: CreateResponseRequest = None,
) -> Response:
pass

@abstractmethod
Expand All @@ -111,6 +122,12 @@ def get_responses(self, question_id: str | None = None) -> list[Response]:
def get_response(self, response_id: str) -> Response:
pass

@abstractmethod
def get_response_file(
self, response_id: str, background_tasks: BackgroundTasks
) -> FileResponse:
pass

@abstractmethod
def delete_golden_record(self, golden_record_id: str) -> dict:
pass
Expand Down
60 changes: 53 additions & 7 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bson import json_util
from bson.objectid import InvalidId, ObjectId
from fastapi import BackgroundTasks, HTTPException
from fastapi.responses import JSONResponse
from fastapi.responses import FileResponse, JSONResponse
from overrides import override

from dataherald.api import API
Expand Down Expand Up @@ -50,6 +50,7 @@
TableDescriptionRequest,
UpdateInstruction,
)
from dataherald.utils.s3 import S3

logger = logging.getLogger(__name__)

Expand All @@ -63,6 +64,10 @@ def async_scanning(scanner, database, scanner_request, storage):
)


def async_removing_file(file_path: str):
os.remove(file_path)


class FastAPI(API):
def __init__(self, system: System):
super().__init__(system)
Expand Down Expand Up @@ -118,7 +123,11 @@ def scan_db(
return True

@override
def answer_question(self, question_request: QuestionRequest) -> Response:
def answer_question(
self,
store_substantial_query_result_in_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
"""Takes in an English question and answers it based on content from the registered databases"""
logger.info(f"Answer question: {question_request.question}")
sql_generation = self.system.instance(SQLGenerator)
Expand Down Expand Up @@ -149,7 +158,10 @@ def answer_question(self, question_request: QuestionRequest) -> Response:
start_generated_answer = time.time()
try:
generated_answer = sql_generation.generate_response(
user_question, database_connection, context[0]
user_question,
database_connection,
context[0],
store_substantial_query_result_in_csv,
)
logger.info("Starts evaluator...")
confidence_score = evaluator.get_confidence_score(
Expand All @@ -167,7 +179,9 @@ def answer_question(self, question_request: QuestionRequest) -> Response:

@override
def answer_question_with_timeout(
self, question_request: QuestionRequest
self,
store_substantial_query_result_in_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
result = None
exception = None
Expand All @@ -182,7 +196,9 @@ def answer_question_with_timeout(
def run_and_catch_exceptions():
nonlocal result, exception
if not stop_event.is_set():
result = self.answer_question(question_request)
result = self.answer_question(
store_substantial_query_result_in_csv, question_request
)

thread = threading.Thread(target=run_and_catch_exceptions)
thread.start()
Expand Down Expand Up @@ -348,6 +364,32 @@ def get_response(self, response_id: str) -> Response:

return result

@override
def get_response_file(
self, response_id: str, background_tasks: BackgroundTasks
) -> FileResponse:
response_repository = ResponseRepository(self.storage)

try:
result = response_repository.find_by_id(response_id)
except InvalidId as e:
raise HTTPException(status_code=400, detail=str(e)) from e

if not result:
raise HTTPException(status_code=404, detail="Question not found")

# todo download
s3 = S3()
file_path = s3.download(result.csv_file_path)
background_tasks.add_task(async_removing_file, file_path)
return FileResponse(
file_path,
media_type="text/csv",
headers={
"Content-Disposition": f"attachment; filename={file_path.split('/')[-1]}"
},
)

@override
def get_questions(self, db_connection_id: str | None = None) -> list[Question]:
question_repository = QuestionRepository(self.storage)
Expand Down Expand Up @@ -397,7 +439,9 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]:

@override
def create_response(
self, query_request: CreateResponseRequest # noqa: ARG002
self,
store_substantial_query_result_in_csv: bool = False,
query_request: CreateResponseRequest = None, # noqa: ARG002
) -> Response:
evaluator = self.system.instance(Evaluator)
question_repository = QuestionRepository(self.storage)
Expand All @@ -417,7 +461,9 @@ def create_response(
start_generated_answer = time.time()
try:
generates_nl_answer = GeneratesNlAnswer(self.system, self.storage)
response = generates_nl_answer.execute(response)
response = generates_nl_answer.execute(
response, store_substantial_query_result_in_csv
)
confidence_score = evaluator.get_confidence_score(
user_question, response, database_connection
)
Expand Down
39 changes: 33 additions & 6 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import fastapi
from fastapi import BackgroundTasks, status
from fastapi import FastAPI as _FastAPI
from fastapi.responses import JSONResponse
from fastapi.responses import FileResponse, JSONResponse
from fastapi.routing import APIRoute

import dataherald
Expand Down Expand Up @@ -164,6 +164,13 @@ def __init__(self, settings: Settings):
tags=["Responses"],
)

self.router.add_api_route(
"/api/v1/responses/{response_id}/file",
self.get_response_file,
methods=["GET"],
tags=["Responses"],
)

self.router.add_api_route(
"/api/v1/sql-query-executions",
self.execute_sql_query,
Expand Down Expand Up @@ -216,10 +223,18 @@ def scan_db(
) -> bool:
return self._api.scan_db(scanner_request, background_tasks)

def answer_question(self, question_request: QuestionRequest) -> Response:
def answer_question(
self,
store_substantial_query_result_in_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
if os.getenv("DH_ENGINE_TIMEOUT", None):
return self._api.answer_question_with_timeout(question_request)
return self._api.answer_question(question_request)
return self._api.answer_question_with_timeout(
store_substantial_query_result_in_csv, question_request
)
return self._api.answer_question(
store_substantial_query_result_in_csv, question_request
)

def get_questions(self, db_connection_id: str | None = None) -> list[Question]:
return self._api.get_questions(db_connection_id)
Expand Down Expand Up @@ -282,13 +297,25 @@ def get_response(self, response_id: str) -> Response:
"""Get a response"""
return self._api.get_response(response_id)

def get_response_file(
self, response_id: str, background_tasks: BackgroundTasks
) -> FileResponse:
"""Get a response file"""
return self._api.get_response_file(response_id, background_tasks)

def execute_sql_query(self, query: Query) -> tuple[str, dict]:
"""Executes a query on the given db_connection_id"""
return self._api.execute_sql_query(query)

def create_response(self, query_request: CreateResponseRequest) -> Response:
def create_response(
self,
store_substantial_query_result_in_csv: bool = False,
query_request: CreateResponseRequest = None,
) -> Response:
"""Executes a query on the given db_connection_id"""
return self._api.create_response(query_request)
return self._api.create_response(
store_substantial_query_result_in_csv, query_request
)

def delete_golden_record(self, golden_record_id: str) -> dict:
"""Deletes a golden record"""
Expand Down
12 changes: 10 additions & 2 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,16 @@ def check_for_time_out_or_tool_limit(self, response: dict) -> dict:
return response

def create_sql_query_status(
self, db: SQLDatabase, query: str, response: Response, top_k: int = None
self,
db: SQLDatabase,
query: str,
response: Response,
top_k: int = None,
store_substantial_query_result_in_csv: bool = False,
) -> Response:
return create_sql_query_status(db, query, response, top_k)
return create_sql_query_status(
db, query, response, top_k, store_substantial_query_result_in_csv
)

def format_intermediate_representations(
self, intermediate_representation: List[Tuple[AgentAction, str]]
Expand Down Expand Up @@ -76,6 +83,7 @@ def generate_response(
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
store_substantial_query_result_in_csv: bool = False,
) -> Response:
"""Generates a response to a user question."""
pass
41 changes: 39 additions & 2 deletions dataherald/sql_generator/create_sql_query_status.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import csv
import uuid
from datetime import date, datetime
from decimal import Decimal

from sqlalchemy import text

from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
from dataherald.types import Response, SQLQueryResult
from dataherald.utils.s3 import S3

MAX_ROWS_TO_CREATE_CSV_FILE = 50
MAX_CHARACTERS_TO_CREATE_CSV_FILE = 3_000


def format_error_message(response: Response, error_message: str) -> Response:
Expand All @@ -21,8 +27,35 @@ def format_error_message(response: Response, error_message: str) -> Response:
return response


def create_csv_file(
store_substantial_query_result_in_csv: bool,
columns: list,
rows: list,
response: Response,
):
if store_substantial_query_result_in_csv and (
len(rows) >= MAX_ROWS_TO_CREATE_CSV_FILE
or len(str(rows)) > MAX_CHARACTERS_TO_CREATE_CSV_FILE
):
file_location = f"tmp/{str(uuid.uuid4())}.csv"
with open(file_location, "w", newline="") as file:
writer = csv.writer(file)

writer.writerow(rows[0].keys())
for row in rows:
writer.writerow(row.values())
s3 = S3()
s3.upload(file_location)
response.csv_file_path = f's3://k2-core/{file_location.split("/")[-1]}'
response.sql_query_result = SQLQueryResult(columns=columns, rows=rows)


def create_sql_query_status(
db: SQLDatabase, query: str, response: Response, top_k: int = None
db: SQLDatabase,
query: str,
response: Response,
top_k: int = None,
store_substantial_query_result_in_csv: bool = False,
) -> Response:
"""Find the sql query status and populate the fields sql_query_result, sql_generation_status, and error_message"""
if query == "":
Expand Down Expand Up @@ -60,7 +93,11 @@ def create_sql_query_status(
else:
modified_row[key] = value
rows.append(modified_row)
response.sql_query_result = SQLQueryResult(columns=columns, rows=rows)

create_csv_file(
store_substantial_query_result_in_csv, columns, rows, response
)

response.sql_generation_status = "VALID"
response.error_message = None
except SQLInjectionError as e:
Expand Down
7 changes: 6 additions & 1 deletion dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def generate_response(
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
store_substantial_query_result_in_csv: bool = False,
) -> Response:
start_time = time.time()
context_store = self.system.instance(ContextStore)
Expand Down Expand Up @@ -690,5 +691,9 @@ def generate_response(
sql_query=sql_query_list[-1] if len(sql_query_list) > 0 else "",
)
return self.create_sql_query_status(
self.database, response.sql_query, response, top_k=TOP_K
self.database,
response.sql_query,
response,
top_k=TOP_K,
store_substantial_query_result_in_csv=store_substantial_query_result_in_csv,
)
7 changes: 6 additions & 1 deletion dataherald/sql_generator/generates_nl_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def __init__(self, system, storage):
self.storage = storage
self.model = ChatModel(self.system)

def execute(self, query_response: Response) -> Response:
def execute(
self,
query_response: Response,
store_substantial_query_result_in_csv: bool = False,
) -> Response:
question_repository = QuestionRepository(self.storage)
question = question_repository.find_by_id(query_response.question_id)

Expand All @@ -50,6 +54,7 @@ def execute(self, query_response: Response) -> Response:
query_response.sql_query,
query_response,
top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")),
store_substantial_query_result_in_csv=store_substantial_query_result_in_csv,
)
system_message_prompt = SystemMessagePromptTemplate.from_template(
SYSTEM_TEMPLATE
Expand Down
Loading

0 comments on commit 44b69fd

Please sign in to comment.