Skip to content

Commit

Permalink
model_nickname -> resource_id
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Feb 7, 2024
1 parent d86412c commit 6fec3b5
Show file tree
Hide file tree
Showing 14 changed files with 478 additions and 513 deletions.
87 changes: 37 additions & 50 deletions .github/scripts/s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,12 @@ def bucket_exists(self, bucket):
return self._client.bucket_exists(bucket)

def put(
self,
path,
file_object,
length=-1,
content_type="application/octet-stream"):
self, path, file_object, length=-1, content_type="application/octet-stream"
):
# For unknown length (ie without reading file into mem) give `part_size`
part_size = 0
if length == -1:
part_size = 10*1024*1024
part_size = 10 * 1024 * 1024
path = f"{self.prefix}/{path}"
self._client.put_object(
self.bucket,
Expand All @@ -62,19 +59,16 @@ def put(
)

def get_file_urls(
self,
path="",
exclude_files=("status.json"),
lifetime=timedelta(hours=1),
) -> list[str]:
self,
path="",
exclude_files=("status.json"),
lifetime=timedelta(hours=1),
) -> list[str]:
"""Checks an S3 'folder' for its list of files"""
logger.debug("Getting file list using {}, at {}", self, path)
path = f"{self.prefix}/{path}"
objects = self._client.list_objects(
self.bucket,
prefix=path,
recursive=True)
file_urls : list[str] = []
objects = self._client.list_objects(self.bucket, prefix=path, recursive=True)
file_urls: list[str] = []
for obj in objects:
if obj.is_dir:
continue
Expand All @@ -92,7 +86,6 @@ def get_file_urls(
# Option 2: Work with minio.datatypes.Object directly
return file_urls


def ls(self, path, only_folders=False, only_files=False) -> Iterator[str]:
"""
List folder contents, non-recursive, ala `ls`
Expand All @@ -101,18 +94,14 @@ def ls(self, path, only_folders=False, only_files=False) -> Iterator[str]:
# path = str(Path(self.prefix, path))
path = f"{self.prefix}/{path}"
logger.debug("Running ls at path: {}", path)
objects = self._client.list_objects(
self.bucket,
prefix=path,
recursive=False)
objects = self._client.list_objects(self.bucket, prefix=path, recursive=False)
for obj in objects:
if only_files and obj.is_dir:
continue
if only_folders and not obj.is_dir:
continue
yield Path(obj.object_name).name


def load_file(self, path) -> str:
"""Load file from S3"""
path = f"{self.prefix}/{path}"
Expand All @@ -131,31 +120,31 @@ def load_file(self, path) -> str:
return content

# url = self.client.get_presigned_url(
# "GET",
# self.bucket,
# str(Path(self.prefix, path)),
# expires=timedelta(minutes=10),
# "GET",
# self.bucket,
# str(Path(self.prefix, path)),
# expires=timedelta(minutes=10),
# )
# response = requests.get(url)
# return response.content

def check_versions(self, model_name: str) -> Iterator[VersionStatus]:
def check_versions(self, resource_id: str) -> Iterator[VersionStatus]:
"""
Check model repository for version of model-name.
Returns dictionary of version-status pairs.
"""
logger.debug("Checking versions for {}", model_name)
version_folders = self.ls(f"{model_name}/", only_folders=True)
logger.debug("Checking versions for {}", resource_id)
version_folders = self.ls(f"{resource_id}/", only_folders=True)

# For each folder get the contents of status.json
for version in version_folders:

yield self.get_version_status(model_name, version)
yield self.get_version_status(resource_id, version)

def get_unpublished_version(self, model_name:str) -> str:
def get_unpublished_version(self, resource_id: str) -> str:
"""Get the unpublisted version"""
versions = list(self.check_versions(model_name))
versions = list(self.check_versions(resource_id))
if len(versions) == 0:
return "1"
unpublished = [version for version in versions if version.status == "staging"]
Expand All @@ -166,49 +155,49 @@ def get_unpublished_version(self, model_name:str) -> str:
raise ValueError("Opps! We seem to have > 1 staging versions!!")
return unpublished[0].version

def get_version_status(self, model_name: str, version: str) -> VersionStatus:
status = self.get_status(model_name, version)
status_str = status.get('status', 'status-field-unset')
version_path = f"{model_name}/{version}"
def get_version_status(self, resource_id: str, version: str) -> VersionStatus:
status = self.get_status(resource_id, version)
status_str = status.get("status", "status-field-unset")
version_path = f"{resource_id}/{version}"
return VersionStatus(version, status_str, version_path)

def get_status(self, model_name: str, version: str) -> dict:
version_path = f"{model_name}/{version}"
logger.debug("model_name: {}, version: {}", model_name, version)
def get_status(self, resource_id: str, version: str) -> dict:
version_path = f"{resource_id}/{version}"
logger.debug("resource_id: {}, version: {}", resource_id, version)
status_path = f"{version_path}/status.json"
logger.debug("Getting status using path {}", status_path)
status = self.load_file(status_path)
status = json.loads(status)
return status

def put_status(self, model_name: str, version: str, status: dict):
logger.debug("Updating status for {}-{}, with {}", model_name, version, status)
def put_status(self, resource_id: str, version: str, status: dict):
logger.debug("Updating status for {}-{}, with {}", resource_id, version, status)
contents = json.dumps(status).encode()
file_object = io.BytesIO(contents)

self.put(
f"{model_name}/{version}/status.json",
f"{resource_id}/{version}/status.json",
file_object,
length=len(contents),
content_type="application/json",
)

def get_log(self, model_name: str, version: str) -> dict:
version_path = f"{model_name}/{version}"
logger.debug("model_name: {}, version: {}", model_name, version)
def get_log(self, resource_id: str, version: str) -> dict:
version_path = f"{resource_id}/{version}"
logger.debug("resource_id: {}, version: {}", resource_id, version)
path = f"{version_path}/log.json"
logger.debug("Getting log using path {}", path)
log = self.load_file(path)
log = json.loads(log)
return log

def put_log(self, model_name: str, version: str, log: dict):
logger.debug("Updating log for {}-{}, with {}", model_name, version, log)
def put_log(self, resource_id: str, version: str, log: dict):
logger.debug("Updating log for {}-{}, with {}", resource_id, version, log)
contents = json.dumps(log).encode()
file_object = io.BytesIO(contents)

self.put(
f"{model_name}/{version}/log.json",
f"{resource_id}/{version}/log.json",
file_object,
length=len(contents),
content_type="application/json",
Expand Down Expand Up @@ -239,5 +228,3 @@ def create_client() -> Client:
secret_key=secret_access_key,
)
return client


34 changes: 18 additions & 16 deletions .github/scripts/unzip_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import argparse
import io
import traceback
from typing import Optional
import urllib.request
import zipfile
from typing import Optional


from update_status import update_status
from s3_client import create_client
from update_status import update_status


def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("model_name", help="Model name")
parser.add_argument("model_zip_url", help="Model URL (needs to be publicly accessible or presigned)")
parser.add_argument("resource_id", help="Resource ID")
parser.add_argument(
"package_url",
help="Resource package URL (needs to be publicly accessible or presigned)",
)
return parser


Expand All @@ -27,22 +29,22 @@ def get_args(argv: Optional[list] = None):

def main():
args = get_args()
model_name = args.model_name
model_zip_url = args.model_zip_url
resource_id = args.resource_id
package_url = args.package_url
try:
unzip_from_url(model_name, model_zip_url)
unzip_from_url(resource_id, package_url)
except Exception:
err_message = f"An error occurred in the CI:\n {traceback.format_exc()}"
print(err_message)
update_status(model_name, {'status' : err_message})
update_status(resource_id, {"status": err_message})
raise


def unzip_from_url(model_name, model_zip_url):
def unzip_from_url(resource_id, package_url):
filename = "model.zip"
client = create_client()

versions = client.check_versions(model_name)
versions = client.check_versions(resource_id)
if len(versions) == 0:
version = "1"

Expand All @@ -52,22 +54,22 @@ def unzip_from_url(model_name, model_zip_url):
raise NotImplementedError("Updating/publishing new version not implemented")

# TODO: Need to make sure status is staging
status = client.get_status(model_name, version)
status = client.get_status(resource_id, version)
status_str = status.get("status", "missing-status")
if status_str != "staging":
raise ValueError(
"Model {} at version {} is status: {}",
model_name, version, status_str)
"Model {} at version {} is status: {}", resource_id, version, status_str
)

# Download the model zip file
remotezip = urllib.request.urlopen(model_zip_url)
remotezip = urllib.request.urlopen(package_url)
# Unzip the zip file
zipinmemory = io.BytesIO(remotezip.read())
zipobj = zipfile.ZipFile(zipinmemory)
for filename in zipobj.namelist():
# file_object = io.BytesIO(zipobj)
file_object = zipobj.open(filename)
path = f"{model_name}/{version}/{filename}"
path = f"{resource_id}/{version}/{filename}"

client.put(
path,
Expand Down
30 changes: 17 additions & 13 deletions .github/scripts/update_log.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import argparse
from typing import Optional
import datetime
from loguru import logger
from typing import Optional

from loguru import logger
from s3_client import create_client


def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("model_name", help="Model name")
parser.add_argument("resource_id", help="Model name")
parser.add_argument("category", help="Log category")
parser.add_argument("summary", help="Log summary")
parser.add_argument("--version", help="Version")
Expand All @@ -24,31 +25,34 @@ def get_args(argv: Optional[list] = None):

def main():
args = get_args()
model_name = args.model_name
resource_id = args.resource_id
category = args.category
summary = args.summary
version = args.version
add_log_entry(model_name, category, summary, version=version)
add_log_entry(resource_id, category, summary, version=version)


def add_log_entry(model_name, category, summary, version=None):
def add_log_entry(resource_id, category, summary, version=None):
timenow = datetime.datetime.now().isoformat()
client = create_client()
logger.info("Updating log for {} with category {} and summary",
model_name,
category,
summary)
logger.info(
"Updating log for {} with category {} and summary",
resource_id,
category,
summary,
)

if version is None:
version = client.get_unpublished_version(model_name)
version = client.get_unpublished_version(resource_id)
logger.info("Version detected: {}", version)
else:
logger.info("Version requested: {}", version)
log = client.get_log(model_name, version)
log = client.get_log(resource_id, version)

if category not in log:
log[category] = []
log[category].append({"timestamp": timenow, "log": summary})
client.put_log(model_name, version, log)
client.put_log(resource_id, version, log)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 6fec3b5

Please sign in to comment.