diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index 21ef7df7..54316c6d 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -2,6 +2,7 @@ from typing import List from fastapi import BackgroundTasks +from fastapi.responses import FileResponse from dataherald.api.types import Query from dataherald.config import Component @@ -36,12 +37,18 @@ def scan_db( pass @abstractmethod - def answer_question(self, question_request: QuestionRequest) -> Response: + def answer_question( + self, + store_substantial_query_result_in_csv: bool = False, + question_request: QuestionRequest = None, + ) -> Response: pass @abstractmethod def answer_question_with_timeout( - self, question_request: QuestionRequest + self, + store_substantial_query_result_in_csv: bool = False, + question_request: QuestionRequest = None, ) -> Response: pass @@ -100,7 +107,11 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]: pass @abstractmethod - def create_response(self, query_request: CreateResponseRequest) -> Response: + def create_response( + self, + store_substantial_query_result_in_csv: bool = False, + query_request: CreateResponseRequest = None, + ) -> Response: pass @abstractmethod @@ -111,6 +122,12 @@ def get_responses(self, question_id: str | None = None) -> list[Response]: def get_response(self, response_id: str) -> Response: pass + @abstractmethod + def get_response_file( + self, response_id: str, background_tasks: BackgroundTasks + ) -> FileResponse: + pass + @abstractmethod def delete_golden_record(self, golden_record_id: str) -> dict: pass diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index d4ed1c50..d3cb7d4a 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -8,7 +8,7 @@ from bson import json_util from bson.objectid import InvalidId, ObjectId from fastapi import BackgroundTasks, HTTPException -from fastapi.responses import JSONResponse +from fastapi.responses import FileResponse, JSONResponse from overrides import override from dataherald.api import API @@ -50,6 +50,7 @@ TableDescriptionRequest, UpdateInstruction, ) +from dataherald.utils.s3 import S3 logger = logging.getLogger(__name__) @@ -63,6 +64,10 @@ def async_scanning(scanner, database, scanner_request, storage): ) +def async_removing_file(file_path: str): + os.remove(file_path) + + class FastAPI(API): def __init__(self, system: System): super().__init__(system) @@ -118,7 +123,11 @@ def scan_db( return True @override - def answer_question(self, question_request: QuestionRequest) -> Response: + def answer_question( + self, + store_substantial_query_result_in_csv: bool = False, + question_request: QuestionRequest = None, + ) -> Response: """Takes in an English question and answers it based on content from the registered databases""" logger.info(f"Answer question: {question_request.question}") sql_generation = self.system.instance(SQLGenerator) @@ -149,7 +158,10 @@ def answer_question(self, question_request: QuestionRequest) -> Response: start_generated_answer = time.time() try: generated_answer = sql_generation.generate_response( - user_question, database_connection, context[0] + user_question, + database_connection, + context[0], + store_substantial_query_result_in_csv, ) logger.info("Starts evaluator...") confidence_score = evaluator.get_confidence_score( @@ -167,7 +179,9 @@ def answer_question(self, question_request: QuestionRequest) -> Response: @override def answer_question_with_timeout( - self, question_request: QuestionRequest + self, + store_substantial_query_result_in_csv: bool = False, + question_request: QuestionRequest = None, ) -> Response: result = None exception = None @@ -182,7 +196,9 @@ def answer_question_with_timeout( def run_and_catch_exceptions(): nonlocal result, exception if not stop_event.is_set(): - result = self.answer_question(question_request) + result = self.answer_question( + store_substantial_query_result_in_csv, question_request + ) thread = threading.Thread(target=run_and_catch_exceptions) thread.start() @@ -348,6 +364,32 @@ def get_response(self, response_id: str) -> Response: return result + @override + def get_response_file( + self, response_id: str, background_tasks: BackgroundTasks + ) -> FileResponse: + response_repository = ResponseRepository(self.storage) + + try: + result = response_repository.find_by_id(response_id) + except InvalidId as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + if not result: + raise HTTPException(status_code=404, detail="Question not found") + + # todo download + s3 = S3() + file_path = s3.download(result.csv_file_path) + background_tasks.add_task(async_removing_file, file_path) + return FileResponse( + file_path, + media_type="text/csv", + headers={ + "Content-Disposition": f"attachment; filename={file_path.split('/')[-1]}" + }, + ) + @override def get_questions(self, db_connection_id: str | None = None) -> list[Question]: question_repository = QuestionRepository(self.storage) @@ -397,7 +439,9 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]: @override def create_response( - self, query_request: CreateResponseRequest # noqa: ARG002 + self, + store_substantial_query_result_in_csv: bool = False, + query_request: CreateResponseRequest = None, # noqa: ARG002 ) -> Response: evaluator = self.system.instance(Evaluator) question_repository = QuestionRepository(self.storage) @@ -417,7 +461,9 @@ def create_response( start_generated_answer = time.time() try: generates_nl_answer = GeneratesNlAnswer(self.system, self.storage) - response = generates_nl_answer.execute(response) + response = generates_nl_answer.execute( + response, store_substantial_query_result_in_csv + ) confidence_score = evaluator.get_confidence_score( user_question, response, database_connection ) diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 89da742e..eccac896 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -4,7 +4,7 @@ import fastapi from fastapi import BackgroundTasks, status from fastapi import FastAPI as _FastAPI -from fastapi.responses import JSONResponse +from fastapi.responses import FileResponse, JSONResponse from fastapi.routing import APIRoute import dataherald @@ -164,6 +164,13 @@ def __init__(self, settings: Settings): tags=["Responses"], ) + self.router.add_api_route( + "/api/v1/responses/{response_id}/file", + self.get_response_file, + methods=["GET"], + tags=["Responses"], + ) + self.router.add_api_route( "/api/v1/sql-query-executions", self.execute_sql_query, @@ -216,10 +223,18 @@ def scan_db( ) -> bool: return self._api.scan_db(scanner_request, background_tasks) - def answer_question(self, question_request: QuestionRequest) -> Response: + def answer_question( + self, + store_substantial_query_result_in_csv: bool = False, + question_request: QuestionRequest = None, + ) -> Response: if os.getenv("DH_ENGINE_TIMEOUT", None): - return self._api.answer_question_with_timeout(question_request) - return self._api.answer_question(question_request) + return self._api.answer_question_with_timeout( + store_substantial_query_result_in_csv, question_request + ) + return self._api.answer_question( + store_substantial_query_result_in_csv, question_request + ) def get_questions(self, db_connection_id: str | None = None) -> list[Question]: return self._api.get_questions(db_connection_id) @@ -282,13 +297,25 @@ def get_response(self, response_id: str) -> Response: """Get a response""" return self._api.get_response(response_id) + def get_response_file( + self, response_id: str, background_tasks: BackgroundTasks + ) -> FileResponse: + """Get a response file""" + return self._api.get_response_file(response_id, background_tasks) + def execute_sql_query(self, query: Query) -> tuple[str, dict]: """Executes a query on the given db_connection_id""" return self._api.execute_sql_query(query) - def create_response(self, query_request: CreateResponseRequest) -> Response: + def create_response( + self, + store_substantial_query_result_in_csv: bool = False, + query_request: CreateResponseRequest = None, + ) -> Response: """Executes a query on the given db_connection_id""" - return self._api.create_response(query_request) + return self._api.create_response( + store_substantial_query_result_in_csv, query_request + ) def delete_golden_record(self, golden_record_id: str) -> dict: """Deletes a golden record""" diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 9d5673c4..dbc77b44 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -39,9 +39,16 @@ def check_for_time_out_or_tool_limit(self, response: dict) -> dict: return response def create_sql_query_status( - self, db: SQLDatabase, query: str, response: Response, top_k: int = None + self, + db: SQLDatabase, + query: str, + response: Response, + top_k: int = None, + store_substantial_query_result_in_csv: bool = False, ) -> Response: - return create_sql_query_status(db, query, response, top_k) + return create_sql_query_status( + db, query, response, top_k, store_substantial_query_result_in_csv + ) def format_intermediate_representations( self, intermediate_representation: List[Tuple[AgentAction, str]] @@ -76,6 +83,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, + store_substantial_query_result_in_csv: bool = False, ) -> Response: """Generates a response to a user question.""" pass diff --git a/dataherald/sql_generator/create_sql_query_status.py b/dataherald/sql_generator/create_sql_query_status.py index 310158ba..11a8a3df 100644 --- a/dataherald/sql_generator/create_sql_query_status.py +++ b/dataherald/sql_generator/create_sql_query_status.py @@ -1,3 +1,5 @@ +import csv +import uuid from datetime import date, datetime from decimal import Decimal @@ -5,6 +7,10 @@ from dataherald.sql_database.base import SQLDatabase, SQLInjectionError from dataherald.types import Response, SQLQueryResult +from dataherald.utils.s3 import S3 + +MAX_ROWS_TO_CREATE_CSV_FILE = 50 +MAX_CHARACTERS_TO_CREATE_CSV_FILE = 3_000 def format_error_message(response: Response, error_message: str) -> Response: @@ -21,8 +27,35 @@ def format_error_message(response: Response, error_message: str) -> Response: return response +def create_csv_file( + store_substantial_query_result_in_csv: bool, + columns: list, + rows: list, + response: Response, +): + if store_substantial_query_result_in_csv and ( + len(rows) >= MAX_ROWS_TO_CREATE_CSV_FILE + or len(str(rows)) > MAX_CHARACTERS_TO_CREATE_CSV_FILE + ): + file_location = f"tmp/{str(uuid.uuid4())}.csv" + with open(file_location, "w", newline="") as file: + writer = csv.writer(file) + + writer.writerow(rows[0].keys()) + for row in rows: + writer.writerow(row.values()) + s3 = S3() + s3.upload(file_location) + response.csv_file_path = f's3://k2-core/{file_location.split("/")[-1]}' + response.sql_query_result = SQLQueryResult(columns=columns, rows=rows) + + def create_sql_query_status( - db: SQLDatabase, query: str, response: Response, top_k: int = None + db: SQLDatabase, + query: str, + response: Response, + top_k: int = None, + store_substantial_query_result_in_csv: bool = False, ) -> Response: """Find the sql query status and populate the fields sql_query_result, sql_generation_status, and error_message""" if query == "": @@ -60,7 +93,11 @@ def create_sql_query_status( else: modified_row[key] = value rows.append(modified_row) - response.sql_query_result = SQLQueryResult(columns=columns, rows=rows) + + create_csv_file( + store_substantial_query_result_in_csv, columns, rows, response + ) + response.sql_generation_status = "VALID" response.error_message = None except SQLInjectionError as e: diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 70f70dfa..03db4c26 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -601,6 +601,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, + store_substantial_query_result_in_csv: bool = False, ) -> Response: start_time = time.time() context_store = self.system.instance(ContextStore) @@ -690,5 +691,9 @@ def generate_response( sql_query=sql_query_list[-1] if len(sql_query_list) > 0 else "", ) return self.create_sql_query_status( - self.database, response.sql_query, response, top_k=TOP_K + self.database, + response.sql_query, + response, + top_k=TOP_K, + store_substantial_query_result_in_csv=store_substantial_query_result_in_csv, ) diff --git a/dataherald/sql_generator/generates_nl_answer.py b/dataherald/sql_generator/generates_nl_answer.py index d412a4d0..26318128 100644 --- a/dataherald/sql_generator/generates_nl_answer.py +++ b/dataherald/sql_generator/generates_nl_answer.py @@ -31,7 +31,11 @@ def __init__(self, system, storage): self.storage = storage self.model = ChatModel(self.system) - def execute(self, query_response: Response) -> Response: + def execute( + self, + query_response: Response, + store_substantial_query_result_in_csv: bool = False, + ) -> Response: question_repository = QuestionRepository(self.storage) question = question_repository.find_by_id(query_response.question_id) @@ -50,6 +54,7 @@ def execute(self, query_response: Response) -> Response: query_response.sql_query, query_response, top_k=int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50")), + store_substantial_query_result_in_csv=store_substantial_query_result_in_csv, ) system_message_prompt = SystemMessagePromptTemplate.from_template( SYSTEM_TEMPLATE diff --git a/dataherald/sql_generator/langchain_sqlagent.py b/dataherald/sql_generator/langchain_sqlagent.py index 99201478..328041e1 100644 --- a/dataherald/sql_generator/langchain_sqlagent.py +++ b/dataherald/sql_generator/langchain_sqlagent.py @@ -29,6 +29,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, + store_substantial_query_result_in_csv: bool = False, ) -> Response: # type: ignore logger.info(f"Generating SQL response to question: {str(user_question.dict())}") self.llm = self.model.get_model( @@ -85,4 +86,9 @@ def generate_response( total_cost=cb.total_cost, sql_query=sql_query_list[-1] if len(sql_query_list) > 0 else "", ) - return self.create_sql_query_status(self.database, response.sql_query, response) + return self.create_sql_query_status( + self.database, + response.sql_query, + response, + store_substantial_query_result_in_csv=store_substantial_query_result_in_csv, + ) diff --git a/dataherald/sql_generator/langchain_sqlchain.py b/dataherald/sql_generator/langchain_sqlchain.py index 48e0bdad..30c3d2c9 100644 --- a/dataherald/sql_generator/langchain_sqlchain.py +++ b/dataherald/sql_generator/langchain_sqlchain.py @@ -47,6 +47,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, + store_substantial_query_result_in_csv: bool = False, ) -> Response: start_time = time.time() self.llm = self.model.get_model( @@ -94,4 +95,9 @@ def generate_response( total_tokens=cb.total_tokens, sql_query=self.format_sql_query(result["intermediate_steps"][1]), ) - return self.create_sql_query_status(self.database, response.sql_query, response) + return self.create_sql_query_status( + self.database, + response.sql_query, + response, + store_substantial_query_result_in_csv=store_substantial_query_result_in_csv, + ) diff --git a/dataherald/sql_generator/llamaindex.py b/dataherald/sql_generator/llamaindex.py index 59de3c90..224de63d 100644 --- a/dataherald/sql_generator/llamaindex.py +++ b/dataherald/sql_generator/llamaindex.py @@ -35,6 +35,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, + store_substantial_query_result_in_csv: bool = False, ) -> Response: start_time = time.time() logger.info(f"Generating SQL response to question: {str(user_question.dict())}") @@ -109,4 +110,9 @@ def generate_response( intermediate_steps=[str(result.metadata)], sql_query=self.format_sql_query(result.metadata["sql_query"]), ) - return self.create_sql_query_status(self.database, response.sql_query, response) + return self.create_sql_query_status( + self.database, + response.sql_query, + response, + store_substantial_query_result_in_csv=store_substantial_query_result_in_csv, + ) diff --git a/dataherald/types.py b/dataherald/types.py index c85ef9ad..70524e29 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -75,6 +75,7 @@ class Response(BaseModel): intermediate_steps: list[str] | None = None sql_query: str sql_query_result: SQLQueryResult | None + csv_file_path: str | None sql_generation_status: str = "INVALID" error_message: str | None exec_time: float | None = None diff --git a/dataherald/utils/s3.py b/dataherald/utils/s3.py index e185e482..dd575529 100644 --- a/dataherald/utils/s3.py +++ b/dataherald/utils/s3.py @@ -1,3 +1,5 @@ +import os + import boto3 from cryptography.fernet import InvalidToken @@ -9,6 +11,22 @@ class S3: def __init__(self): self.settings = Settings() + def upload(self, file_location) -> str: + file_name = file_location.split("/")[-1] + bucket_name = "k2-core" + + # Upload the file + s3_client = boto3.client( + "s3", + aws_access_key_id=self.settings.s3_aws_access_key_id, + aws_secret_access_key=self.settings.s3_aws_secret_access_key, + ) + s3_client.upload_file( + file_location, bucket_name, os.path.basename(file_location) + ) + os.remove(file_location) + return f"s3://{bucket_name}/{file_name}" + def download(self, path: str) -> str: fernet_encrypt = FernetEncrypt() path = path.split("/")