From 029ad1d6fef6eae95b7974e129d7b21c17f6b5e7 Mon Sep 17 00:00:00 2001 From: Juan Carlos Jose Camacho Date: Tue, 7 Nov 2023 12:52:59 -0600 Subject: [PATCH] Remove URL logic and rename flag to create CSV file --- dataherald/api/__init__.py | 10 +++--- dataherald/api/fastapi.py | 31 +++++++++---------- dataherald/server/fastapi/__init__.py | 20 ++++++------ dataherald/sql_generator/__init__.py | 6 ++-- .../sql_generator/create_sql_query_status.py | 17 +++------- .../sql_generator/dataherald_sqlagent.py | 4 +-- .../sql_generator/generates_nl_answer.py | 6 ++-- .../sql_generator/langchain_sqlagent.py | 4 +-- .../sql_generator/langchain_sqlchain.py | 4 +-- dataherald/sql_generator/llamaindex.py | 4 +-- .../tests/sql_generator/test_generator.py | 4 +-- dataherald/types.py | 1 - dataherald/utils/s3.py | 26 ++-------------- 13 files changed, 52 insertions(+), 85 deletions(-) diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index e6a3a81b..87b5fc2c 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -2,7 +2,7 @@ from typing import List from fastapi import BackgroundTasks -from fastapi.responses import JSONResponse +from fastapi.responses import FileResponse from dataherald.api.types import Query from dataherald.config import Component @@ -40,7 +40,7 @@ def scan_db( def answer_question( self, run_evaluator: bool = True, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, question_request: QuestionRequest = None, ) -> Response: pass @@ -49,7 +49,7 @@ def answer_question( def answer_question_with_timeout( self, run_evaluator: bool = True, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, question_request: QuestionRequest = None, ) -> Response: pass @@ -117,7 +117,7 @@ def create_response( self, run_evaluator: bool = True, sql_response_only: bool = False, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, query_request: CreateResponseRequest = None, ) -> Response: pass @@ -131,7 +131,7 @@ def get_response(self, response_id: str) -> Response: pass @abstractmethod - def get_response_file(self, response_id: str) -> JSONResponse: + def get_response_file(self, response_id: str) -> FileResponse: pass @abstractmethod diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 50a3b1ba..0d048bcd 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -9,7 +9,7 @@ from bson import json_util from bson.objectid import InvalidId, ObjectId from fastapi import BackgroundTasks, HTTPException -from fastapi.responses import JSONResponse +from fastapi.responses import FileResponse, JSONResponse from overrides import override from dataherald.api import API @@ -123,7 +123,7 @@ def scan_db( def answer_question( self, run_evaluator: bool = True, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, question_request: QuestionRequest = None, ) -> Response: """Takes in an English question and answers it based on content from the registered databases""" @@ -158,7 +158,7 @@ def answer_question( user_question, database_connection, context[0], - large_query_result_in_csv, + generate_csv, ) logger.info("Starts evaluator...") if run_evaluator: @@ -172,7 +172,7 @@ def answer_question( status_code=400, content={"question_id": user_question.id, "error_message": str(e)}, ) - if generated_answer.csv_download_url: + if generated_answer.csv_file_path: generated_answer.sql_query_result = None generated_answer.confidence_score = confidence_score generated_answer.exec_time = time.time() - start_generated_answer @@ -183,7 +183,7 @@ def answer_question( def answer_question_with_timeout( self, run_evaluator: bool = True, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, question_request: QuestionRequest = None, ) -> Response: result = None @@ -200,7 +200,7 @@ def run_and_catch_exceptions(): nonlocal result, exception if not stop_event.is_set(): result = self.answer_question( - run_evaluator, large_query_result_in_csv, question_request + run_evaluator, generate_csv, question_request ) thread = threading.Thread(target=run_and_catch_exceptions) @@ -370,7 +370,7 @@ def get_response(self, response_id: str) -> Response: return result @override - def get_response_file(self, response_id: str) -> JSONResponse: + def get_response_file(self, response_id: str) -> FileResponse: response_repository = ResponseRepository(self.storage) question_repository = QuestionRepository(self.storage) db_connection_repository = DatabaseConnectionRepository(self.storage) @@ -389,13 +389,10 @@ def get_response_file(self, response_id: str) -> JSONResponse: ) s3 = S3() - return JSONResponse( - status_code=201, - content={ - "csv_download_url": s3.download_url( - result.csv_file_path, db_connection.file_storage - ), - }, + + return FileResponse( + s3.download(result.csv_file_path, db_connection.file_storage), + media_type="text/csv", ) @override @@ -475,7 +472,7 @@ def create_response( self, run_evaluator: bool = True, sql_response_only: bool = False, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, query_request: CreateResponseRequest = None, # noqa: ARG002 ) -> Response: question_repository = QuestionRepository(self.storage) @@ -498,7 +495,7 @@ def create_response( user_question, database_connection, context[0], - large_query_result_in_csv, + generate_csv, ) else: response = Response( @@ -522,7 +519,7 @@ def create_response( user_question, response, database_connection ) response.confidence_score = confidence_score - if response.csv_download_url: + if response.csv_file_path: response.sql_query_result = None response.exec_time = time.time() - start_generated_answer response_repository.insert(response) diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 788176e0..89574187 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -4,7 +4,7 @@ import fastapi from fastapi import BackgroundTasks, status from fastapi import FastAPI as _FastAPI -from fastapi.responses import JSONResponse +from fastapi.responses import FileResponse, JSONResponse from fastapi.routing import APIRoute import dataherald @@ -165,9 +165,9 @@ def __init__(self, settings: Settings): ) self.router.add_api_route( - "/api/v1/responses/{response_id}/generate-csv-download-url", + "/api/v1/responses/{response_id}/file", self.get_response_file, - methods=["POST"], + methods=["GET"], tags=["Responses"], ) @@ -233,16 +233,14 @@ def scan_db( def answer_question( self, run_evaluator: bool = True, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, question_request: QuestionRequest = None, ) -> Response: if os.getenv("DH_ENGINE_TIMEOUT", None): return self._api.answer_question_with_timeout( - run_evaluator, large_query_result_in_csv, question_request + run_evaluator, generate_csv, question_request ) - return self._api.answer_question( - run_evaluator, large_query_result_in_csv, question_request - ) + return self._api.answer_question(run_evaluator, generate_csv, question_request) def get_questions(self, db_connection_id: str | None = None) -> list[Question]: return self._api.get_questions(db_connection_id) @@ -309,7 +307,7 @@ def update_response(self, response_id: str) -> Response: """Update a response""" return self._api.update_response(response_id) - def get_response_file(self, response_id: str) -> JSONResponse: + def get_response_file(self, response_id: str) -> FileResponse: """Get a response file""" return self._api.get_response_file(response_id) @@ -321,12 +319,12 @@ def create_response( self, run_evaluator: bool = True, sql_response_only: bool = False, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, query_request: CreateResponseRequest = None, ) -> Response: """Executes a query on the given db_connection_id""" return self._api.create_response( - run_evaluator, sql_response_only, large_query_result_in_csv, query_request + run_evaluator, sql_response_only, generate_csv, query_request ) def delete_golden_record(self, golden_record_id: str) -> dict: diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index b18816d3..1823b33b 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -44,7 +44,7 @@ def create_sql_query_status( query: str, response: Response, top_k: int = None, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, database_connection: DatabaseConnection | None = None, ) -> Response: return create_sql_query_status( @@ -52,7 +52,7 @@ def create_sql_query_status( query, response, top_k, - large_query_result_in_csv, + generate_csv, database_connection=database_connection, ) @@ -89,7 +89,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, ) -> Response: """Generates a response to a user question.""" pass diff --git a/dataherald/sql_generator/create_sql_query_status.py b/dataherald/sql_generator/create_sql_query_status.py index 925c90a5..05b4bc6b 100644 --- a/dataherald/sql_generator/create_sql_query_status.py +++ b/dataherald/sql_generator/create_sql_query_status.py @@ -10,9 +10,6 @@ from dataherald.types import Response, SQLQueryResult from dataherald.utils.s3 import S3 -MAX_ROWS_TO_CREATE_CSV_FILE = 50 -MAX_CHARACTERS_TO_CREATE_CSV_FILE = 3_000 - def format_error_message(response: Response, error_message: str) -> Response: # Remove the complete query @@ -29,16 +26,13 @@ def format_error_message(response: Response, error_message: str) -> Response: def create_csv_file( - large_query_result_in_csv: bool, + generate_csv: bool, columns: list, rows: list, response: Response, database_connection: DatabaseConnection | None = None, ): - if large_query_result_in_csv and ( - len(rows) >= MAX_ROWS_TO_CREATE_CSV_FILE - or len(str(rows)) > MAX_CHARACTERS_TO_CREATE_CSV_FILE - ): + if generate_csv: file_location = f"tmp/{str(uuid.uuid4())}.csv" with open(file_location, "w", newline="") as file: writer = csv.writer(file) @@ -47,10 +41,9 @@ def create_csv_file( for row in rows: writer.writerow(row.values()) s3 = S3() - response.csv_download_url = s3.upload( + response.csv_file_path = s3.upload( file_location, database_connection.file_storage ) - response.csv_file_path = f's3://k2-core/{file_location.split("/")[-1]}' response.sql_query_result = SQLQueryResult(columns=columns, rows=rows) @@ -59,7 +52,7 @@ def create_sql_query_status( query: str, response: Response, top_k: int = None, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, database_connection: DatabaseConnection | None = None, ) -> Response: """Find the sql query status and populate the fields sql_query_result, sql_generation_status, and error_message""" @@ -100,7 +93,7 @@ def create_sql_query_status( rows.append(modified_row) create_csv_file( - large_query_result_in_csv, + generate_csv, columns, rows, response, diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index a183355c..1de633c4 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -613,7 +613,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, ) -> Response: start_time = time.time() context_store = self.system.instance(ContextStore) @@ -712,6 +712,6 @@ def generate_response( response.sql_query, response, top_k=TOP_K, - large_query_result_in_csv=large_query_result_in_csv, + generate_csv=generate_csv, database_connection=database_connection, ) diff --git a/dataherald/sql_generator/generates_nl_answer.py b/dataherald/sql_generator/generates_nl_answer.py index ed78bb88..de82b69e 100644 --- a/dataherald/sql_generator/generates_nl_answer.py +++ b/dataherald/sql_generator/generates_nl_answer.py @@ -35,7 +35,7 @@ def execute( self, query_response: Response, sql_response_only: bool = False, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, ) -> Response: question_repository = QuestionRepository(self.storage) question = question_repository.find_by_id(query_response.question_id) @@ -57,10 +57,10 @@ def execute( query_response.sql_query, query_response, top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")), - large_query_result_in_csv=large_query_result_in_csv, + generate_csv=generate_csv, ) - if query_response.csv_download_url: + if query_response.csv_file_path: query_response.response = None return query_response diff --git a/dataherald/sql_generator/langchain_sqlagent.py b/dataherald/sql_generator/langchain_sqlagent.py index e40fc376..5e6813f4 100644 --- a/dataherald/sql_generator/langchain_sqlagent.py +++ b/dataherald/sql_generator/langchain_sqlagent.py @@ -29,7 +29,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, ) -> Response: # type: ignore logger.info(f"Generating SQL response to question: {str(user_question.dict())}") self.llm = self.model.get_model( @@ -90,5 +90,5 @@ def generate_response( self.database, response.sql_query, response, - large_query_result_in_csv=large_query_result_in_csv, + generate_csv=generate_csv, ) diff --git a/dataherald/sql_generator/langchain_sqlchain.py b/dataherald/sql_generator/langchain_sqlchain.py index f4ab1dda..90516080 100644 --- a/dataherald/sql_generator/langchain_sqlchain.py +++ b/dataherald/sql_generator/langchain_sqlchain.py @@ -47,7 +47,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, ) -> Response: start_time = time.time() self.llm = self.model.get_model( @@ -99,5 +99,5 @@ def generate_response( self.database, response.sql_query, response, - large_query_result_in_csv=large_query_result_in_csv, + generate_csv=generate_csv, ) diff --git a/dataherald/sql_generator/llamaindex.py b/dataherald/sql_generator/llamaindex.py index 3032f2b4..3d5ee4fb 100644 --- a/dataherald/sql_generator/llamaindex.py +++ b/dataherald/sql_generator/llamaindex.py @@ -35,7 +35,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - large_query_result_in_csv: bool = False, + generate_csv: bool = False, ) -> Response: start_time = time.time() logger.info(f"Generating SQL response to question: {str(user_question.dict())}") @@ -114,5 +114,5 @@ def generate_response( self.database, response.sql_query, response, - large_query_result_in_csv=large_query_result_in_csv, + generate_csv=generate_csv, ) diff --git a/dataherald/tests/sql_generator/test_generator.py b/dataherald/tests/sql_generator/test_generator.py index 53b39b0f..4b198170 100644 --- a/dataherald/tests/sql_generator/test_generator.py +++ b/dataherald/tests/sql_generator/test_generator.py @@ -18,12 +18,12 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, # noqa: ARG002 - large_query_result_in_csv: bool = None, + generate_csv: bool = None, ) -> Response: return Response( question_id="651f2d76275132d5b65175eb", response="Foo response", intermediate_steps=["foo"], sql_query="bar", - large_query_result_in_csv=None, + generate_csv=None, ) diff --git a/dataherald/types.py b/dataherald/types.py index a285fadc..6cddcf9a 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -76,7 +76,6 @@ class Response(BaseModel): sql_query: str sql_query_result: SQLQueryResult | None csv_file_path: str | None - csv_download_url: str | None sql_generation_status: str = "INVALID" error_message: str | None exec_time: float | None = None diff --git a/dataherald/utils/s3.py b/dataherald/utils/s3.py index f73dcd1f..f0950618 100644 --- a/dataherald/utils/s3.py +++ b/dataherald/utils/s3.py @@ -38,16 +38,11 @@ def upload(self, file_location, file_storage: FileStorage | None = None) -> str: file_location, bucket_name, os.path.basename(file_location) ) os.remove(file_location) + return f"s3://{bucket_name}/{file_name}" - return s3_client.generate_presigned_url( - "get_object", - Params={"Bucket": bucket_name, "Key": file_name}, - ExpiresIn=3600, # The URL will expire in 3600 seconds (1 hour) - ) - - def download_url(self, path: str, file_storage: FileStorage | None = None) -> str: + def download(self, path: str, file_storage: FileStorage | None = None) -> str: + fernet_encrypt = FernetEncrypt() path = path.split("/") - if file_storage: fernet_encrypt = FernetEncrypt() s3_client = boto3.client( @@ -64,21 +59,6 @@ def download_url(self, path: str, file_storage: FileStorage | None = None) -> st aws_access_key_id=self.settings.s3_aws_access_key_id, aws_secret_access_key=self.settings.s3_aws_secret_access_key, ) - - return s3_client.generate_presigned_url( - "get_object", - Params={"Bucket": path[2], "Key": path[-1]}, - ExpiresIn=3600, # The URL will expire in 3600 seconds (1 hour) - ) - - def download(self, path: str) -> str: - fernet_encrypt = FernetEncrypt() - path = path.split("/") - s3_client = boto3.client( - "s3", - aws_access_key_id=self.settings.s3_aws_access_key_id, - aws_secret_access_key=self.settings.s3_aws_secret_access_key, - ) file_location = f"tmp/{path[-1]}" s3_client.download_file( Bucket=path[2], Key=f"{path[-1]}", Filename=file_location