Skip to content

Commit

Permalink
db_connections endpoints support new file storages
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Nov 1, 2023
1 parent e18cb7f commit 8c5acd0
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 72 deletions.
12 changes: 5 additions & 7 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List

from fastapi import BackgroundTasks
from fastapi.responses import FileResponse
from fastapi.responses import JSONResponse

from dataherald.api.types import Query
from dataherald.config import Component
Expand Down Expand Up @@ -39,15 +39,15 @@ def scan_db(
@abstractmethod
def answer_question(
self,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
pass

@abstractmethod
def answer_question_with_timeout(
self,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
pass
Expand Down Expand Up @@ -109,7 +109,7 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]:
@abstractmethod
def create_response(
self,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
query_request: CreateResponseRequest = None,
) -> Response:
pass
Expand All @@ -123,9 +123,7 @@ def get_response(self, response_id: str) -> Response:
pass

@abstractmethod
def get_response_file(
self, response_id: str, background_tasks: BackgroundTasks
) -> FileResponse:
def get_response_file(self, response_id: str) -> JSONResponse:
pass

@abstractmethod
Expand Down
42 changes: 18 additions & 24 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bson import json_util
from bson.objectid import InvalidId, ObjectId
from fastapi import BackgroundTasks, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from fastapi.responses import JSONResponse
from overrides import override

from dataherald.api import API
Expand Down Expand Up @@ -64,10 +64,6 @@ def async_scanning(scanner, database, scanner_request, storage):
)


def async_removing_file(file_path: str):
os.remove(file_path)


class FastAPI(API):
def __init__(self, system: System):
super().__init__(system)
Expand Down Expand Up @@ -125,7 +121,7 @@ def scan_db(
@override
def answer_question(
self,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
"""Takes in an English question and answers it based on content from the registered databases"""
Expand Down Expand Up @@ -161,7 +157,7 @@ def answer_question(
user_question,
database_connection,
context[0],
store_substantial_query_result_in_csv,
large_query_result_in_csv,
)
logger.info("Starts evaluator...")
confidence_score = evaluator.get_confidence_score(
Expand All @@ -172,6 +168,8 @@ def answer_question(
status_code=400,
content={"question_id": user_question.id, "error_message": str(e)},
)
if generated_answer.csv_download_url:
generated_answer.sql_query_result = None
generated_answer.confidence_score = confidence_score
generated_answer.exec_time = time.time() - start_generated_answer
response_repository = ResponseRepository(self.storage)
Expand All @@ -180,7 +178,7 @@ def answer_question(
@override
def answer_question_with_timeout(
self,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
result = None
Expand All @@ -197,7 +195,7 @@ def run_and_catch_exceptions():
nonlocal result, exception
if not stop_event.is_set():
result = self.answer_question(
store_substantial_query_result_in_csv, question_request
large_query_result_in_csv, question_request
)

thread = threading.Thread(target=run_and_catch_exceptions)
Expand Down Expand Up @@ -226,6 +224,7 @@ def create_database_connection(
llm_api_key=database_connection_request.llm_api_key,
use_ssh=database_connection_request.use_ssh,
ssh_settings=database_connection_request.ssh_settings,
file_storage=database_connection_request.file_storage,
)

SQLDatabase.get_sql_engine(db_connection, True)
Expand Down Expand Up @@ -260,6 +259,7 @@ def update_database_connection(
llm_api_key=database_connection_request.llm_api_key,
use_ssh=database_connection_request.use_ssh,
ssh_settings=database_connection_request.ssh_settings,
file_storage=database_connection_request.file_storage,
)

SQLDatabase.get_sql_engine(db_connection, True)
Expand Down Expand Up @@ -365,9 +365,7 @@ def get_response(self, response_id: str) -> Response:
return result

@override
def get_response_file(
self, response_id: str, background_tasks: BackgroundTasks
) -> FileResponse:
def get_response_file(self, response_id: str) -> JSONResponse:
response_repository = ResponseRepository(self.storage)

try:
Expand All @@ -378,15 +376,11 @@ def get_response_file(
if not result:
raise HTTPException(status_code=404, detail="Question not found")

# todo download
s3 = S3()
file_path = s3.download(result.csv_file_path)
background_tasks.add_task(async_removing_file, file_path)
return FileResponse(
file_path,
media_type="text/csv",
headers={
"Content-Disposition": f"attachment; filename={file_path.split('/')[-1]}"
return JSONResponse(
status_code=201,
content={
"csv_download_url": s3.download_url(result.csv_file_path),
},
)

Expand Down Expand Up @@ -440,7 +434,7 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]:
@override
def create_response(
self,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
query_request: CreateResponseRequest = None, # noqa: ARG002
) -> Response:
evaluator = self.system.instance(Evaluator)
Expand All @@ -461,13 +455,13 @@ def create_response(
start_generated_answer = time.time()
try:
generates_nl_answer = GeneratesNlAnswer(self.system, self.storage)
response = generates_nl_answer.execute(
response, store_substantial_query_result_in_csv
)
response = generates_nl_answer.execute(response, large_query_result_in_csv)
confidence_score = evaluator.get_confidence_score(
user_question, response, database_connection
)
response.confidence_score = confidence_score
if response.csv_download_url:
response.sql_query_result = None
response.exec_time = time.time() - start_generated_answer
response_repository.update(response)
except ValueError as e:
Expand Down
24 changes: 9 additions & 15 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ def __init__(self, settings: Settings):
)

self.router.add_api_route(
"/api/v1/responses/{response_id}/file",
"/api/v1/responses/{response_id}/generate-csv-download-url",
self.get_response_file,
methods=["GET"],
methods=["POST"],
tags=["Responses"],
)

Expand Down Expand Up @@ -225,16 +225,14 @@ def scan_db(

def answer_question(
self,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
question_request: QuestionRequest = None,
) -> Response:
if os.getenv("DH_ENGINE_TIMEOUT", None):
return self._api.answer_question_with_timeout(
store_substantial_query_result_in_csv, question_request
large_query_result_in_csv, question_request
)
return self._api.answer_question(
store_substantial_query_result_in_csv, question_request
)
return self._api.answer_question(large_query_result_in_csv, question_request)

def get_questions(self, db_connection_id: str | None = None) -> list[Question]:
return self._api.get_questions(db_connection_id)
Expand Down Expand Up @@ -297,25 +295,21 @@ def get_response(self, response_id: str) -> Response:
"""Get a response"""
return self._api.get_response(response_id)

def get_response_file(
self, response_id: str, background_tasks: BackgroundTasks
) -> FileResponse:
def get_response_file(self, response_id: str) -> JSONResponse:
"""Get a response file"""
return self._api.get_response_file(response_id, background_tasks)
return self._api.get_response_file(response_id)

def execute_sql_query(self, query: Query) -> tuple[str, dict]:
"""Executes a query on the given db_connection_id"""
return self._api.execute_sql_query(query)

def create_response(
self,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
query_request: CreateResponseRequest = None,
) -> Response:
"""Executes a query on the given db_connection_id"""
return self._api.create_response(
store_substantial_query_result_in_csv, query_request
)
return self._api.create_response(large_query_result_in_csv, query_request)

def delete_golden_record(self, golden_record_id: str) -> dict:
"""Deletes a golden record"""
Expand Down
24 changes: 24 additions & 0 deletions dataherald/sql_database/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,29 @@ def __getitem__(self, key: str) -> Any:
return getattr(self, key)


class FileStorage(BaseModel):
name: str
access_key_id: str
secret_access_key: str
region: str | None
bucket: str

class Config:
extra = Extra.ignore

@validator("access_key_id", "secret_access_key", pre=True, always=True)
def encrypt(cls, value: str):
fernet_encrypt = FernetEncrypt()
try:
fernet_encrypt.decrypt(value)
return value
except Exception:
return fernet_encrypt.encrypt(value)

def __getitem__(self, key: str) -> Any:
return getattr(self, key)


class SSHSettings(BaseSettings):
db_name: str | None
host: str | None
Expand Down Expand Up @@ -59,6 +82,7 @@ class DatabaseConnection(BaseModel):
path_to_credentials_file: str | None
llm_api_key: str | None = None
ssh_settings: SSHSettings | None = None
file_storage: FileStorage | None = None

@validator("uri", pre=True, always=True)
def set_uri_without_ssh(cls, v, values):
Expand Down
12 changes: 9 additions & 3 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,16 @@ def create_sql_query_status(
query: str,
response: Response,
top_k: int = None,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
database_connection: DatabaseConnection | None = None,
) -> Response:
return create_sql_query_status(
db, query, response, top_k, store_substantial_query_result_in_csv
db,
query,
response,
top_k,
large_query_result_in_csv,
database_connection=database_connection,
)

def format_intermediate_representations(
Expand Down Expand Up @@ -83,7 +89,7 @@ def generate_response(
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
) -> Response:
"""Generates a response to a user question."""
pass
19 changes: 14 additions & 5 deletions dataherald/sql_generator/create_sql_query_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlalchemy import text

from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.types import Response, SQLQueryResult
from dataherald.utils.s3 import S3

Expand All @@ -28,12 +29,13 @@ def format_error_message(response: Response, error_message: str) -> Response:


def create_csv_file(
store_substantial_query_result_in_csv: bool,
large_query_result_in_csv: bool,
columns: list,
rows: list,
response: Response,
database_connection: DatabaseConnection | None = None,
):
if store_substantial_query_result_in_csv and (
if large_query_result_in_csv and (
len(rows) >= MAX_ROWS_TO_CREATE_CSV_FILE
or len(str(rows)) > MAX_CHARACTERS_TO_CREATE_CSV_FILE
):
Expand All @@ -45,7 +47,9 @@ def create_csv_file(
for row in rows:
writer.writerow(row.values())
s3 = S3()
s3.upload(file_location)
response.csv_download_url = s3.upload(
file_location, database_connection.file_storage
)
response.csv_file_path = f's3://k2-core/{file_location.split("/")[-1]}'
response.sql_query_result = SQLQueryResult(columns=columns, rows=rows)

Expand All @@ -55,7 +59,8 @@ def create_sql_query_status(
query: str,
response: Response,
top_k: int = None,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
database_connection: DatabaseConnection | None = None,
) -> Response:
"""Find the sql query status and populate the fields sql_query_result, sql_generation_status, and error_message"""
if query == "":
Expand Down Expand Up @@ -95,7 +100,11 @@ def create_sql_query_status(
rows.append(modified_row)

create_csv_file(
store_substantial_query_result_in_csv, columns, rows, response
large_query_result_in_csv,
columns,
rows,
response,
database_connection,
)

response.sql_generation_status = "VALID"
Expand Down
5 changes: 3 additions & 2 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def generate_response(
user_question: Question,
database_connection: DatabaseConnection,
context: List[dict] = None,
store_substantial_query_result_in_csv: bool = False,
large_query_result_in_csv: bool = False,
) -> Response:
start_time = time.time()
context_store = self.system.instance(ContextStore)
Expand Down Expand Up @@ -685,5 +685,6 @@ def generate_response(
response.sql_query,
response,
top_k=TOP_K,
store_substantial_query_result_in_csv=store_substantial_query_result_in_csv,
large_query_result_in_csv=large_query_result_in_csv,
database_connection=database_connection,
)
Loading

0 comments on commit 8c5acd0

Please sign in to comment.