From 78c5286b7906d210662da918f1fd0ae0b7e3625e Mon Sep 17 00:00:00 2001 From: Alexander Jipa Date: Fri, 3 Nov 2023 13:01:06 -0400 Subject: [PATCH] feat: adding aws_sagemaker_scheduler (#785) --- dev-requirements.txt | 1 + docs/source/index.rst | 1 + docs/source/schedulers/aws_sagemaker.rst | 19 + torchx/schedulers/__init__.py | 1 + torchx/schedulers/aws_sagemaker_scheduler.py | 591 ++++++++++++++++++ .../test/aws_sagemaker_scheduler_test.py | 386 ++++++++++++ 6 files changed, 999 insertions(+) create mode 100644 docs/source/schedulers/aws_sagemaker.rst create mode 100644 torchx/schedulers/aws_sagemaker_scheduler.py create mode 100644 torchx/schedulers/test/aws_sagemaker_scheduler_test.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 1b0ae50a3..1e8538958 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -2,6 +2,7 @@ aiobotocore ax-platform[mysql]==0.2.3 black==23.3.0 boto3 +sagemaker>=2.149.0 captum>=0.4.0 docker flake8==3.9.0 diff --git a/docs/source/index.rst b/docs/source/index.rst index bf8370bce..0a560530f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -77,6 +77,7 @@ Works With schedulers/slurm schedulers/ray schedulers/aws_batch + schedulers/aws_sagemaker schedulers/lsf schedulers/gcp_batch diff --git a/docs/source/schedulers/aws_sagemaker.rst b/docs/source/schedulers/aws_sagemaker.rst new file mode 100644 index 000000000..9fc0f4ec2 --- /dev/null +++ b/docs/source/schedulers/aws_sagemaker.rst @@ -0,0 +1,19 @@ +AWS SageMaker +================= + +.. automodule:: torchx.schedulers.aws_sagemaker_scheduler + +.. currentmodule:: torchx.schedulers.aws_sagemaker_scheduler + +.. autoclass:: AWSSageMakerScheduler + :members: + :show-inheritance: + +.. autoclass:: AWSSageMakerJob + :members: + +Reference +~~~~~~~~~~~~ + +.. autofunction:: create_scheduler + diff --git a/torchx/schedulers/__init__.py b/torchx/schedulers/__init__.py index a9200f388..05b29411c 100644 --- a/torchx/schedulers/__init__.py +++ b/torchx/schedulers/__init__.py @@ -19,6 +19,7 @@ "kubernetes": "torchx.schedulers.kubernetes_scheduler", "kubernetes_mcad": "torchx.schedulers.kubernetes_mcad_scheduler", "aws_batch": "torchx.schedulers.aws_batch_scheduler", + "aws_sagemaker": "torchx.schedulers.aws_sagemaker_scheduler", "gcp_batch": "torchx.schedulers.gcp_batch_scheduler", "ray": "torchx.schedulers.ray_scheduler", "lsf": "torchx.schedulers.lsf_scheduler", diff --git a/torchx/schedulers/aws_sagemaker_scheduler.py b/torchx/schedulers/aws_sagemaker_scheduler.py new file mode 100644 index 000000000..cb27fad9f --- /dev/null +++ b/torchx/schedulers/aws_sagemaker_scheduler.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python3 + +import getpass +import os +import re +import threading +from collections import OrderedDict as OrdDict +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Mapping, + Optional, + OrderedDict, + Tuple, + TYPE_CHECKING, + TypeVar, +) + +import boto3 +import yaml + +# pyre-fixme[21] +from sagemaker.pytorch import PyTorch +from torchx.components.structured_arg import StructuredNameArgument +from torchx.schedulers.api import ( + AppDryRunInfo, + DescribeAppResponse, + ListAppResponse, + Scheduler, + Stream, +) +from torchx.schedulers.ids import make_unique +from torchx.specs.api import AppDef, AppState, CfgVal, runopts +from torchx.workspace.docker_workspace import DockerWorkspaceMixin +from typing_extensions import TypedDict + + +if TYPE_CHECKING: + from docker import DockerClient # pragma: no cover + +JOB_STATE: Dict[str, AppState] = { + "InProgress": AppState.RUNNING, + "Completed": AppState.SUCCEEDED, + "Failed": AppState.FAILED, + "Stopping": AppState.CANCELLED, + "Stopped": AppState.CANCELLED, +} + + +class AWSSageMakerOpts(TypedDict, total=False): + """ + Opts where we can get from .torchxconfig or user command args + """ + + role: str + instance_count: int + instance_type: str + keep_alive_period_in_seconds: Optional[int] + volume_size: Optional[int] + volume_kms_key: Optional[str] + max_run: Optional[int] + input_mode: Optional[str] + output_path: Optional[str] + output_kms_key: Optional[str] + base_job_name: Optional[str] + tags: Optional[Dict[str, str]] + subnets: Optional[List[str]] + security_group_ids: Optional[List[str]] + model_uri: Optional[str] + model_channel_name: Optional[str] + metric_definitions: Optional[Dict[str, str]] + encrypt_inter_container_traffic: Optional[bool] + use_spot_instances: Optional[bool] + max_wait: Optional[int] + checkpoint_s3_uri: Optional[str] + checkpoint_local_path: Optional[str] + debugger_hook_config: Optional[bool] + enable_sagemaker_metrics: Optional[bool] + enable_network_isolation: Optional[bool] + disable_profiler: Optional[bool] + environment: Optional[Dict[str, str]] + max_retry_attempts: Optional[int] + source_dir: Optional[str] + git_config: Optional[Dict[str, str]] + hyperparameters: Optional[Dict[str, str]] + container_log_level: Optional[int] + code_location: Optional[str] + dependencies: Optional[List[str]] + training_repository_access_mode: Optional[str] + training_repository_credentials_provider_arn: Optional[str] + disable_output_compression: Optional[bool] + enable_infra_check: Optional[bool] + + +@dataclass +class AWSSageMakerJob: + """ + Jobs defined the key values that is requried to schedule a job. This will be the value + of `request` in the AppDryRunInfo object. + + - job_name: defines the job name shown in SageMaker + - job_def: defines the job description that will be used to schedule the job on SageMaker + - images_to_push: used by torchx to push to image_repo + """ + + job_name: str + job_def: Dict[str, Any] + images_to_push: Dict[str, Tuple[str, str]] + + def __str__(self) -> str: + return yaml.dump(asdict(self)) + + def __repr__(self) -> str: + return str(self) + + +T = TypeVar("T") + + +def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]: + # decorator function for keeping object in cache + local: threading.local = threading.local() + key: str = "value" + + def wrapper() -> T: + if key in local.__dict__: + return local.__dict__[key] + v = f() + local.__dict__[key] = v + return v + + return wrapper + + +@_thread_local_cache +def _local_session() -> boto3.session.Session: + return boto3.session.Session() + + +def _merge_ordered( + src: Optional[Dict[str, str]], extra: Dict[str, str] +) -> OrderedDict[str, str]: + merged = OrdDict(src or {}) + merged.update(extra) + return merged + + +class SageMakerScheduler(DockerWorkspaceMixin, Scheduler[AWSSageMakerOpts]): # type: ignore[misc] + """ + SageMakerScheduler is a TorchX scheduling interface to AWS SageMaker. + + .. code-block:: bash + + $ torchx run -s aws_sagemaker utils.echo --image alpine:latest --msg hello + aws_batch://torchx_user/1234 + $ torchx status aws_batch://torchx_user/1234 + ... + + Authentication is loaded from the environment using the ``boto3`` credential + handling. + + **Config Options** + + .. runopts:: + class: torchx.schedulers.sagemaker_scheduler.create_scheduler + + **Compatibility** + + .. compatibility:: + type: scheduler + features: + cancel: true + logs: false + distributed: true + describe: | + Partial support. SageMakerScheduler will return job and replica + status but does not provide the complete original AppSpec. + workspaces: true + mounts: false + elasticity: false + """ + + def __init__( + self, + session_name: str, + client: Optional[Any] = None, # pyre-ignore[2] + docker_client: Optional["DockerClient"] = None, + ) -> None: + super().__init__("aws_sagemaker", session_name, docker_client=docker_client) + # pyre-fixme[4]: Attribute annotation cannot be `Any`. + self.__client = client + + @property + # pyre-fixme[3]: Return annotation cannot be `Any`. + def _client(self) -> Any: + if self.__client: + return self.__client + return _local_session().client("sagemaker") + + def schedule(self, dryrun_info: AppDryRunInfo[AWSSageMakerJob]) -> str: + cfg = dryrun_info._cfg + assert cfg is not None, f"{dryrun_info} missing cfg" + + images_to_push = dryrun_info.request.images_to_push + self.push_images(images_to_push) + + req = dryrun_info.request + pt_estimator = PyTorch(**req.job_def) + pt_estimator.fit(wait=False, job_name=req.job_name) + + return req.job_name + + def _submit_dryrun( + self, app: AppDef, cfg: AWSSageMakerOpts + ) -> AppDryRunInfo[AWSSageMakerJob]: + role = app.roles[0] + entrypoint, hyperparameters = self._parse_args(role.args) + + # map any local images to the remote image + images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg)) + structured_name_kwargs = {} + if entrypoint.startswith("-m"): + structured_name_kwargs["m"] = entrypoint.replace("-m", "").strip() + else: + structured_name_kwargs["script"] = entrypoint + structured_name = StructuredNameArgument.parse_from( + app.name, **structured_name_kwargs + ) + job_name = make_unique(structured_name.run_name) + + role.env["TORCHX_JOB_ID"] = job_name + + # see https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.EstimatorBase + job_def = { + "entry_point": entrypoint, + "image_uri": role.image, + "distribution": {"torch_distributed": {"enabled": True}}, + } + + cfg["environment"] = _merge_ordered(cfg.get("environment"), role.env) + # hyperparameters are used for both script/module entrypoint args and the values from .torchxconfig + # order matters, adding script args last to handle wildcard parameters + cfg["hyperparameters"] = _merge_ordered( + cfg.get("hyperparameters"), hyperparameters + ) + # tags are used for AppDef metadata and the values from .torchxconfig + cfg["tags"] = [ # pyre-ignore[54] + *(cfg.get("tags") or []), + *({"Key": k, "Value": v} for k, v in app.metadata.items()), + ] + # following the principle of least astonishment defaulting source_dir to current working directory + cfg["source_dir"] = cfg.get("source_dir") or os.getcwd() + + for key in cfg: + if key in job_def: + raise ValueError( + f"{key} is controlled by aws_sagemaker_scheduler and is set to {job_def[key]}" + ) + value = cfg.get(key) # pyre-ignore[26] + if value is not None: + job_def[key] = value + + req = AWSSageMakerJob( + job_name=job_name, + job_def=job_def, + images_to_push=images_to_push, + ) + return AppDryRunInfo(req, repr) + + def _parse_args(self, args: List[str]) -> Tuple[str, Dict[str, str]]: + if len(args) < 1: + raise ValueError("Not enough args to resolve entrypoint") + offset = 1 + if args[0] == "-m": + if len(args) < 2: + raise ValueError("Missing module name") + offset += 1 + entrypoint = " ".join(args[:offset]) + hyperparameters = OrdDict() # the order matters, e.g. for wildcard params + while offset < len(args): + arg = args[offset] + sp_pos = arg.find("=") + if sp_pos < 0: + if offset + 1 >= len(args): + raise ValueError( + "SageMaker currently only supports named arguments" + ) + key = arg + offset += 1 + value = args[offset] + else: + key = arg[:sp_pos] + value = arg[sp_pos + 1 :] + if not key.startswith("--"): + raise ValueError("SageMaker only supports arguments that start with --") + offset += 1 + hyperparameters[key[2:]] = value + return entrypoint, hyperparameters + + def _run_opts(self) -> runopts: + opts = runopts() + opts.add( + "role", + type_=str, + help="an AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource.", + required=True, + ) + opts.add( + "instance_count", + type_=int, + default=1, + help="number of Amazon EC2 instances to use for training. Required if instance_groups is not set.", + ) + opts.add( + "instance_type", + type_=str, + help="type of EC2 instance to use for training, for example, 'ml.c4.xlarge'", + required=True, + ) + opts.add( + "user", + type_=str, + default=getpass.getuser(), + help="the username to tag the job with. `getpass.getuser()` if not specified.", + ) + opts.add( + "keep_alive_period_in_seconds", + type_=int, + default=None, + help="the duration of time in seconds to retain configured resources in a warm pool for subsequent training jobs.", + ) + opts.add( + "volume_size", + type_=int, + default=None, + help="size in GB of the storage volume to use for storing input and output data during training (default: 30).", + ) + opts.add( + "volume_kms_key", + type_=str, + default=None, + help="KMS key ID for encrypting EBS volume attached to the training instance.", + ) + opts.add( + "max_run", + type_=int, + default=None, + help="timeout in seconds for training (default: 24 * 60 * 60).", + ) + opts.add( + "input_mode", + type_=str, + default=None, + help="the input mode that the algorithm supports (default: ‘File’).", + ) + opts.add( + "output_path", + type_=str, + default=None, + help="S3 location for saving the training result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the fit() method execution.", + ) + opts.add( + "output_kms_key", + type_=str, + default=None, + help="KMS key ID for encrypting the training output (default: Your IAM role’s KMS key for Amazon S3).", + ) + opts.add( + "base_job_name", + type_=str, + default=None, + help="prefix for training job name when the fit() method launches. If not specified, the estimator generates a default job name based on the training image name and current timestamp.", + ) + opts.add( + "tags", + type_=List[Dict[str, str]], + default=None, + help="list of tags for labeling a training job.", + ) + opts.add( + "subnets", + type_=List[str], + default=None, + help="list of subnet ids. If not specified training job will be created without VPC config.", + ) + opts.add( + "security_group_ids", + type_=List[str], + default=None, + help="list of security group ids. If not specified training job will be created without VPC config.", + ) + opts.add( + "model_uri", + type_=str, + default=None, + help="URI where a pre-trained model is stored, either locally or in S3.", + ) + opts.add( + "model_channel_name", + type_=str, + default=None, + help="name of the channel where ‘model_uri’ will be downloaded (default: ‘model’).", + ) + opts.add( + "metric_definitions", + type_=List[Dict[str, str]], + default=None, + help="list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: ‘Name’ for the name of the metric, and ‘Regex’ for the regular expression used to extract the metric from the logs.", + ) + opts.add( + "encrypt_inter_container_traffic", + type_=bool, + default=None, + help="specifies whether traffic between training containers is encrypted for the training job (default: False).", + ) + opts.add( + "use_spot_instances", + type_=bool, + default=None, + help="specifies whether to use SageMaker Managed Spot instances for training. If enabled then the max_wait arg should also be set.", + ) + opts.add( + "max_wait", + type_=int, + default=None, + help="timeout in seconds waiting for spot training job.", + ) + opts.add( + "checkpoint_s3_uri", + type_=str, + default=None, + help="S3 URI in which to persist checkpoints that the algorithm persists (if any) during training.", + ) + opts.add( + "checkpoint_local_path", + type_=str, + default=None, + help="local path that the algorithm writes its checkpoints to.", + ) + opts.add( + "debugger_hook_config", + type_=bool, + default=None, + help="configuration for how debugging information is emitted with SageMaker Debugger. If not specified, a default one is created using the estimator’s output_path, unless the region does not support SageMaker Debugger. To disable SageMaker Debugger, set this parameter to False.", + ) + opts.add( + "enable_sagemaker_metrics", + type_=bool, + default=None, + help="enable SageMaker Metrics Time Series.", + ) + opts.add( + "enable_network_isolation", + type_=bool, + default=None, + help="specifies whether container will run in network isolation mode (default: False).", + ) + opts.add( + "disable_profiler", + type_=bool, + default=None, + help="specifies whether Debugger monitoring and profiling will be disabled (default: False).", + ) + opts.add( + "environment", + type_=Dict[str, str], + default=None, + help="environment variables to be set for use during training job", + ) + opts.add( + "max_retry_attempts", + type_=int, + default=None, + help="number of times to move a job to the STARTING status. You can specify between 1 and 30 attempts.", + ) + opts.add( + "source_dir", + type_=str, + default=None, + help="absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: current working directory)", + ) + opts.add( + "git_config", + type_=Dict[str, str], + default=None, + help="git configurations used for cloning files, including repo, branch, commit, 2FA_enabled, username, password, and token.", + ) + opts.add( + "hyperparameters", + type_=Dict[str, str], + default=None, + help="dictionary containing the hyperparameters to initialize this estimator with.", + ) + opts.add( + "container_log_level", + type_=int, + default=None, + help="log level to use within the container (default: logging.INFO).", + ) + opts.add( + "code_location", + type_=str, + default=None, + help="S3 prefix URI where custom code is uploaded.", + ) + opts.add( + "dependencies", + type_=List[str], + default=None, + help="list of absolute or relative paths to directories with any additional libraries that should be exported to the container.", + ) + opts.add( + "training_repository_access_mode", + type_=str, + default=None, + help="specifies how SageMaker accesses the Docker image that contains the training algorithm.", + ) + opts.add( + "training_repository_credentials_provider_arn", + type_=str, + default=None, + help="Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your training image is hosted.", + ) + opts.add( + "disable_output_compression", + type_=bool, + default=None, + help="when set to true, Model is uploaded to Amazon S3 without compression after training finishes.", + ) + opts.add( + "enable_infra_check", + type_=bool, + default=None, + help="specifies whether it is running Sagemaker built-in infra check jobs.", + ) + return opts + + def describe(self, app_id: str) -> Optional[DescribeAppResponse]: + job = self._get_job(app_id) + if job is None: + return None + + return DescribeAppResponse( + app_id=app_id, + state=JOB_STATE[job["TrainingJobStatus"]], + ui_url=self._job_ui_url(job["TrainingJobArn"]), + ) + + def list(self) -> List[ListAppResponse]: + raise NotImplementedError() + + def _cancel_existing(self, app_id: str) -> None: + self._client.stop_training_job(TrainingJobName=app_id) + + def log_iter( + self, + app_id: str, + role_name: str, + k: int = 0, + regex: Optional[str] = None, + since: Optional[datetime] = None, + until: Optional[datetime] = None, + should_tail: bool = False, + streams: Optional[Stream] = None, + ) -> Iterable[str]: + raise NotImplementedError() + + def _get_job(self, app_id: str) -> Optional[Dict[str, Any]]: + job = self._client.describe_training_job(TrainingJobName=app_id) + return job + + def _job_ui_url(self, job_arn: str) -> Optional[str]: + match = re.match( + "arn:aws:sagemaker:(?P[a-z-0-9]+):[0-9]+:training-job/(?P[a-z-0-9]+)", + job_arn, + ) + if match is None: + return None + region = match.group("region") + job_id = match.group("job_id") + return f"https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#jobs/{job_id}" + + +def create_scheduler(session_name: str, **kwargs: object) -> SageMakerScheduler: + return SageMakerScheduler(session_name=session_name) diff --git a/torchx/schedulers/test/aws_sagemaker_scheduler_test.py b/torchx/schedulers/test/aws_sagemaker_scheduler_test.py new file mode 100644 index 000000000..d9ebd3d35 --- /dev/null +++ b/torchx/schedulers/test/aws_sagemaker_scheduler_test.py @@ -0,0 +1,386 @@ +import threading +import unittest +from collections import OrderedDict +from contextlib import contextmanager +from datetime import datetime +from typing import Any, Dict, Generator, Iterable, Optional +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from torchx.schedulers.api import AppDryRunInfo + +from torchx.schedulers.aws_sagemaker_scheduler import ( + _local_session, + AWSSageMakerJob, + AWSSageMakerOpts, + create_scheduler, + SageMakerScheduler, + JOB_STATE, +) +from torchx.specs.api import runopts + +ENV_TORCHX_ROLE_NAME = "TORCHX_ROLE_NAME" +MODULE = "torchx.schedulers.aws_sagemaker_scheduler" + + +def to_millis_since_epoch(ts: datetime) -> int: + # datetime's timestamp returns seconds since epoch + return int(round(ts.timestamp() * 1000)) + + +class TestAWSSageMakerOpts(TestCase): + def setUp(self) -> None: + self.test_dict: AWSSageMakerOpts = { + "role": "test-arn", + "subnets": ["subnet-1", "subnet-2"], + "security_group_ids": ["sg-1", "sg-2"], + } + + def test_role(self) -> None: + self.assertEqual(self.test_dict["role"], "test-arn") + self.assertIsInstance(self.test_dict["role"], str) + + def test_subnets(self) -> None: + self.assertEqual(self.test_dict["subnets"], ["subnet-1", "subnet-2"]) + self.assertIsInstance(self.test_dict["subnets"], list) + + def test_security_group_ids(self) -> None: + self.assertEqual(self.test_dict["security_group_ids"], ["sg-1", "sg-2"]) + self.assertIsInstance(self.test_dict["security_group_ids"], list) + + +@contextmanager +def mock_rand() -> Generator[None, None, None]: + with patch(f"{MODULE}.make_unique") as make_unique_ctx: + make_unique_ctx.return_value = "app-name-42" + yield + + +boto3Response = Dict[str, Any] # boto3 responses are JSON + + +class MockPaginator: + """ + Used for mocking ``boto3.client("").get_paginator("")`` calls. + """ + + def __init__(self, **op_to_pages: Iterable[boto3Response]) -> None: + # boto3 paginators return an iterable of API responses + self.op_to_pages: Dict[str, Iterable[boto3Response]] = op_to_pages + self.op_name: Optional[str] = None + + def __call__(self, op_name: str) -> "MockPaginator": + self.op_name = op_name + return self + + def paginate(self, *_1: Any, **_2: Any) -> Iterable[Dict[str, Any]]: + if self.op_name: + return self.op_to_pages[self.op_name] + raise RuntimeError( + "`op_name` not set. Did you forget to call `__call__(op_name)`?" + ) + + +class TestSageMakerScheduler(TestCase): + def setUp(self) -> None: + self.sagemaker_client = MagicMock() + self.scheduler = SageMakerScheduler( + session_name="test-session", client=self.sagemaker_client, docker_client=MagicMock() + ) + self.job = AWSSageMakerJob( + job_name="test-name", + job_def={ + "entry_point": "some_entry_point", + "image_uri": "some_image_uri", + "role_arn": "some_role_arn", + }, + images_to_push={"image1": ("tag1", "repo1")}, + ) + self.dryrun_info = AppDryRunInfo(self.job, repr) + + def _mock_scheduler(self) -> SageMakerScheduler: + scheduler = SageMakerScheduler( + "test", + client=MagicMock(), + docker_client=MagicMock(), + ) + + scheduler._client.get_paginator.side_effect = MockPaginator( + describe_job_queues=[ + { + "ResponseMetadata": {}, + "jobQueues": [ + { + "jobQueueName": "torchx", + "jobQueueArn": "arn:aws:sagemaker:test-region:4000005:job-queue/torchx", + "state": "ENABLED", + }, + ], + } + ], + list_jobs=[ + { + "jobSummaryList": [ + { + "jobArn": "arn:aws:sagemaker:us-west-2:1234567890:job/6afc27d7-3559-43ca-89fd-1007b6bf2546", + "jobId": "6afc27d7-3559-43ca-89fd-1007b6bf2546", + "jobName": "app-name-42", + "createdAt": 1643949940162, + "status": "SUCCEEDED", + "stoppedAt": 1643950324125, + "container": {"exitCode": 0}, + "nodeProperties": {"numNodes": 2}, + "jobDefinition": "arn:aws:sagemaker:us-west-2:1234567890:job-definition/app-name-42:1", + } + ] + } + ], + ) + + scheduler._client.describe_jobs.return_value = { + "jobs": [ + { + "jobArn": "arn:aws:sagemaker:us-west-2:1234567890:job/6afc27d7-3559-43ca-89fd-1007b6bf2546", + "jobName": "app-name-42", + "jobId": "6afc27d7-3559-43ca-89fd-1007b6bf2546", + "jobQueue": "testqueue", + "status": "SUCCEEDED", + "attempts": [ + { + "container": { + "exitCode": 0, + "logStreamName": "log_stream", + "networkInterfaces": [], + }, + "startedAt": 1643950310819, + "stoppedAt": 1643950324125, + "statusReason": "Essential container in task exited", + } + ], + "statusReason": "Essential container in task exited", + "createdAt": 1643949940162, + "retryStrategy": { + "attempts": 1, + "evaluateOnExit": [{"onExitCode": "0", "action": "exit"}], + }, + "startedAt": 1643950310819, + "stoppedAt": 1643950324125, + "dependsOn": [], + "jobDefinition": "job-def", + "parameters": {}, + "nodeProperties": { + "numNodes": 2, + "mainNode": 0, + "nodeRangeProperties": [ + { + "targetNodes": "0:1", + "container": { + "image": "ghcr.io/pytorch/torchx:0.1.2dev0", + "command": ["echo", "your name"], + "volumes": [], + "environment": [ + { + "name": "TORCHX_ROLE_IDX", + "value": "0", + }, + { + "name": "TORCHX_ROLE_NAME", + "value": "echo", + }, + { + "name": "TORCHX_RANK0_HOST", + "value": "localhost", + }, + ], + "mountPoints": [], + "ulimits": [], + "resourceRequirements": [ + {"value": "1", "type": "VCPU"}, + {"value": "1000", "type": "MEMORY"}, + ], + "logConfiguration": { + "logDriver": "awslogs", + "options": {}, + "secretOptions": [], + }, + "secrets": [], + }, + }, + ], + }, + "tags": { + "torchx.pytorch.org/version": "0.1.2dev0", + "torchx.pytorch.org/app-name": "echo", + }, + "platformCapabilities": [], + } + ] + } + + return scheduler + + @patch(f"{MODULE}.PyTorch") + def test_schedule(self, mock_pytorch_estimator) -> None: + expected_name = "test-name" + returned_name = self.scheduler.schedule(self.dryrun_info) + self.assertEqual(returned_name, expected_name) + + def test_run_opts(self) -> None: + scheduler = self._mock_scheduler() + # Call the _run_opts method + result = scheduler._run_opts() + # Assert that the returned value is an instance of runopts + self.assertIsInstance(result, runopts) + + def test_cancel_existing(self) -> None: + scheduler = self._mock_scheduler() + # Call the _cancel_existing method + scheduler._cancel_existing(app_id="testqueue:app-name-42") + # Assert that it's called once + scheduler._client.stop_training_job.assert_called_once() + + def test_list(self) -> None: + with self.assertRaises(NotImplementedError): + scheduler = self._mock_scheduler() + scheduler.list() + + def test_describe_job(self) -> None: + region = "us-east-1" + job_id = "42" + state = "InProgress" + training_job = { + "TrainingJobStatus": state, + "TrainingJobArn": f"arn:aws:sagemaker:{region}:1234567890:training-job/{job_id})", + } + self.sagemaker_client.describe_training_job.return_value = training_job + job = self.scheduler.describe(app_id=(app_id := "testqueue:app-name-42")) + self.assertEqual(job.app_id, app_id) + self.assertEqual(job.state, JOB_STATE[state]) + self.assertEqual(job.ui_url, f"https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#jobs/{job_id}") + + def test_log_iter(self) -> None: + with self.assertRaises(NotImplementedError): + scheduler = self._mock_scheduler() + scheduler.log_iter( + app_id="testqueue:app-name-42", + role_name="echo", + k=1, + regex="foo.*", + ) + + def test_get_job(self) -> None: + # Arrange + scheduler = self._mock_scheduler() + + # Act + test_job = scheduler._get_job(app_id="testqueue:app-name-42") + + # Assert + self.assertEqual(test_job, scheduler._client.describe_training_job.return_value) + + def test_job_ui_url(self) -> None: + # Set up the input job ARN and expected URL + job_arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/job-id" + expected_url = "https://us-east-1.console.aws.amazon.com/sagemaker/home?region=us-east-1#jobs/job-id" + + # Call the _job_ui_url method + result = self.scheduler._job_ui_url(job_arn) + + # Assert that the returned URL matches the expected URL + self.assertEqual(result, expected_url) + + def test_job_ui_url_with_invalid_arn(self) -> None: + # Set up an invalid job ARN + job_arn = "invalid-arn" + + # Call the _job_ui_url method + result = self.scheduler._job_ui_url(job_arn) + + # Assert that the returned value is None + self.assertIsNone(result) + + def test_job_ui_url_with_no_match(self) -> None: + # Set up a job ARN that does not match the regex pattern + job_arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job" + + # Call the _job_ui_url method + result = self.scheduler._job_ui_url(job_arn) + + # Assert that the returned value is None + self.assertIsNone(result) + + def test_parse_args(self) -> None: + # Set up the role_args with no match + role_args = ["arg1", "arg2", "arg3"] + + # Call the _parse_entrypoint_and_source_dir method + with self.assertRaises(ValueError): + self.scheduler._parse_args(role_args) + + def test_parse_args_with_overrides(self) -> None: + # Set up the args + test_args = [ + "--", + "--config-path", + "test-path/test-config", + "--config-name", + "config.yaml", + "--overrides", + "key1=value1", + ] + + # Call the _parse_arguments method + result = self.scheduler._parse_args(test_args) + + # Assert the returned values + expected_args = OrderedDict( + [ + ("config-path", "test-path/test-config"), + ("config-name", "config.yaml"), + ("overrides", "key1=value1"), + ] + ) + self.assertEqual(result, ("--", expected_args)) + + def test_parse_args_without_overrides(self) -> None: + # Set up the args + test_args = [ + "--", + "--config-path", + "test-path/test-config", + "--config-name", + "config.yaml", + ] + + # Call the _parse_arguments method + result = self.scheduler._parse_args(test_args) + + # Assert the returned values + expected_args = OrderedDict( + [ + ("config-path", "test-path/test-config"), + ("config-name", "config.yaml"), + ] + ) + self.assertEqual(result, ("--", expected_args)) + + def test_local_session(self) -> None: + a: object = _local_session() + self.assertIs(a, _local_session()) + + def worker() -> None: + b = _local_session() + self.assertIs(b, _local_session()) + self.assertIsNot(a, b) + + t = threading.Thread(target=worker) + t.start() + t.join() + + def test_create_scheduler(self) -> None: + scheduler = create_scheduler(session_name="test-sm") + self.assertIsInstance(scheduler, SageMakerScheduler) + + +if __name__ == "__main__": + unittest.main()