diff --git a/.gitignore b/.gitignore index 1615be25..16119bfe 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ config/authorized_keys config/rclone tpdocs/ .env +.venv + diff --git a/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/requestmodels/models.py b/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/requestmodels/models.py index bd95a7a8..23653779 100644 --- a/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/requestmodels/models.py +++ b/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/requestmodels/models.py @@ -1,47 +1,48 @@ -from typing import List, Union, Dict, Annotated +from typing import Dict from pydantic import BaseModel, Field import os import json +if os.environ.get("GCP_CREDENTIALS"): + with open(os.environ["GCP_CREDENTIALS"]) as f: + _GCP_CREDENTIALS = json.load(f) +else: + _GCP_CREDENTIALS = {} + class S3Config(BaseModel): - access_key_id: str = Field(default="") - secret_access_key: str = Field(default="") - endpoint_url: str = Field(default="") - bucket_name: str = Field(default="") + access_key_id: str = Field(default=os.environ.get("S3_ACCESS_KEY_ID", "")) + secret_access_key: str = Field( + default=os.environ.get("S3_SECRET_ACCESS_KEY", "")) + endpoint_url: str = Field(default=os.environ.get("S3_ENDPOINT_URL", "")) + bucket_name: str = Field(default=os.environ.get("S3_BUCKET_NAME", "")) connect_timeout: int = Field(default=5) connect_attempts: int = Field(default=1) - @staticmethod - def get_defaults(): - return { - "access_key_id": "", - "secret_access_key": "", - "endpoint_url": "", - "bucket_name": "", - "connect_timeout": "5", - "connect_attempts": "1" - } + def get_config(self): + config = {"access_key_id": self.access_key_id, + "secret_access_key": self.secret_access_key, + "endpoint_url": self.endpoint_url, + "bucket_name": self.bucket_name, + "connect_timeout": self.connect_timeout, + "connect_attempts": self.connect_attempts} + set_values = sum(1 for v in config.values() if v) + return config if set_values > 2 else {} + +class GcpConfig(BaseModel): + credentials: Dict = Field(default_factory=_GCP_CREDENTIALS.copy) + project_id: str = Field(default=os.environ.get("GCP_PROJECT_ID", "")) + bucket_name: str = Field(default=os.environ.get("GCP_BUCKET_NAME", "")) def get_config(self): - return { - "access_key_id": getattr(self, "access_key_id", os.environ.get("S3_ACCESS_KEY_ID", "")), - "secret_access_key": getattr(self, "secret_access_key", os.environ.get("S3_SECRET_ACCESS_KEY", "")), - "endpoint_url": getattr(self, "endpoint_url", os.environ.get("S3_ENDPOINT_URL", "")), - "bucket_name": getattr(self, "bucket_name", os.environ.get("S3_BUCKET_NAME", "")), - "connect_timeout": "5", - "connect_attempts": "1" - } + config = {"credentials": self.credentials, + "project_id": self.project_id, + "bucket_name": self.bucket_name} + set_values = sum(1 for v in config.values() if v) + return config if set_values > 0 else {} class WebHook(BaseModel): url: str = Field(default="") - extra_params: Dict = Field(default={}) - - @staticmethod - def get_defaults(): - return { - "url": "", - "extra_params": {} - } + extra_params: Dict = Field(default_factory=dict) def has_valid_url(self): return network.is_url(self.url) @@ -49,10 +50,11 @@ def has_valid_url(self): class Input(BaseModel): request_id: str = Field(default="") modifier: str = Field(default="") - modifications: Dict = Field(default={}) - workflow_json: Dict = Field(default={}) - s3: S3Config = Field(default=S3Config.get_defaults()) - webhook: WebHook = Field(default=WebHook.get_defaults()) + modifications: Dict = Field(default_factory=dict) + workflow_json: Dict = Field(default_factory=dict) + s3: S3Config = Field(default_factory=S3Config) + gcp: GcpConfig = Field(default_factory=GcpConfig) + webhook: WebHook = Field(default_factory=WebHook) class Payload(BaseModel): input: Input diff --git a/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/requirements.txt b/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/requirements.txt index 3dc0201c..b8ff2a45 100644 --- a/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/requirements.txt +++ b/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/requirements.txt @@ -2,8 +2,11 @@ aiocache pydantic>=2 aiobotocore aiofiles +aiogoogle aiohttp fastapi==0.103 +google-auth +google-cloud-storage pathlib python-magic uvicorn==0.23 diff --git a/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/workers/postprocess_worker.py b/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/workers/postprocess_worker.py index a2e38b9c..389055da 100644 --- a/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/workers/postprocess_worker.py +++ b/build/COPY_ROOT_1/opt/ai-dock/api-wrapper/workers/postprocess_worker.py @@ -1,7 +1,13 @@ +import datetime +import aiogoogle.auth.creds +import aiogoogle.client import asyncio +import itertools import aiobotocore.session import aiofiles import aiofiles.os +from google.oauth2 import service_account +from google.cloud.storage import _signing as signing from config import config from pathlib import Path @@ -33,7 +39,31 @@ async def work(self): result = await self.response_store.get(request_id) await self.move_assets(request_id, result) - await self.upload_assets(request_id, request.input.s3.get_config(), result) + + named_upload_tasks = [] + if (s3_config := request.input.s3.get_config()): + async def upload_s3_assets(): + return ("s3", await self.upload_s3_assets(request_id, s3_config, result)) + named_upload_tasks.append( + asyncio.create_task(upload_s3_assets())) + if (gcp_config := request.input.gcp.get_config()): + async def upload_gcp_assets(): + return ("gcp", await self.upload_gcp_assets(request_id, gcp_config, result)) + named_upload_tasks.append( + asyncio.create_task(upload_gcp_assets())) + if named_upload_tasks: + named_presigned_urls = dict(await asyncio.gather(*named_upload_tasks)) + presigned_urls = itertools.zip_longest( + named_presigned_urls.get("s3", []), + named_presigned_urls.get("gcp", []), + fillvalue=None) + for obj, (s3_url, gcp_url) in zip(result.output, presigned_urls): + if s3_url: + # Keeping for backward compatibility + obj["url"] = s3_url + obj["s3_url"] = s3_url + if gcp_url: + obj["gcp_url"] = gcp_url result.status = "success" result.message = "Process complete." @@ -77,7 +107,7 @@ async def move_assets(self, request_id, result): "local_path": new_path }) - async def upload_assets(self, request_id, s3_config, result): + async def upload_s3_assets(self, request_id, s3_config, result): session = aiobotocore.session.get_session() async with session.create_client( 's3', @@ -96,16 +126,12 @@ async def upload_assets(self, request_id, s3_config, result): tasks.append(task) # Run all tasks concurrently - presigned_urls = await asyncio.gather(*tasks) - - # Append the presigned URLs to the respective objects - for obj, url in zip(result.output, presigned_urls): - obj["url"] = url + return await asyncio.gather(*tasks) - async def upload_file_and_get_url(self, requst_id, s3_client, bucket_name, local_path): + async def upload_file_and_get_url(self, request_id, s3_client, bucket_name, local_path): # Get the file name from the local path - file_name = f"{requst_id}/{Path(local_path).name}" - print (f"uploading {file_name}") + file_name = f"{request_id}/{Path(local_path).name}" + print(f"uploading to s3 {file_name}") try: # Upload the file @@ -116,9 +142,51 @@ async def upload_file_and_get_url(self, requst_id, s3_client, bucket_name, local presigned_url = await s3_client.generate_presigned_url( 'get_object', Params={'Bucket': bucket_name, 'Key': file_name}, - ExpiresIn=604800 # URL expiration time in seconds + ExpiresIn=int(datetime.timedelta(days=7).total_seconds()), ) return presigned_url except Exception as e: - print(f"Error uploading {local_path}: {e}") - return None \ No newline at end of file + print(f"Error uploading to s3 {local_path}: {e}") + return None + + async def upload_gcp_assets(self, request_id, gcp_config, result): + creds = aiogoogle.auth.creds.ServiceAccountCreds( + scopes=["https://www.googleapis.com/auth/cloud-platform"], + **gcp_config["credentials"], + ) + google_credentials = service_account.Credentials.from_service_account_info( + gcp_config["credentials"]) + aiog_client = aiogoogle.client.Aiogoogle(service_account_creds=creds) + async with aiog_client: + # Not needed as we are using provided service account creds. Uncomment if using discovery. + # await aiog_client.service_account_manager.detect_default_creds_source() + storage = await aiog_client.discover("storage", "v1") + tasks = [] + for obj in result.output: + local_path = obj["local_path"] + task = asyncio.create_task(self.upload_file_to_gcp_and_get_url( + request_id, aiog_client, storage, gcp_config["bucket_name"], local_path, google_credentials)) + tasks.append(task) + + # Run all tasks concurrently + return await asyncio.gather(*tasks) + + async def upload_file_to_gcp_and_get_url(self, request_id, aiog_client, storage, bucket_name, local_path, google_credentials): + destination_path = f"{request_id}/{Path(local_path).name}" + print(f"uploading to gcp {destination_path}") + + try: + await aiog_client.as_service_account(storage.objects.insert( + bucket=bucket_name, + name=destination_path, + upload_file=local_path, + ), full_res=True) + return signing.generate_signed_url_v4( + google_credentials, + f"/{bucket_name}/{destination_path}", + expiration=datetime.timedelta(days=7), + method="GET", + ) + except Exception as e: + print(f"Error uploading to gcp {local_path}: {e}") + return None