Skip to content

Commit

Permalink
Remove URL logic and rename flag to create CSV file
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Nov 7, 2023
1 parent 682a402 commit 029ad1d
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 85 deletions.
10 changes: 5 additions & 5 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List

from fastapi import BackgroundTasks
from fastapi.responses import JSONResponse
from fastapi.responses import FileResponse

from dataherald.api.types import Query
from dataherald.config import Component
Expand Down Expand Up @@ -40,7 +40,7 @@ def scan_db(
def answer_question(
self,
run_evaluator: bool = True,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
pass
Expand All @@ -49,7 +49,7 @@ def answer_question(
def answer_question_with_timeout(
self,
run_evaluator: bool = True,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
pass
Expand Down Expand Up @@ -117,7 +117,7 @@ def create_response(
self,
run_evaluator: bool = True,
sql_response_only: bool = False,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
query_request: CreateResponseRequest = None,
) -> Response:
pass
Expand All @@ -131,7 +131,7 @@ def get_response(self, response_id: str) -> Response:
pass

@abstractmethod
def get_response_file(self, response_id: str) -> JSONResponse:
def get_response_file(self, response_id: str) -> FileResponse:
pass

@abstractmethod
Expand Down
31 changes: 14 additions & 17 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bson import json_util
from bson.objectid import InvalidId, ObjectId
from fastapi import BackgroundTasks, HTTPException
from fastapi.responses import JSONResponse
from fastapi.responses import FileResponse, JSONResponse
from overrides import override

from dataherald.api import API
Expand Down Expand Up @@ -123,7 +123,7 @@ def scan_db(
def answer_question(
self,
run_evaluator: bool = True,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
"""Takes in an English question and answers it based on content from the registered databases"""
Expand Down Expand Up @@ -158,7 +158,7 @@ def answer_question(
user_question,
database_connection,
context[0],
large_query_result_in_csv,
generate_csv,
)
logger.info("Starts evaluator...")
if run_evaluator:
Expand All @@ -172,7 +172,7 @@ def answer_question(
status_code=400,
content={"question_id": user_question.id, "error_message": str(e)},
)
if generated_answer.csv_download_url:
if generated_answer.csv_file_path:
generated_answer.sql_query_result = None
generated_answer.confidence_score = confidence_score
generated_answer.exec_time = time.time() - start_generated_answer
Expand All @@ -183,7 +183,7 @@ def answer_question(
def answer_question_with_timeout(
self,
run_evaluator: bool = True,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
result = None
Expand All @@ -200,7 +200,7 @@ def run_and_catch_exceptions():
nonlocal result, exception
if not stop_event.is_set():
result = self.answer_question(
run_evaluator, large_query_result_in_csv, question_request
run_evaluator, generate_csv, question_request
)

thread = threading.Thread(target=run_and_catch_exceptions)
Expand Down Expand Up @@ -370,7 +370,7 @@ def get_response(self, response_id: str) -> Response:
return result

@override
def get_response_file(self, response_id: str) -> JSONResponse:
def get_response_file(self, response_id: str) -> FileResponse:
response_repository = ResponseRepository(self.storage)
question_repository = QuestionRepository(self.storage)
db_connection_repository = DatabaseConnectionRepository(self.storage)
Expand All @@ -389,13 +389,10 @@ def get_response_file(self, response_id: str) -> JSONResponse:
)

s3 = S3()
return JSONResponse(
status_code=201,
content={
"csv_download_url": s3.download_url(
result.csv_file_path, db_connection.file_storage
),
},

return FileResponse(
s3.download(result.csv_file_path, db_connection.file_storage),
media_type="text/csv",
)

@override
Expand Down Expand Up @@ -475,7 +472,7 @@ def create_response(
self,
run_evaluator: bool = True,
sql_response_only: bool = False,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
query_request: CreateResponseRequest = None, # noqa: ARG002
) -> Response:
question_repository = QuestionRepository(self.storage)
Expand All @@ -498,7 +495,7 @@ def create_response(
user_question,
database_connection,
context[0],
large_query_result_in_csv,
generate_csv,
)
else:
response = Response(
Expand All @@ -522,7 +519,7 @@ def create_response(
user_question, response, database_connection
)
response.confidence_score = confidence_score
if response.csv_download_url:
if response.csv_file_path:
response.sql_query_result = None
response.exec_time = time.time() - start_generated_answer
response_repository.insert(response)
Expand Down
20 changes: 9 additions & 11 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import fastapi
from fastapi import BackgroundTasks, status
from fastapi import FastAPI as _FastAPI
from fastapi.responses import JSONResponse
from fastapi.responses import FileResponse, JSONResponse
from fastapi.routing import APIRoute

import dataherald
Expand Down Expand Up @@ -165,9 +165,9 @@ def __init__(self, settings: Settings):
)

self.router.add_api_route(
"/api/v1/responses/{response_id}/generate-csv-download-url",
"/api/v1/responses/{response_id}/file",
self.get_response_file,
methods=["POST"],
methods=["GET"],
tags=["Responses"],
)

Expand Down Expand Up @@ -233,16 +233,14 @@ def scan_db(
def answer_question(
self,
run_evaluator: bool = True,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
if os.getenv("DH_ENGINE_TIMEOUT", None):
return self._api.answer_question_with_timeout(
run_evaluator, large_query_result_in_csv, question_request
run_evaluator, generate_csv, question_request
)
return self._api.answer_question(
run_evaluator, large_query_result_in_csv, question_request
)
return self._api.answer_question(run_evaluator, generate_csv, question_request)

def get_questions(self, db_connection_id: str | None = None) -> list[Question]:
return self._api.get_questions(db_connection_id)
Expand Down Expand Up @@ -309,7 +307,7 @@ def update_response(self, response_id: str) -> Response:
"""Update a response"""
return self._api.update_response(response_id)

def get_response_file(self, response_id: str) -> JSONResponse:
def get_response_file(self, response_id: str) -> FileResponse:
"""Get a response file"""
return self._api.get_response_file(response_id)

Expand All @@ -321,12 +319,12 @@ def create_response(
self,
run_evaluator: bool = True,
sql_response_only: bool = False,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
query_request: CreateResponseRequest = None,
) -> Response:
"""Executes a query on the given db_connection_id"""
return self._api.create_response(
run_evaluator, sql_response_only, large_query_result_in_csv, query_request
run_evaluator, sql_response_only, generate_csv, query_request
)

def delete_golden_record(self, golden_record_id: str) -> dict:
Expand Down
6 changes: 3 additions & 3 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ def create_sql_query_status(
query: str,
response: Response,
top_k: int = None,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
database_connection: DatabaseConnection | None = None,
) -> Response:
return create_sql_query_status(
db,
query,
response,
top_k,
large_query_result_in_csv,
generate_csv,
database_connection=database_connection,
)

Expand Down Expand Up @@ -89,7 +89,7 @@ def generate_response(
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
) -> Response:
"""Generates a response to a user question."""
pass
17 changes: 5 additions & 12 deletions dataherald/sql_generator/create_sql_query_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
from dataherald.types import Response, SQLQueryResult
from dataherald.utils.s3 import S3

MAX_ROWS_TO_CREATE_CSV_FILE = 50
MAX_CHARACTERS_TO_CREATE_CSV_FILE = 3_000


def format_error_message(response: Response, error_message: str) -> Response:
# Remove the complete query
Expand All @@ -29,16 +26,13 @@ def format_error_message(response: Response, error_message: str) -> Response:


def create_csv_file(
large_query_result_in_csv: bool,
generate_csv: bool,
columns: list,
rows: list,
response: Response,
database_connection: DatabaseConnection | None = None,
):
if large_query_result_in_csv and (
len(rows) >= MAX_ROWS_TO_CREATE_CSV_FILE
or len(str(rows)) > MAX_CHARACTERS_TO_CREATE_CSV_FILE
):
if generate_csv:
file_location = f"tmp/{str(uuid.uuid4())}.csv"
with open(file_location, "w", newline="") as file:
writer = csv.writer(file)
Expand All @@ -47,10 +41,9 @@ def create_csv_file(
for row in rows:
writer.writerow(row.values())
s3 = S3()
response.csv_download_url = s3.upload(
response.csv_file_path = s3.upload(
file_location, database_connection.file_storage
)
response.csv_file_path = f's3://k2-core/{file_location.split("/")[-1]}'
response.sql_query_result = SQLQueryResult(columns=columns, rows=rows)


Expand All @@ -59,7 +52,7 @@ def create_sql_query_status(
query: str,
response: Response,
top_k: int = None,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
database_connection: DatabaseConnection | None = None,
) -> Response:
"""Find the sql query status and populate the fields sql_query_result, sql_generation_status, and error_message"""
Expand Down Expand Up @@ -100,7 +93,7 @@ def create_sql_query_status(
rows.append(modified_row)

create_csv_file(
large_query_result_in_csv,
generate_csv,
columns,
rows,
response,
Expand Down
4 changes: 2 additions & 2 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def generate_response(
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
) -> Response:
start_time = time.time()
context_store = self.system.instance(ContextStore)
Expand Down Expand Up @@ -712,6 +712,6 @@ def generate_response(
response.sql_query,
response,
top_k=TOP_K,
large_query_result_in_csv=large_query_result_in_csv,
generate_csv=generate_csv,
database_connection=database_connection,
)
6 changes: 3 additions & 3 deletions dataherald/sql_generator/generates_nl_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def execute(
self,
query_response: Response,
sql_response_only: bool = False,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
) -> Response:
question_repository = QuestionRepository(self.storage)
question = question_repository.find_by_id(query_response.question_id)
Expand All @@ -57,10 +57,10 @@ def execute(
query_response.sql_query,
query_response,
top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")),
large_query_result_in_csv=large_query_result_in_csv,
generate_csv=generate_csv,
)

if query_response.csv_download_url:
if query_response.csv_file_path:
query_response.response = None
return query_response

Expand Down
4 changes: 2 additions & 2 deletions dataherald/sql_generator/langchain_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def generate_response(
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
) -> Response: # type: ignore
logger.info(f"Generating SQL response to question: {str(user_question.dict())}")
self.llm = self.model.get_model(
Expand Down Expand Up @@ -90,5 +90,5 @@ def generate_response(
self.database,
response.sql_query,
response,
large_query_result_in_csv=large_query_result_in_csv,
generate_csv=generate_csv,
)
4 changes: 2 additions & 2 deletions dataherald/sql_generator/langchain_sqlchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def generate_response(
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
) -> Response:
start_time = time.time()
self.llm = self.model.get_model(
Expand Down Expand Up @@ -99,5 +99,5 @@ def generate_response(
self.database,
response.sql_query,
response,
large_query_result_in_csv=large_query_result_in_csv,
generate_csv=generate_csv,
)
4 changes: 2 additions & 2 deletions dataherald/sql_generator/llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def generate_response(
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
large_query_result_in_csv: bool = False,
generate_csv: bool = False,
) -> Response:
start_time = time.time()
logger.info(f"Generating SQL response to question: {str(user_question.dict())}")
Expand Down Expand Up @@ -114,5 +114,5 @@ def generate_response(
self.database,
response.sql_query,
response,
large_query_result_in_csv=large_query_result_in_csv,
generate_csv=generate_csv,
)
Loading

0 comments on commit 029ad1d

Please sign in to comment.