Skip to content

Commit

Permalink
[DH-5028] Improve flag to set ONLY_STORE_CSV_FILES_LOCALLY (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 authored Nov 27, 2023
1 parent ee22f5a commit acee6ab
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ MONGODB_DB_PASSWORD = 'admin'
# The enncryption key is used to encrypt database connection info before storing in Mongo. Please refer to the README on how to set it.
S3_AWS_ACCESS_KEY_ID=
S3_AWS_SECRET_ACCESS_KEY=
ONLY_STORE_CSV_FILES_LOCALLY = True # Set to True if only want to save generated CSV files locally instead of S3. Note that if stored locally they should be treated as ephemeral, i.e., they will disappear when the engine is restarted.
ONLY_STORE_CSV_FILES_LOCALLY = False # Set to True if only want to save generated CSV files locally instead of S3. Note that if stored locally they should be treated as ephemeral, i.e., they will disappear when the engine is restarted.
1 change: 0 additions & 1 deletion dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def get_response_file(
) -> FileResponse:
pass


@abstractmethod
def delete_golden_record(self, golden_record_id: str) -> dict:
pass
Expand Down
16 changes: 11 additions & 5 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from dataherald.api import API
from dataherald.api.types import Query
from dataherald.config import System
from dataherald.config import Settings, System
from dataherald.context_store import ContextStore
from dataherald.db import DB
from dataherald.db_scanner import Scanner
Expand Down Expand Up @@ -52,7 +52,6 @@
TableDescriptionRequest,
UpdateInstruction,
)
from dataherald.config import Settings
from dataherald.utils.s3 import S3

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -409,18 +408,25 @@ def get_response_file(
raise HTTPException(status_code=400, detail=str(e)) from e

if not result:
raise HTTPException(status_code=404, detail="Question, response, or db_connection not found")
raise HTTPException(
status_code=404, detail="Question, response, or db_connection not found"
)

# Check if the file is to be returned from server (locally) or from S3
if Settings().only_store_csv_files_locally:
file_location = result.csv_file_path
# check if the file exists
if not os.path.exists(file_location):
raise HTTPException(status_code=404, detail="CSV file not found. Possibly deleted/removed from server.")
raise HTTPException(
status_code=404,
detail="CSV file not found. Possibly deleted/removed from server.",
)
else:
s3 = S3()

file_location = s3.download(result.csv_file_path, db_connection.file_storage)
file_location = s3.download(
result.csv_file_path, db_connection.file_storage
)
background_tasks.add_task(delete_file, file_location)

return FileResponse(
Expand Down
4 changes: 3 additions & 1 deletion dataherald/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ class Settings(BaseSettings):
encrypt_key: str = os.environ.get("ENCRYPT_KEY")
s3_aws_access_key_id: str | None = os.environ.get("S3_AWS_ACCESS_KEY_ID")
s3_aws_secret_access_key: str | None = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
only_store_csv_files_locally: str | None = os.environ.get("ONLY_STORE_CSV_FILES_LOCALLY")
only_store_csv_files_locally: bool | None = os.environ.get(
"ONLY_STORE_CSV_FILES_LOCALLY", False
)

def require(self, key: str) -> Any:
val = self[key]
Expand Down
16 changes: 7 additions & 9 deletions dataherald/sql_generator/create_sql_query_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@

from sqlalchemy import text

from dataherald.config import Settings
from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.types import Response, SQLQueryResult
from dataherald.utils.s3 import S3

from dataherald.config import Settings



def format_error_message(response: Response, error_message: str) -> Response:
# Remove the complete query
Expand Down Expand Up @@ -44,12 +42,12 @@ def create_csv_file(
for row in rows:
writer.writerow(row.values())
if Settings().only_store_csv_files_locally:
response.csv_file_path = file_location
else:
s3 = S3()
response.csv_file_path = s3.upload(
file_location, database_connection.file_storage
)
response.csv_file_path = file_location
else:
s3 = S3()
response.csv_file_path = s3.upload(
file_location, database_connection.file_storage
)
response.sql_query_result = SQLQueryResult(columns=columns, rows=rows)


Expand Down

0 comments on commit acee6ab

Please sign in to comment.