Skip to content

Commit

Permalink
Merge branch 'main' into replace-pyaml-with-ruyaml
Browse files Browse the repository at this point in the history
  • Loading branch information
jmetz authored Feb 8, 2024
2 parents 969ed83 + 876bbe4 commit 3046fc2
Show file tree
Hide file tree
Showing 14 changed files with 485 additions and 516 deletions.
99 changes: 44 additions & 55 deletions .github/scripts/s3_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
import io
from pathlib import Path
import json
import os
from dataclasses import dataclass, field
from datetime import timedelta
from pathlib import Path
from typing import Iterator
import json

from minio import Minio # type: ignore
# import requests # type: ignore
from loguru import logger # type: ignore
from minio import Minio # type: ignore


@dataclass
Expand Down Expand Up @@ -42,15 +42,12 @@ def bucket_exists(self, bucket):
return self._client.bucket_exists(bucket)

def put(
self,
path,
file_object,
length=-1,
content_type="application/octet-stream"):
self, path, file_object, length=-1, content_type="application/octet-stream"
):
# For unknown length (ie without reading file into mem) give `part_size`
part_size = 0
if length == -1:
part_size = 10*1024*1024
part_size = 10 * 1024 * 1024
path = f"{self.prefix}/{path}"
self._client.put_object(
self.bucket,
Expand All @@ -62,19 +59,16 @@ def put(
)

def get_file_urls(
self,
path="",
exclude_files=("status.json"),
lifetime=timedelta(hours=1),
) -> list[str]:
self,
path="",
exclude_files=("status.json"),
lifetime=timedelta(hours=1),
) -> list[str]:
"""Checks an S3 'folder' for its list of files"""
logger.debug("Getting file list using {}, at {}", self, path)
path = f"{self.prefix}/{path}"
objects = self._client.list_objects(
self.bucket,
prefix=path,
recursive=True)
file_urls : list[str] = []
objects = self._client.list_objects(self.bucket, prefix=path, recursive=True)
file_urls: list[str] = []
for obj in objects:
if obj.is_dir:
continue
Expand All @@ -92,7 +86,6 @@ def get_file_urls(
# Option 2: Work with minio.datatypes.Object directly
return file_urls


def ls(self, path, only_folders=False, only_files=False) -> Iterator[str]:
"""
List folder contents, non-recursive, ala `ls`
Expand All @@ -101,19 +94,15 @@ def ls(self, path, only_folders=False, only_files=False) -> Iterator[str]:
# path = str(Path(self.prefix, path))
path = f"{self.prefix}/{path}"
logger.debug("Running ls at path: {}", path)
objects = self._client.list_objects(
self.bucket,
prefix=path,
recursive=False)
objects = self._client.list_objects(self.bucket, prefix=path, recursive=False)
for obj in objects:
if only_files and obj.is_dir:
continue
if only_folders and not obj.is_dir:
continue
yield Path(obj.object_name).name


def load_file(self, path):
def load_file(self, path) -> str:
"""Load file from S3"""
path = f"{self.prefix}/{path}"
try:
Expand All @@ -131,31 +120,31 @@ def load_file(self, path):
return content

# url = self.client.get_presigned_url(
# "GET",
# self.bucket,
# str(Path(self.prefix, path)),
# expires=timedelta(minutes=10),
# "GET",
# self.bucket,
# str(Path(self.prefix, path)),
# expires=timedelta(minutes=10),
# )
# response = requests.get(url)
# return response.content

def check_versions(self, model_name: str) -> Iterator[VersionStatus]:
def check_versions(self, resource_path: str) -> Iterator[VersionStatus]:
"""
Check model repository for version of model-name.
Returns dictionary of version-status pairs.
"""
logger.debug("Checking versions for {}", model_name)
version_folders = self.ls(f"{model_name}/", only_folders=True)
logger.debug("Checking versions for {}", resource_path)
version_folders = self.ls(f"{resource_path}/", only_folders=True)

# For each folder get the contents of status.json
for version in version_folders:

yield self.get_version_status(model_name, version)
yield self.get_version_status(resource_path, version)

def get_unpublished_version(self, model_name:str) -> str:
def get_unpublished_version(self, resource_path: str) -> str:
"""Get the unpublisted version"""
versions = list(self.check_versions(model_name))
versions = list(self.check_versions(resource_path))
if len(versions) == 0:
return "1"
unpublished = [version for version in versions if version.status == "staging"]
Expand All @@ -166,49 +155,51 @@ def get_unpublished_version(self, model_name:str) -> str:
raise ValueError("Opps! We seem to have > 1 staging versions!!")
return unpublished[0].version

def get_version_status(self, model_name: str, version: str) -> VersionStatus:
status = self.get_status(model_name, version)
status_str = status.get('status', 'status-field-unset')
version_path = f"{model_name}/{version}"
def get_version_status(self, resource_path: str, version: str) -> VersionStatus:
status = self.get_status(resource_path, version)
status_str = status.get("status", "status-field-unset")
version_path = f"{resource_path}/{version}"
return VersionStatus(version, status_str, version_path)

def get_status(self, model_name: str, version: str) -> dict:
version_path = f"{model_name}/{version}"
logger.debug("model_name: {}, version: {}", model_name, version)
def get_status(self, resource_path: str, version: str) -> dict:
version_path = f"{resource_path}/{version}"
logger.debug("resource_path: {}, version: {}", resource_path, version)
status_path = f"{version_path}/status.json"
logger.debug("Getting status using path {}", status_path)
status = self.load_file(status_path)
status = json.loads(status)
return status

def put_status(self, model_name: str, version: str, status: dict):
logger.debug("Updating status for {}-{}, with {}", model_name, version, status)
def put_status(self, resource_path: str, version: str, status: dict):
logger.debug(
"Updating status for {}-{}, with {}", resource_path, version, status
)
contents = json.dumps(status).encode()
file_object = io.BytesIO(contents)

self.put(
f"{model_name}/{version}/status.json",
f"{resource_path}/{version}/status.json",
file_object,
length=len(contents),
content_type="application/json",
)

def get_log(self, model_name: str, version: str) -> dict:
version_path = f"{model_name}/{version}"
logger.debug("model_name: {}, version: {}", model_name, version)
def get_log(self, resource_path: str, version: str) -> dict:
version_path = f"{resource_path}/{version}"
logger.debug("resource_path: {}, version: {}", resource_path, version)
path = f"{version_path}/log.json"
logger.debug("Getting log using path {}", path)
log = self.load_file(path)
log = json.loads(log)
return log

def put_log(self, model_name: str, version: str, log: dict):
logger.debug("Updating log for {}-{}, with {}", model_name, version, log)
def put_log(self, resource_path: str, version: str, log: dict):
logger.debug("Updating log for {}-{}, with {}", resource_path, version, log)
contents = json.dumps(log).encode()
file_object = io.BytesIO(contents)

self.put(
f"{model_name}/{version}/log.json",
f"{resource_path}/{version}/log.json",
file_object,
length=len(contents),
content_type="application/json",
Expand Down Expand Up @@ -239,5 +230,3 @@ def create_client() -> Client:
secret_key=secret_access_key,
)
return client


34 changes: 18 additions & 16 deletions .github/scripts/unzip_model.py → .github/scripts/unzip_package.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import argparse
import io
import traceback
from typing import Optional
import urllib.request
import zipfile
from typing import Optional


from update_status import update_status
from s3_client import create_client
from update_status import update_status


def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("model_name", help="Model name")
parser.add_argument("model_zip_url", help="Model URL (needs to be publicly accessible or presigned)")
parser.add_argument("resource_path", help="Resource ID")
parser.add_argument(
"package_url",
help="Resource package URL (needs to be publicly accessible or presigned)",
)
return parser


Expand All @@ -27,22 +29,22 @@ def get_args(argv: Optional[list] = None):

def main():
args = get_args()
model_name = args.model_name
model_zip_url = args.model_zip_url
resource_path = args.resource_path
package_url = args.package_url
try:
unzip_from_url(model_name, model_zip_url)
unzip_from_url(resource_path, package_url)
except Exception:
err_message = f"An error occurred in the CI:\n {traceback.format_exc()}"
print(err_message)
update_status(model_name, {'status' : err_message})
update_status(resource_path, {"status": err_message})
raise


def unzip_from_url(model_name, model_zip_url):
def unzip_from_url(resource_path, package_url):
filename = "model.zip"
client = create_client()

versions = client.check_versions(model_name)
versions = client.check_versions(resource_path)
if len(versions) == 0:
version = "1"

Expand All @@ -52,22 +54,22 @@ def unzip_from_url(model_name, model_zip_url):
raise NotImplementedError("Updating/publishing new version not implemented")

# TODO: Need to make sure status is staging
status = client.get_status(model_name, version)
status = client.get_status(resource_path, version)
status_str = status.get("status", "missing-status")
if status_str != "staging":
raise ValueError(
"Model {} at version {} is status: {}",
model_name, version, status_str)
"Model {} at version {} is status: {}", resource_path, version, status_str
)

# Download the model zip file
remotezip = urllib.request.urlopen(model_zip_url)
remotezip = urllib.request.urlopen(package_url)
# Unzip the zip file
zipinmemory = io.BytesIO(remotezip.read())
zipobj = zipfile.ZipFile(zipinmemory)
for filename in zipobj.namelist():
# file_object = io.BytesIO(zipobj)
file_object = zipobj.open(filename)
path = f"{model_name}/{version}/{filename}"
path = f"{resource_path}/{version}/files/{filename}"

client.put(
path,
Expand Down
30 changes: 17 additions & 13 deletions .github/scripts/update_log.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import argparse
from typing import Optional
import datetime
from loguru import logger
from typing import Optional

from loguru import logger
from s3_client import create_client


def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("model_name", help="Model name")
parser.add_argument("resource_path", help="Model name")
parser.add_argument("category", help="Log category")
parser.add_argument("summary", help="Log summary")
parser.add_argument("--version", help="Version")
Expand All @@ -24,31 +25,34 @@ def get_args(argv: Optional[list] = None):

def main():
args = get_args()
model_name = args.model_name
resource_path = args.resource_path
category = args.category
summary = args.summary
version = args.version
add_log_entry(model_name, category, summary, version=version)
add_log_entry(resource_path, category, summary, version=version)


def add_log_entry(model_name, category, summary, version=None):
def add_log_entry(resource_path, category, summary, version=None):
timenow = datetime.datetime.now().isoformat()
client = create_client()
logger.info("Updating log for {} with category {} and summary",
model_name,
category,
summary)
logger.info(
"Updating log for {} with category {} and summary",
resource_path,
category,
summary,
)

if version is None:
version = client.get_unpublished_version(model_name)
version = client.get_unpublished_version(resource_path)
logger.info("Version detected: {}", version)
else:
logger.info("Version requested: {}", version)
log = client.get_log(model_name, version)
log = client.get_log(resource_path, version)

if category not in log:
log[category] = []
log[category].append({"timestamp": timenow, "log": summary})
client.put_log(model_name, version, log)
client.put_log(resource_path, version, log)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 3046fc2

Please sign in to comment.