From a33d1b3bb683450e7853b25946775aa4b0263675 Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 10:43:37 -0800 Subject: [PATCH 01/38] refactor --- .../__init__.py | 242 +----------------- .../batch_client.py | 34 +++ .../batch_descriptor.py | 71 +++++ .../batch_job_builder.py | 103 ++++++++ 4 files changed, 215 insertions(+), 235 deletions(-) create mode 100644 snakemake_executor_plugin_aws_batch/batch_client.py create mode 100644 snakemake_executor_plugin_aws_batch/batch_descriptor.py create mode 100644 snakemake_executor_plugin_aws_batch/batch_job_builder.py diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index 1fb466c..1f245da 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -1,13 +1,8 @@ from dataclasses import dataclass, field from pprint import pformat -import boto3 -import uuid -import heapq -import botocore import shlex -import time -import threading from typing import List, Generator, Optional +from snakemake_executor_plugin_aws_batch import BatchJobDescriber, BatchClient, BatchJobBuilder from snakemake_interface_executor_plugins.executors.base import SubmittedJobInfo from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor from snakemake_interface_executor_plugins.settings import ( @@ -28,16 +23,6 @@ # of None or anything else that makes sense in your case. @dataclass class ExecutorSettings(ExecutorSettingsBase): - access_key_id: Optional[int] = field( - default=None, - metadata={"help": "AWS access key id", "env_var": True, "required": False}, - repr=False, - ) - access_key: Optional[int] = field( - default=None, - metadata={"help": "AWS access key", "env_var": True, "required": False}, - repr=False, - ) region: Optional[int] = field( default=None, metadata={ @@ -46,29 +31,6 @@ class ExecutorSettings(ExecutorSettingsBase): "required": True, }, ) - fsap_id: Optional[str] = ( - field( - default=None, - metadata={ - "help": ( - "The fsap id of the EFS instance you want to use that " - "is shared with your local environment" - ), - "env_var": False, - "required": False, - }, - ), - ) - efs_project_path: Optional[str] = ( - field( - default=None, - metadata={ - "help": "The EFS path that contains the project Snakemake is running", - "env_var": False, - "required": False, - }, - ), - ) job_queue: Optional[str] = field( default=None, metadata={ @@ -140,6 +102,7 @@ class ExecutorSettings(ExecutorSettingsBase): # Implementation of your executor class Executor(RemoteExecutor): def __post_init__(self): + # set snakemake/snakemake container image self.container_image = self.workflow.remote_execution_settings.container_image @@ -149,49 +112,8 @@ def __post_init__(self): # keep track of job definitions self.created_job_defs = list() - self.mount_path = None self._describer = BatchJobDescriber() - - # init batch client - try: - self.batch_client = ( - boto3.Session().client( # Session() needed for thread safety - "batch", - aws_access_key_id=self.settings.access_key_id, - aws_secret_access_key=self.settings.access_key, - region_name=self.settings.region, - config=botocore.config.Config( - retries={"max_attempts": 5, "mode": "standard"} - ), - ) - ) - except Exception as e: - raise WorkflowError(e) - - # TODO: - # def _prepare_mounts(self): - # """ - # Prepare the "volumes" and "mountPoints" for the Batch job definition, - # assembling the in-container filesystem with the shared working directory, - # read-only input files, and command/stdout/stderr files. - # """ - - # # EFS mount point - # volumes = [ - # { - # "name": "efs", - # "efsVolumeConfiguration": { - # "fileSystemId": self.fs_id, - # "transitEncryption": "ENABLED", - # }, - # } - # ] - # volumes[0]["efsVolumeConfiguration"]["authorizationConfig"] = { - # "accessPointId": self.fsap_id - # } - # mount_points = [{"containerPath": self.mount_path, "sourceVolume": "efs"}] - - # return volumes, mount_points + self.batch_clint = BatchClient(region_name=self.settings.region) def run_job(self, job: JobExecutorInterface): # Implement here how to run a job. @@ -204,88 +126,11 @@ def run_job(self, job: JobExecutorInterface): # If required, make sure to pass the job's id to the job_info object, as keyword # argument 'external_job_id'. - # set job name - job_uuid = str(uuid.uuid4()) - job_name = f"snakejob-{job.name}-{job_uuid}" - - # set job definition name - job_definition_name = f"snakejob-def-{job.name}-{job_uuid}" - job_definition_type = "container" - - # get job resources or default - vcpu = str(job.resources.get("_cores", str(1))) - mem = str(job.resources.get("mem_mb", str(2048))) - - # job definition container properties - container_properties = { - "command": ["snakemake"], - "image": self.container_image, - # fargate required privileged False - "privileged": False, - "resourceRequirements": [ - # resource requirements have to be compatible - # see: https://docs.aws.amazon.com/batch/latest/APIReference/API_ResourceRequirement.html # noqa - {"type": "VCPU", "value": vcpu}, - {"type": "MEMORY", "value": mem}, - ], - "networkConfiguration": { - "assignPublicIp": "ENABLED", - }, - "executionRoleArn": self.settings.execution_role, - } - - # TODO: or not todo ? - # ( - # container_properties["volumes"], - # container_properties["mountPoints"], - # ) = self._prepare_mounts() - - # register the job definition - tags = self.settings.tags if isinstance(self.settings.tags, dict) else dict() try: - job_def = self.batch_client.register_job_definition( - jobDefinitionName=job_definition_name, - type=job_definition_type, - containerProperties=container_properties, - platformCapabilities=["FARGATE"], - tags=tags, - ) - self.created_job_defs.append(job_def) - except Exception as e: - raise WorkflowError(e) - - job_command = self._generate_snakemake_command(job) - - # configure job parameters - job_params = { - "jobName": job_name, - "jobQueue": self.settings.job_queue, - "jobDefinition": "{}:{}".format( - job_def["jobDefinitionName"], job_def["revision"] - ), - "containerOverrides": { - "command": job_command, - "resourceRequirements": [ - {"type": "VCPU", "value": vcpu}, - {"type": "MEMORY", "value": mem}, - ], - }, - } - - if self.settings.tags: - job_params["tags"] = self.settings.tags - - if self.settings.task_timeout is not None: - job_params["timeout"] = { - "attemptDurationSeconds": self.settings.task_timeout - } - - # submit the job - try: - submitted = self.batch_client.submit_job(**job_params) + job_info = BatchJobBuilder().submit() self.logger.debug( "AWS Batch job submitted with queue {}, jobId {} and tags {}".format( - self.settings.job_queue, submitted["jobId"], self.settings.tags + self.settings.job_queue, job_info["jobId"], self.settings.tags ) ) except Exception as e: @@ -294,7 +139,7 @@ def run_job(self, job: JobExecutorInterface): self.report_job_submission( SubmittedJobInfo( job=job, - external_jobid=submitted["jobId"], + external_jobid=job_info["jobId"], aux={ "jobs_params": job_params, "job_def_arn": job_def["jobDefinitionArn"], @@ -302,11 +147,6 @@ def run_job(self, job: JobExecutorInterface): ) ) - def _generate_snakemake_command(self, job: JobExecutorInterface) -> str: - """generates the snakemake command for the job""" - exec_job = self.format_job_exec(job) - return ["sh", "-c", shlex.quote(exec_job)] - async def check_active_jobs( self, active_jobs: List[SubmittedJobInfo] ) -> Generator[SubmittedJobInfo, None, None]: @@ -450,72 +290,4 @@ def cancel_jobs(self, active_jobs: List[SubmittedJobInfo]): # cleanup jobs for j in active_jobs: self._terminate_job(j) - self._deregister_job(j) - - -class BatchJobDescriber: - """ - This singleton class handles calling the AWS Batch DescribeJobs API with up to 100 - job IDs per request, then dispensing each job description to the thread interested - in it. This helps avoid AWS API request rate limits when tracking concurrent jobs. - """ - - JOBS_PER_REQUEST = 100 # maximum jobs per DescribeJob request - - def __init__(self): - self.lock = threading.Lock() - self.last_request_time = 0 - self.job_queue = [] - self.jobs = {} - - def describe(self, aws, job_id, period): - """get the latest Batch job description""" - while True: - with self.lock: - if job_id not in self.jobs: - # register new job to be described ASAP - heapq.heappush(self.job_queue, (0.0, job_id)) - self.jobs[job_id] = None - # update as many job descriptions as possible - self._update(aws, period) - # return the desired job description if we have it - desc = self.jobs[job_id] - if desc: - return desc - # otherwise wait (outside the lock) and try again - time.sleep(period / 4) - - def unsubscribe(self, job_id): - """unsubscribe from job_id when no longer interested""" - with self.lock: - if job_id in self.jobs: - del self.jobs[job_id] - - def _update(self, aws, period): - # if enough time has passed since our last DescribeJobs request - if time.time() - self.last_request_time >= period: - # take the N least-recently described jobs - job_ids = set() - assert self.job_queue - while self.job_queue and len(job_ids) < self.JOBS_PER_REQUEST: - job_id = heapq.heappop(self.job_queue)[1] - assert job_id not in job_ids - if job_id in self.jobs: - job_ids.add(job_id) - if not job_ids: - return - # describe them - try: - job_descs = aws.describe_jobs(jobs=list(job_ids)) - finally: - # always: bump last_request_time and re-enqueue these jobs - self.last_request_time = time.time() - for job_id in job_ids: - heapq.heappush(self.job_queue, (self.last_request_time, job_id)) - # update self.jobs with the new descriptions - for job_desc in job_descs["jobs"]: - job_ids.remove(job_desc["jobId"]) - self.jobs[job_desc["jobId"]] = job_desc - assert ( - not job_ids - ), "AWS Batch DescribeJobs didn't return all expected results" + self._deregister_job(j) \ No newline at end of file diff --git a/snakemake_executor_plugin_aws_batch/batch_client.py b/snakemake_executor_plugin_aws_batch/batch_client.py new file mode 100644 index 0000000..caf33b2 --- /dev/null +++ b/snakemake_executor_plugin_aws_batch/batch_client.py @@ -0,0 +1,34 @@ +import boto3 + +class AWSClient: + def __init__(self, service_name, region_name=None): + """ + Initialize an AWS client for a specific service using default credentials. + + :param service_name: The name of the AWS service (e.g., 's3', 'ec2', 'dynamodb'). + :param region_name: The region name to use for the client (optional). + """ + self.service_name = service_name + self.region_name = region_name + self.client = self.initialize_client() + + def initialize_client(self): + """ + Create an AWS client using boto3 with the default credentials. + + :return: The boto3 client for the specified service. + """ + if self.region_name: + client = boto3.client(self.service_name, region_name=self.region_name) + else: + client = boto3.client(self.service_name) + return client + +class BatchClient(AWSClient): + def __init__(self, region_name=None): + """ + Initialize an AWS Batch client using default credentials. + + :param region_name: The region name to use for the client (optional). + """ + super().__init__('batch', region_name) \ No newline at end of file diff --git a/snakemake_executor_plugin_aws_batch/batch_descriptor.py b/snakemake_executor_plugin_aws_batch/batch_descriptor.py new file mode 100644 index 0000000..01c8980 --- /dev/null +++ b/snakemake_executor_plugin_aws_batch/batch_descriptor.py @@ -0,0 +1,71 @@ +import threading +import heapq +import time + + +class BatchJobDescriber: + """ + This singleton class handles calling the AWS Batch DescribeJobs API with up to 100 + job IDs per request, then dispensing each job description to the thread interested + in it. This helps avoid AWS API request rate limits when tracking concurrent jobs. + """ + + JOBS_PER_REQUEST = 100 # maximum jobs per DescribeJob request + + def __init__(self): + self.lock = threading.Lock() + self.last_request_time = 0 + self.job_queue = [] + self.jobs = {} + + def describe(self, aws, job_id, period): + """get the latest Batch job description""" + while True: + with self.lock: + if job_id not in self.jobs: + # register new job to be described ASAP + heapq.heappush(self.job_queue, (0.0, job_id)) + self.jobs[job_id] = None + # update as many job descriptions as possible + self._update(aws, period) + # return the desired job description if we have it + desc = self.jobs[job_id] + if desc: + return desc + # otherwise wait (outside the lock) and try again + time.sleep(period / 4) + + def unsubscribe(self, job_id): + """unsubscribe from job_id when no longer interested""" + with self.lock: + if job_id in self.jobs: + del self.jobs[job_id] + + def _update(self, aws, period): + # if enough time has passed since our last DescribeJobs request + if time.time() - self.last_request_time >= period: + # take the N least-recently described jobs + job_ids = set() + assert self.job_queue + while self.job_queue and len(job_ids) < self.JOBS_PER_REQUEST: + job_id = heapq.heappop(self.job_queue)[1] + assert job_id not in job_ids + if job_id in self.jobs: + job_ids.add(job_id) + if not job_ids: + return + # describe them + try: + job_descs = aws.describe_jobs(jobs=list(job_ids)) + finally: + # always: bump last_request_time and re-enqueue these jobs + self.last_request_time = time.time() + for job_id in job_ids: + heapq.heappush(self.job_queue, (self.last_request_time, job_id)) + # update self.jobs with the new descriptions + for job_desc in job_descs["jobs"]: + job_ids.remove(job_desc["jobId"]) + self.jobs[job_desc["jobId"]] = job_desc + assert ( + not job_ids + ), "AWS Batch DescribeJobs didn't return all expected results" \ No newline at end of file diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py new file mode 100644 index 0000000..53319c7 --- /dev/null +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -0,0 +1,103 @@ +import uuid +from batch_client import BatchClient +from snakemake.exceptions import WorkflowError +import shlex +from snakemake_interface_executor_plugins.jobs import ( + JobExecutorInterface, +) +from enum import Enum + +class BATCH_JOB_DEFINITION_TYPE(Enum): + CONTAINER = "container" + MULTINODE = "multinode" + +class BATCH_JOB_PLATFORM_CAPABILITIES(Enum): + FARGATE = "FARGATE" + EC2 = "EC2" + +class BatchJobBuilder: + def __init__(self, job, container_image, settings, batch_client=None): + self.job = job + self.container_image = container_image + self.settings = settings + self.batch_client = batch_client or BatchClient() + self.created_job_defs = [] + + def build_job_definition(self): + job_uuid = str(uuid.uuid4()) + job_name = f"snakejob-{self.job.name}-{job_uuid}" + job_definition_name = f"snakejob-def-{self.job.name}-{job_uuid}" + + vcpu = str(self.job.resources.get("_cores", str(1))) + mem = str(self.job.resources.get("mem_mb", str(2048))) + + container_properties = { + "command": ["snakemake"], + "image": self.container_image, + "privileged": True, + "resourceRequirements": [ + {"type": "VCPU", "value": vcpu}, + {"type": "MEMORY", "value": mem}, + ], + "networkConfiguration": { + "assignPublicIp": "ENABLED", + }, + "executionRoleArn": self.settings.execution_role, + } + + tags = self.settings.tags if isinstance(self.settings.tags, dict) else dict() + try: + job_def = self.batch_client.client.register_job_definition( + jobDefinitionName=job_definition_name, + type=BATCH_JOB_DEFINITION_TYPE.CONTAINER.value, + containerProperties=container_properties, + platformCapabilities=[BATCH_JOB_PLATFORM_CAPABILITIES.FARGATE.value], + tags=tags, + ) + self.created_job_defs.append(job_def) + return job_def, job_name + except Exception as e: + raise WorkflowError(e) + + def _generate_snakemake_command(self, job: JobExecutorInterface) -> str: + """generates the snakemake command for the job""" + exec_job = self.format_job_exec(job) + return ["sh", "-c", shlex.quote(exec_job)] + + def submit_job(self): + job_def, job_name = self.build_job_definition() + job_command = self._generate_snakemake_command(self.job) + + job_params = { + "jobName": job_name, + "jobQueue": self.settings.job_queue, + "jobDefinition": "{}:{}".format( + job_def["jobDefinitionName"], job_def["revision"] + ), + "containerOverrides": { + "command": job_command, + "resourceRequirements": [ + {"type": "VCPU", "value": str(self.job.resources.get("_cores", str(1)))}, + {"type": "MEMORY", "value": str(self.job.resources.get("mem_mb", str(2048)))}, + ], + }, + } + + if self.settings.tags: + job_params["tags"] = self.settings.tags + + if self.settings.task_timeout is not None: + job_params["timeout"] = { + "attemptDurationSeconds": self.settings.task_timeout + } + + try: + submitted = self.batch_client.client.submit_job(**job_params) + self.logger.debug( + "AWS Batch job submitted with queue {}, jobId {} and tags {}".format( + self.settings.job_queue, submitted["jobId"], self.settings.tags + ) + ) + return submitted + except Exception as e: + raise WorkflowError(e) From c0e6018ae34e76235e6894da55d6481d7a549961 Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 10:44:02 -0800 Subject: [PATCH 02/38] no keys --- tests/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 51883fb..2b9b8ca 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -14,8 +14,6 @@ def get_executor(self) -> str: def get_executor_settings(self) -> Optional[ExecutorSettingsBase]: # instantiate ExecutorSettings of this plugin as appropriate return ExecutorSettings( - access_key_id=os.getenv("SNAKEMAKE_AWS_BATCH_ACCESS_KEY_ID"), - access_key=os.getenv("SNAKEMAKE_AWS_BATCH_ACCESS_KEY"), region=os.environ.get("SNAKEMAKE_AWS_BATCH_REGION", "us-east-1"), job_queue=os.environ.get("SNAKEMAKE_AWS_BATCH_JOB_QUEUE"), execution_role=os.environ.get("SNAKEMAKE_AWS_BATCH_EXECUTION_ROLE"), From 6a8af109ace05760c102f0c9cef499255e34b67f Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 10:47:59 -0800 Subject: [PATCH 03/38] rm container image --- tests/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 2b9b8ca..e538bd6 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -28,6 +28,4 @@ def get_remote_execution_settings( return snakemake.settings.RemoteExecutionSettings( seconds_between_status_checks=5, envvars=self.get_envvars(), - # TODO remove once we have switched to stable snakemake for dev - container_image="snakemake/snakemake:latest", ) From 56e65bae6990620daed8c4214c9c3ab223a81f9e Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 11:57:21 -0800 Subject: [PATCH 04/38] build job definition --- .../__init__.py | 7 ++--- .../batch_job_builder.py | 31 ++++++++++++------- tests/__init__.py | 1 - 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index 1f245da..b83ff06 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field from pprint import pformat -import shlex from typing import List, Generator, Optional from snakemake_executor_plugin_aws_batch import BatchJobDescriber, BatchClient, BatchJobBuilder from snakemake_interface_executor_plugins.executors.base import SubmittedJobInfo @@ -103,7 +102,7 @@ class ExecutorSettings(ExecutorSettingsBase): class Executor(RemoteExecutor): def __post_init__(self): - # set snakemake/snakemake container image + # snakemake/snakemake:latest container image self.container_image = self.workflow.remote_execution_settings.container_image # access executor specific settings @@ -125,9 +124,9 @@ def run_job(self, job: JobExecutorInterface): # snakemake_interface_executor_plugins.executors.base.SubmittedJobInfo. # If required, make sure to pass the job's id to the job_info object, as keyword # argument 'external_job_id'. - try: - job_info = BatchJobBuilder().submit() + job_definition = BatchJobBuilder() + job_info = job_definition.submit() self.logger.debug( "AWS Batch job submitted with queue {}, jobId {} and tags {}".format( self.settings.job_queue, job_info["jobId"], self.settings.tags diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 53319c7..5c8bfd0 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -1,11 +1,13 @@ import uuid -from batch_client import BatchClient from snakemake.exceptions import WorkflowError import shlex from snakemake_interface_executor_plugins.jobs import ( JobExecutorInterface, ) from enum import Enum +from snakemake_executor_plugin_aws_batch import BatchClient, ExecutorSettings + +SNAKEMAKE_COMMAND = "snakemake" class BATCH_JOB_DEFINITION_TYPE(Enum): CONTAINER = "container" @@ -15,12 +17,20 @@ class BATCH_JOB_PLATFORM_CAPABILITIES(Enum): FARGATE = "FARGATE" EC2 = "EC2" +class BATCH_JOB_RESOURCE_REQUIREMENT_TYPE(Enum): + GPU = "GPU" + VCPU = "VCPU" + MEMORY = "MEMEORY" + class BatchJobBuilder: - def __init__(self, job, container_image, settings, batch_client=None): + def __init__(self, job: JobExecutorInterface, + container_image: str, + settings: ExecutorSettings, + batch_client: BatchClient=None): self.job = job self.container_image = container_image self.settings = settings - self.batch_client = batch_client or BatchClient() + self.batch_client = batch_client self.created_job_defs = [] def build_job_definition(self): @@ -28,21 +38,20 @@ def build_job_definition(self): job_name = f"snakejob-{self.job.name}-{job_uuid}" job_definition_name = f"snakejob-def-{self.job.name}-{job_uuid}" + gpu = str(self.job.resources.get("_gpus", str(0))) vcpu = str(self.job.resources.get("_cores", str(1))) mem = str(self.job.resources.get("mem_mb", str(2048))) container_properties = { - "command": ["snakemake"], "image": self.container_image, + "command": [SNAKEMAKE_COMMAND], + "jobRoleArn": self.settings.job_role, "privileged": True, "resourceRequirements": [ - {"type": "VCPU", "value": vcpu}, - {"type": "MEMORY", "value": mem}, + {"type": BATCH_JOB_RESOURCE_REQUIREMENT_TYPE.GPU.value, "value": gpu}, + {"type": BATCH_JOB_RESOURCE_REQUIREMENT_TYPE.VCPU.value, "value": vcpu}, + {"type": BATCH_JOB_RESOURCE_REQUIREMENT_TYPE.MEMORY.value, "value": mem}, ], - "networkConfiguration": { - "assignPublicIp": "ENABLED", - }, - "executionRoleArn": self.settings.execution_role, } tags = self.settings.tags if isinstance(self.settings.tags, dict) else dict() @@ -51,7 +60,7 @@ def build_job_definition(self): jobDefinitionName=job_definition_name, type=BATCH_JOB_DEFINITION_TYPE.CONTAINER.value, containerProperties=container_properties, - platformCapabilities=[BATCH_JOB_PLATFORM_CAPABILITIES.FARGATE.value], + platformCapabilities=[BATCH_JOB_PLATFORM_CAPABILITIES.EC2.value], tags=tags, ) self.created_job_defs.append(job_def) diff --git a/tests/__init__.py b/tests/__init__.py index e538bd6..03401f3 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -12,7 +12,6 @@ def get_executor(self) -> str: return "aws-batch" def get_executor_settings(self) -> Optional[ExecutorSettingsBase]: - # instantiate ExecutorSettings of this plugin as appropriate return ExecutorSettings( region=os.environ.get("SNAKEMAKE_AWS_BATCH_REGION", "us-east-1"), job_queue=os.environ.get("SNAKEMAKE_AWS_BATCH_JOB_QUEUE"), From 0c30eac51249d88eb1e6a331ae86940230b67736 Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 13:30:36 -0800 Subject: [PATCH 05/38] update submit job def --- .../batch_job_builder.py | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 5c8bfd0..8341292 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -23,10 +23,11 @@ class BATCH_JOB_RESOURCE_REQUIREMENT_TYPE(Enum): MEMORY = "MEMEORY" class BatchJobBuilder: - def __init__(self, job: JobExecutorInterface, + def __init__(self, logger, job: JobExecutorInterface, container_image: str, settings: ExecutorSettings, batch_client: BatchClient=None): + self.logger = logger self.job = job self.container_image = container_image self.settings = settings @@ -42,9 +43,11 @@ def build_job_definition(self): vcpu = str(self.job.resources.get("_cores", str(1))) mem = str(self.job.resources.get("mem_mb", str(2048))) + job_command = _generate_snakemake_command(self.job) + container_properties = { "image": self.container_image, - "command": [SNAKEMAKE_COMMAND], + "command": [job_command], "jobRoleArn": self.settings.job_role, "privileged": True, "resourceRequirements": [ @@ -54,14 +57,16 @@ def build_job_definition(self): ], } + timeout = dict() tags = self.settings.tags if isinstance(self.settings.tags, dict) else dict() try: job_def = self.batch_client.client.register_job_definition( jobDefinitionName=job_definition_name, type=BATCH_JOB_DEFINITION_TYPE.CONTAINER.value, containerProperties=container_properties, - platformCapabilities=[BATCH_JOB_PLATFORM_CAPABILITIES.EC2.value], + timeout=timeout, tags=tags, + platformCapabilities=[BATCH_JOB_PLATFORM_CAPABILITIES.EC2.value], ) self.created_job_defs.append(job_def) return job_def, job_name @@ -75,7 +80,6 @@ def _generate_snakemake_command(self, job: JobExecutorInterface) -> str: def submit_job(self): job_def, job_name = self.build_job_definition() - job_command = self._generate_snakemake_command(self.job) job_params = { "jobName": job_name, @@ -83,23 +87,8 @@ def submit_job(self): "jobDefinition": "{}:{}".format( job_def["jobDefinitionName"], job_def["revision"] ), - "containerOverrides": { - "command": job_command, - "resourceRequirements": [ - {"type": "VCPU", "value": str(self.job.resources.get("_cores", str(1)))}, - {"type": "MEMORY", "value": str(self.job.resources.get("mem_mb", str(2048)))}, - ], - }, } - if self.settings.tags: - job_params["tags"] = self.settings.tags - - if self.settings.task_timeout is not None: - job_params["timeout"] = { - "attemptDurationSeconds": self.settings.task_timeout - } - try: submitted = self.batch_client.client.submit_job(**job_params) self.logger.debug( From 4a4bccfe0bf017a85cb0d71c74275647fecdaeb2 Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 13:31:56 -0800 Subject: [PATCH 06/38] update internal method --- snakemake_executor_plugin_aws_batch/batch_job_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 8341292..6db4ff6 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -43,7 +43,7 @@ def build_job_definition(self): vcpu = str(self.job.resources.get("_cores", str(1))) mem = str(self.job.resources.get("mem_mb", str(2048))) - job_command = _generate_snakemake_command(self.job) + job_command = self._generate_snakemake_command() container_properties = { "image": self.container_image, @@ -73,9 +73,9 @@ def build_job_definition(self): except Exception as e: raise WorkflowError(e) - def _generate_snakemake_command(self, job: JobExecutorInterface) -> str: + def _generate_snakemake_command(self) -> str: """generates the snakemake command for the job""" - exec_job = self.format_job_exec(job) + exec_job = self.format_job_exec(self.job) return ["sh", "-c", shlex.quote(exec_job)] def submit_job(self): From 7bc9d2aa929ac407618b2d7708a3b5abf072fb77 Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 13:50:32 -0800 Subject: [PATCH 07/38] move command up --- snakemake_executor_plugin_aws_batch/__init__.py | 12 +++++++++++- .../batch_job_builder.py | 11 +++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index b83ff06..3e45ca6 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +import shlex from pprint import pformat from typing import List, Generator, Optional from snakemake_executor_plugin_aws_batch import BatchJobDescriber, BatchClient, BatchJobBuilder @@ -124,8 +125,17 @@ def run_job(self, job: JobExecutorInterface): # snakemake_interface_executor_plugins.executors.base.SubmittedJobInfo. # If required, make sure to pass the job's id to the job_info object, as keyword # argument 'external_job_id'. + + remote_command = f"/bin/bash -c {shlex.quote(self.format_job_exec(job))}" + self.logger.debug(f"Remote command: {remote_command}") + try: - job_definition = BatchJobBuilder() + job_definition = BatchJobBuilder(logger=self.logger, + job=job, + container_image=self.container_image, + settings=self.settings, + job_command=remote_command, + batch_client=self.batch_client) job_info = job_definition.submit() self.logger.debug( "AWS Batch job submitted with queue {}, jobId {} and tags {}".format( diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 6db4ff6..4700c10 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -26,11 +26,13 @@ class BatchJobBuilder: def __init__(self, logger, job: JobExecutorInterface, container_image: str, settings: ExecutorSettings, + job_command: str, batch_client: BatchClient=None): self.logger = logger self.job = job self.container_image = container_image self.settings = settings + self.job_command = job_command self.batch_client = batch_client self.created_job_defs = [] @@ -43,11 +45,9 @@ def build_job_definition(self): vcpu = str(self.job.resources.get("_cores", str(1))) mem = str(self.job.resources.get("mem_mb", str(2048))) - job_command = self._generate_snakemake_command() - container_properties = { "image": self.container_image, - "command": [job_command], + "command": [self.job_command], "jobRoleArn": self.settings.job_role, "privileged": True, "resourceRequirements": [ @@ -73,11 +73,6 @@ def build_job_definition(self): except Exception as e: raise WorkflowError(e) - def _generate_snakemake_command(self) -> str: - """generates the snakemake command for the job""" - exec_job = self.format_job_exec(self.job) - return ["sh", "-c", shlex.quote(exec_job)] - def submit_job(self): job_def, job_name = self.build_job_definition() From bca0dd0148c6f73a834940d735d3d7889d2b10bb Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 14:00:25 -0800 Subject: [PATCH 08/38] bump deps --- pyproject.toml | 6 +++--- snakemake_executor_plugin_aws_batch/batch_job_builder.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 331af23..51f4cbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,9 +10,9 @@ keywords = ["snakemake", "plugin", "executor", "aws-batch"] [tool.poetry.dependencies] python = "^3.11" -snakemake-interface-common = "^1.15.0" -snakemake-interface-executor-plugins = "^8.1.1" -boto3 = "^1.33.11" +snakemake-interface-common = "^1.17.4" +snakemake-interface-executor-plugins = "^9.3.2" +boto3 = "^1.36.5" [tool.poetry.group.dev.dependencies] black = "^23.11.0" diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 4700c10..5a8e466 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -27,7 +27,7 @@ def __init__(self, logger, job: JobExecutorInterface, container_image: str, settings: ExecutorSettings, job_command: str, - batch_client: BatchClient=None): + batch_client: BatchClient): self.logger = logger self.job = job self.container_image = container_image @@ -85,7 +85,7 @@ def submit_job(self): } try: - submitted = self.batch_client.client.submit_job(**job_params) + submitted = self.batch_client.submit_job(**job_params) self.logger.debug( "AWS Batch job submitted with queue {}, jobId {} and tags {}".format( self.settings.job_queue, submitted["jobId"], self.settings.tags From acac8965468b8c64e345d3d043751e23f26f0bc7 Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 19:11:48 -0800 Subject: [PATCH 09/38] fixes --- snakemake_executor_plugin_aws_batch/__init__.py | 2 +- snakemake_executor_plugin_aws_batch/batch_job_builder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index 3e45ca6..f24cccf 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -113,7 +113,7 @@ def __post_init__(self): # keep track of job definitions self.created_job_defs = list() self._describer = BatchJobDescriber() - self.batch_clint = BatchClient(region_name=self.settings.region) + self.batch_client = BatchClient(region_name=self.settings.region) def run_job(self, job: JobExecutorInterface): # Implement here how to run a job. diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 5a8e466..6f63408 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -20,7 +20,7 @@ class BATCH_JOB_PLATFORM_CAPABILITIES(Enum): class BATCH_JOB_RESOURCE_REQUIREMENT_TYPE(Enum): GPU = "GPU" VCPU = "VCPU" - MEMORY = "MEMEORY" + MEMORY = "MEMORY" class BatchJobBuilder: def __init__(self, logger, job: JobExecutorInterface, From 8447d07f3318dacf241feed3d7abef90c7e129de Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 19:18:55 -0800 Subject: [PATCH 10/38] settings tweak --- snakemake_executor_plugin_aws_batch/__init__.py | 4 ++-- tests/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index f24cccf..d657813 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -39,10 +39,10 @@ class ExecutorSettings(ExecutorSettingsBase): "required": True, }, ) - execution_role: Optional[str] = field( + job_role: Optional[str] = field( default=None, metadata={ - "help": "The AWS execution role ARN that is used for running the tasks", + "help": "The AWS job role ARN that is used for running the tasks", "env_var": True, "required": True, }, diff --git a/tests/__init__.py b/tests/__init__.py index 03401f3..a3da7bf 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -13,9 +13,9 @@ def get_executor(self) -> str: def get_executor_settings(self) -> Optional[ExecutorSettingsBase]: return ExecutorSettings( - region=os.environ.get("SNAKEMAKE_AWS_BATCH_REGION", "us-east-1"), + region=os.environ.get("SNAKEMAKE_AWS_BATCH_REGION", "us-west-2"), job_queue=os.environ.get("SNAKEMAKE_AWS_BATCH_JOB_QUEUE"), - execution_role=os.environ.get("SNAKEMAKE_AWS_BATCH_EXECUTION_ROLE"), + job_role=os.environ.get("SNAKEMAKE_AWS_BATCH_JOB_ROLE"), ) def get_assume_shared_fs(self) -> bool: From 4940eadf9c4fa65469d91a49f1511f26818fc292 Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 19:54:01 -0800 Subject: [PATCH 11/38] authors --- snakemake_executor_plugin_aws_batch/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index d657813..d020913 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -1,3 +1,8 @@ +__author__ = "Jake VanCampen, Johannes Köster" +__copyright__ = f"Copyright 2025, Snakemake community" +__email__ = "jake.vancampen7@gmail.com" +__license__ = "MIT" + from dataclasses import dataclass, field import shlex from pprint import pformat From ce595f51f4b8d05ed516ed9ed9d9311b78686077 Mon Sep 17 00:00:00 2001 From: jakevc Date: Fri, 24 Jan 2025 20:53:18 -0800 Subject: [PATCH 12/38] format --- .vscode/settings.json | 9 +++++++ .../__init__.py | 19 +++++++------ .../batch_client.py | 12 +++++---- .../batch_descriptor.py | 2 +- .../batch_job_builder.py | 27 +++++++++++++------ tests/__init__.py | 1 - 6 files changed, 47 insertions(+), 23 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..97aa548 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,9 @@ +{ + "black-formatter.interpreter": [ + "./.venv/bin/python" + ], + "flake8.interpreter": [ + "./.venv/bin/python" + ], + "files.saveConflictResolution": "overwriteFileOnDisk" +} \ No newline at end of file diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index d020913..3e43271 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -1,13 +1,16 @@ __author__ = "Jake VanCampen, Johannes Köster" -__copyright__ = f"Copyright 2025, Snakemake community" +__copyright__ = "Copyright 2025, Snakemake community" __email__ = "jake.vancampen7@gmail.com" __license__ = "MIT" from dataclasses import dataclass, field import shlex +from typing import Union from pprint import pformat -from typing import List, Generator, Optional -from snakemake_executor_plugin_aws_batch import BatchJobDescriber, BatchClient, BatchJobBuilder +from typing import List, AsyncGenerator, Optional +from snakemake_executor_plugin_aws_batch.batch_client import BatchClient +from snakemake_executor_plugin_aws_batch.batch_job_builder import BatchJobBuilder +from snakemake_executor_plugin_aws_batch.batch_descriptor import BatchJobDescriber from snakemake_interface_executor_plugins.executors.base import SubmittedJobInfo from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor from snakemake_interface_executor_plugins.settings import ( @@ -135,7 +138,7 @@ def run_job(self, job: JobExecutorInterface): self.logger.debug(f"Remote command: {remote_command}") try: - job_definition = BatchJobBuilder(logger=self.logger, + job_definition = BatchJobBuilder(logger=self.logger, job=job, container_image=self.container_image, settings=self.settings, @@ -155,15 +158,15 @@ def run_job(self, job: JobExecutorInterface): job=job, external_jobid=job_info["jobId"], aux={ - "jobs_params": job_params, - "job_def_arn": job_def["jobDefinitionArn"], + "jobs_params": job_info["job_params"], + "job_def_arn": job_definition["jobDefinitionArn"], }, ) ) async def check_active_jobs( self, active_jobs: List[SubmittedJobInfo] - ) -> Generator[SubmittedJobInfo, None, None]: + ) -> AsyncGenerator[SubmittedJobInfo, None, None]: # Check the status of active jobs. # You have to iterate over the given list active_jobs. @@ -196,7 +199,7 @@ async def check_active_jobs( else: yield job - def _get_job_status(self, job: SubmittedJobInfo) -> (int, Optional[str]): + def _get_job_status(self, job: SubmittedJobInfo) -> Union[int, Optional[str]]: """poll for Batch job success or failure returns exits code and failure information if exit code is not 0 diff --git a/snakemake_executor_plugin_aws_batch/batch_client.py b/snakemake_executor_plugin_aws_batch/batch_client.py index caf33b2..839bc62 100644 --- a/snakemake_executor_plugin_aws_batch/batch_client.py +++ b/snakemake_executor_plugin_aws_batch/batch_client.py @@ -1,11 +1,12 @@ import boto3 + class AWSClient: def __init__(self, service_name, region_name=None): """ Initialize an AWS client for a specific service using default credentials. - - :param service_name: The name of the AWS service (e.g., 's3', 'ec2', 'dynamodb'). + + :param service_name: The name of the AWS service (e.g., 's3', 'ec2', 'dynamodb') :param region_name: The region name to use for the client (optional). """ self.service_name = service_name @@ -15,7 +16,7 @@ def __init__(self, service_name, region_name=None): def initialize_client(self): """ Create an AWS client using boto3 with the default credentials. - + :return: The boto3 client for the specified service. """ if self.region_name: @@ -24,11 +25,12 @@ def initialize_client(self): client = boto3.client(self.service_name) return client + class BatchClient(AWSClient): def __init__(self, region_name=None): """ Initialize an AWS Batch client using default credentials. - + :param region_name: The region name to use for the client (optional). """ - super().__init__('batch', region_name) \ No newline at end of file + super().__init__("batch", region_name) diff --git a/snakemake_executor_plugin_aws_batch/batch_descriptor.py b/snakemake_executor_plugin_aws_batch/batch_descriptor.py index 01c8980..f42080d 100644 --- a/snakemake_executor_plugin_aws_batch/batch_descriptor.py +++ b/snakemake_executor_plugin_aws_batch/batch_descriptor.py @@ -68,4 +68,4 @@ def _update(self, aws, period): self.jobs[job_desc["jobId"]] = job_desc assert ( not job_ids - ), "AWS Batch DescribeJobs didn't return all expected results" \ No newline at end of file + ), "AWS Batch DescribeJobs didn't return all expected results" diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 6f63408..bf78457 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -1,33 +1,41 @@ import uuid from snakemake.exceptions import WorkflowError -import shlex from snakemake_interface_executor_plugins.jobs import ( JobExecutorInterface, ) from enum import Enum -from snakemake_executor_plugin_aws_batch import BatchClient, ExecutorSettings +from snakemake_executor_plugin_aws_batch import ExecutorSettings +from snakemake_executor_plugin_aws_batch.batch_client import BatchClient SNAKEMAKE_COMMAND = "snakemake" + class BATCH_JOB_DEFINITION_TYPE(Enum): CONTAINER = "container" MULTINODE = "multinode" + class BATCH_JOB_PLATFORM_CAPABILITIES(Enum): FARGATE = "FARGATE" EC2 = "EC2" + class BATCH_JOB_RESOURCE_REQUIREMENT_TYPE(Enum): GPU = "GPU" VCPU = "VCPU" MEMORY = "MEMORY" + class BatchJobBuilder: - def __init__(self, logger, job: JobExecutorInterface, - container_image: str, - settings: ExecutorSettings, - job_command: str, - batch_client: BatchClient): + def __init__( + self, + logger, + job: JobExecutorInterface, + container_image: str, + settings: ExecutorSettings, + job_command: str, + batch_client: BatchClient, + ): self.logger = logger self.job = job self.container_image = container_image @@ -53,7 +61,10 @@ def build_job_definition(self): "resourceRequirements": [ {"type": BATCH_JOB_RESOURCE_REQUIREMENT_TYPE.GPU.value, "value": gpu}, {"type": BATCH_JOB_RESOURCE_REQUIREMENT_TYPE.VCPU.value, "value": vcpu}, - {"type": BATCH_JOB_RESOURCE_REQUIREMENT_TYPE.MEMORY.value, "value": mem}, + { + "type": BATCH_JOB_RESOURCE_REQUIREMENT_TYPE.MEMORY.value, + "value": mem, + }, # noqa ], } diff --git a/tests/__init__.py b/tests/__init__.py index 1bb9794..aa901ce 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,7 +3,6 @@ import snakemake.common.tests from snakemake_interface_executor_plugins.settings import ExecutorSettingsBase - from snakemake_executor_plugin_aws_batch import ExecutorSettings From 21b882f0f31d44d31ec3202ce2953a584bf27b2f Mon Sep 17 00:00:00 2001 From: jakevc Date: Sat, 25 Jan 2025 14:00:27 -0800 Subject: [PATCH 13/38] executor setting type causes circular import --- .../__init__.py | 23 ++++++++++--------- .../batch_job_builder.py | 3 +-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index 3e43271..25db126 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -9,7 +9,7 @@ from pprint import pformat from typing import List, AsyncGenerator, Optional from snakemake_executor_plugin_aws_batch.batch_client import BatchClient -from snakemake_executor_plugin_aws_batch.batch_job_builder import BatchJobBuilder +from snakemake_executor_plugin_aws_batch.batch_job_builder import BatchJobBuilder from snakemake_executor_plugin_aws_batch.batch_descriptor import BatchJobDescriber from snakemake_interface_executor_plugins.executors.base import SubmittedJobInfo from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor @@ -110,12 +110,11 @@ class ExecutorSettings(ExecutorSettingsBase): # Implementation of your executor class Executor(RemoteExecutor): def __post_init__(self): - # snakemake/snakemake:latest container image self.container_image = self.workflow.remote_execution_settings.container_image # access executor specific settings - self.settings: ExecutorSettings = self.workflow.executor_settings + self.settings = self.workflow.executor_settings self.logger.debug(f"ExecutorSettings: {pformat(self.settings, indent=2)}") # keep track of job definitions @@ -138,12 +137,14 @@ def run_job(self, job: JobExecutorInterface): self.logger.debug(f"Remote command: {remote_command}") try: - job_definition = BatchJobBuilder(logger=self.logger, - job=job, - container_image=self.container_image, - settings=self.settings, - job_command=remote_command, - batch_client=self.batch_client) + job_definition = BatchJobBuilder( + logger=self.logger, + job=job, + container_image=self.container_image, + settings=self.settings, + job_command=remote_command, + batch_client=self.batch_client, + ) job_info = job_definition.submit() self.logger.debug( "AWS Batch job submitted with queue {}, jobId {} and tags {}".format( @@ -166,7 +167,7 @@ def run_job(self, job: JobExecutorInterface): async def check_active_jobs( self, active_jobs: List[SubmittedJobInfo] - ) -> AsyncGenerator[SubmittedJobInfo, None, None]: + ) -> AsyncGenerator[SubmittedJobInfo, None]: # Check the status of active jobs. # You have to iterate over the given list active_jobs. @@ -307,4 +308,4 @@ def cancel_jobs(self, active_jobs: List[SubmittedJobInfo]): # cleanup jobs for j in active_jobs: self._terminate_job(j) - self._deregister_job(j) \ No newline at end of file + self._deregister_job(j) diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index bf78457..dc527d6 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -4,7 +4,6 @@ JobExecutorInterface, ) from enum import Enum -from snakemake_executor_plugin_aws_batch import ExecutorSettings from snakemake_executor_plugin_aws_batch.batch_client import BatchClient SNAKEMAKE_COMMAND = "snakemake" @@ -32,7 +31,7 @@ def __init__( logger, job: JobExecutorInterface, container_image: str, - settings: ExecutorSettings, + settings, job_command: str, batch_client: BatchClient, ): From 6e6916da6b17ae87415f5ab3ea90c439b53fc2d7 Mon Sep 17 00:00:00 2001 From: jakevc Date: Sat, 25 Jan 2025 14:22:51 -0800 Subject: [PATCH 14/38] tweak --- snakemake_executor_plugin_aws_batch/__init__.py | 2 +- snakemake_executor_plugin_aws_batch/batch_client.py | 9 +++++++++ snakemake_executor_plugin_aws_batch/batch_job_builder.py | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index 25db126..2a069ff 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -160,7 +160,7 @@ def run_job(self, job: JobExecutorInterface): external_jobid=job_info["jobId"], aux={ "jobs_params": job_info["job_params"], - "job_def_arn": job_definition["jobDefinitionArn"], + "job_def_arn": job_info["jobDefinitionArn"], }, ) ) diff --git a/snakemake_executor_plugin_aws_batch/batch_client.py b/snakemake_executor_plugin_aws_batch/batch_client.py index 839bc62..168b558 100644 --- a/snakemake_executor_plugin_aws_batch/batch_client.py +++ b/snakemake_executor_plugin_aws_batch/batch_client.py @@ -34,3 +34,12 @@ def __init__(self, region_name=None): :param region_name: The region name to use for the client (optional). """ super().__init__("batch", region_name) + + def submit_job(self, **kwargs): + """ + Submit a job to AWS Batch. + + :param kwargs: The keyword arguments to pass to the submit_job method. + :return: The response from the submit_job method. + """ + return self.client.submit_job(**kwargs) diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index dc527d6..fab7d69 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -83,7 +83,7 @@ def build_job_definition(self): except Exception as e: raise WorkflowError(e) - def submit_job(self): + def submit(self): job_def, job_name = self.build_job_definition() job_params = { From 86dfe89ca6a938e764bddd9bbd1a120e4dae42ee Mon Sep 17 00:00:00 2001 From: jakevc Date: Sat, 25 Jan 2025 14:56:57 -0800 Subject: [PATCH 15/38] tweak --- snakemake_executor_plugin_aws_batch/__init__.py | 2 +- snakemake_executor_plugin_aws_batch/batch_job_builder.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index 2a069ff..50e0b4d 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -67,7 +67,7 @@ class ExecutorSettings(ExecutorSettingsBase): }, ) task_timeout: Optional[int] = field( - default=None, + default=60, metadata={ "help": ( "Task timeout (seconds) will force AWS Batch to terminate " diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index fab7d69..3adb254 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -58,7 +58,6 @@ def build_job_definition(self): "jobRoleArn": self.settings.job_role, "privileged": True, "resourceRequirements": [ - {"type": BATCH_JOB_RESOURCE_REQUIREMENT_TYPE.GPU.value, "value": gpu}, {"type": BATCH_JOB_RESOURCE_REQUIREMENT_TYPE.VCPU.value, "value": vcpu}, { "type": BATCH_JOB_RESOURCE_REQUIREMENT_TYPE.MEMORY.value, @@ -67,7 +66,12 @@ def build_job_definition(self): ], } - timeout = dict() + if int(gpu) > 0: + container_properties["resourceRequirements"].append( + {"type": BATCH_JOB_RESOURCE_REQUIREMENT_TYPE.GPU.value, "value": gpu} + ) + + timeout = {"attemptDurationSeconds": self.settings.task_timeout} tags = self.settings.tags if isinstance(self.settings.tags, dict) else dict() try: job_def = self.batch_client.client.register_job_definition( From 459580e0bdc38e38aaf24dd1a4aa6b06e9486b92 Mon Sep 17 00:00:00 2001 From: jakevc Date: Sat, 25 Jan 2025 19:47:19 -0800 Subject: [PATCH 16/38] update WorkflowError msg --- snakemake_executor_plugin_aws_batch/batch_job_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 3adb254..3a254ae 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -85,7 +85,7 @@ def build_job_definition(self): self.created_job_defs.append(job_def) return job_def, job_name except Exception as e: - raise WorkflowError(e) + raise WorkflowError(f"Failed to register job definition: {e}") from e def submit(self): job_def, job_name = self.build_job_definition() @@ -107,4 +107,4 @@ def submit(self): ) return submitted except Exception as e: - raise WorkflowError(e) + raise WorkflowError(f"Failed to submit job: {e}") from e From e1b1bb0ce968bbace5129654b15caa7a56e083ce Mon Sep 17 00:00:00 2001 From: jakevc Date: Sat, 25 Jan 2025 19:51:33 -0800 Subject: [PATCH 17/38] tuple --- snakemake_executor_plugin_aws_batch/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index 50e0b4d..442bebb 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -5,7 +5,6 @@ from dataclasses import dataclass, field import shlex -from typing import Union from pprint import pformat from typing import List, AsyncGenerator, Optional from snakemake_executor_plugin_aws_batch.batch_client import BatchClient @@ -200,7 +199,7 @@ async def check_active_jobs( else: yield job - def _get_job_status(self, job: SubmittedJobInfo) -> Union[int, Optional[str]]: + def _get_job_status(self, job: SubmittedJobInfo) -> tuple[int, Optional[str]]: """poll for Batch job success or failure returns exits code and failure information if exit code is not 0 From 3c69a03865e5e01a5e654e72b1045099b561a839 Mon Sep 17 00:00:00 2001 From: jakevc Date: Sat, 25 Jan 2025 19:54:39 -0800 Subject: [PATCH 18/38] todo --- snakemake_executor_plugin_aws_batch/batch_job_builder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 3a254ae..245c911 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -48,6 +48,7 @@ def build_job_definition(self): job_name = f"snakejob-{self.job.name}-{job_uuid}" job_definition_name = f"snakejob-def-{self.job.name}-{job_uuid}" + # TODO: validate resources gpu = str(self.job.resources.get("_gpus", str(0))) vcpu = str(self.job.resources.get("_cores", str(1))) mem = str(self.job.resources.get("mem_mb", str(2048))) From bed0327578a98284850ff6b585e2398f1bfe4b66 Mon Sep 17 00:00:00 2001 From: jakevc Date: Sat, 25 Jan 2025 21:56:24 -0800 Subject: [PATCH 19/38] terraform --- .gitignore | 5 ++- terraform/README.md | 11 +++++ terraform/main.tf | 96 ++++++++++++++++++++++++++++++++++++++++ terraform/vars.tf | 104 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 terraform/README.md create mode 100644 terraform/main.tf create mode 100644 terraform/vars.tf diff --git a/.gitignore b/.gitignore index bcc3d59..feac811 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,7 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -poetry.lock \ No newline at end of file +poetry.lock + +.terraform* +terraform.tfstate* \ No newline at end of file diff --git a/terraform/README.md b/terraform/README.md new file mode 100644 index 0000000..62a4c84 --- /dev/null +++ b/terraform/README.md @@ -0,0 +1,11 @@ +# README + +This directory contains terraform templates to deploy the minimum required AWS +infrastructure for the snakemake_executor_plugin_aws_batch. Update vars.tf to +use the resource names and attribute values suitable for your environment, then run: + +``` +terraform init +terraform plan +terraform apply +``` \ No newline at end of file diff --git a/terraform/main.tf b/terraform/main.tf new file mode 100644 index 0000000..7ef3297 --- /dev/null +++ b/terraform/main.tf @@ -0,0 +1,96 @@ +provider "aws" { + region = var.aws_provider_region +} + +data "aws_iam_policy_document" "ec2_assume_role" { + statement { + effect = "Allow" + + principals { + type = "Service" + identifiers = ["ec2.amazonaws.com"] + } + + actions = ["sts:AssumeRole"] + } +} + +resource "aws_iam_role" "ecs_instance_role" { + name = var.ecs_instance_role_name + assume_role_policy = data.aws_iam_policy_document.ec2_assume_role.json +} + +resource "aws_iam_role_policy_attachment" "ecs_instance_role" { + role = aws_iam_role.ecs_instance_role.name + policy_arn = var.ecs_instance_role_policy_arn +} + +resource "aws_iam_instance_profile" "ecs_instance_role" { + name = var.ecs_instance_role_name + role = aws_iam_role.ecs_instance_role.name +} + +data "aws_iam_policy_document" "batch_assume_role" { + statement { + effect = "Allow" + + principals { + type = "Service" + identifiers = ["batch.amazonaws.com"] + } + + actions = ["sts:AssumeRole"] + } +} + +resource "aws_iam_role" "aws_batch_service_role" { + name = var.aws_batch_service_role_name + assume_role_policy = data.aws_iam_policy_document.batch_assume_role.json +} + +resource "aws_iam_role_policy_attachment" "aws_batch_service_role" { + role = aws_iam_role.aws_batch_service_role.name + policy_arn = var.aws_batch_service_role_policy_arn +} + +resource "aws_placement_group" "sample" { + name = var.aws_placement_group_name + strategy = var.aws_placement_group_strategy +} + +resource "aws_batch_compute_environment" "sample" { + compute_environment_name = var.aws_batch_compute_environment_name + + compute_resources { + instance_role = aws_iam_instance_profile.ecs_instance_role.arn + + instance_type = var.instance_types + + max_vcpus = var.max_vcpus + min_vcpus = var.min_vcpus + + placement_group = aws_placement_group.sample.name + + security_group_ids = var.aws_batch_security_group_ids + + subnets = var.aws_batch_subnet_ids + + type = var.aws_batch_compute_resource_type + } + + service_role = aws_iam_role.aws_batch_service_role.arn + type = var.aws_batch_compute_environment_type + depends_on = [aws_iam_role_policy_attachment.aws_batch_service_role] +} + + +resource "aws_batch_job_queue" "test_queue" { + name = var.aws_batch_job_queue_name + state = var.aws_batch_job_queue_state + priority = 1 + + compute_environment_order { + order = 1 + compute_environment = aws_batch_compute_environment.sample.arn + } +} \ No newline at end of file diff --git a/terraform/vars.tf b/terraform/vars.tf new file mode 100644 index 0000000..22d85b3 --- /dev/null +++ b/terraform/vars.tf @@ -0,0 +1,104 @@ +variable "aws_provider_region" { + description = "The AWS region to deploy resources in" + type = string + default = "us-west-2" +} + +variable "ecs_instance_role_name" { + description = "The name of the ECS instance role" + type = string + default = "ecs_instance_role01" +} + +variable "ecs_instance_role_policy_arn" { + description = "The ARN of the ECS instance role policy" + type = string + default = "arn:aws:iam::aws:policy/service-role/AmazonEC2ContainerServiceforEC2Role" +} + +variable "aws_batch_service_role_name" { + description = "The name of the AWS Batch service role" + type = string + default = "aws_batch_service_role01" +} + +variable "aws_batch_service_role_policy_arn" { + description = "The ARN of the AWS Batch service role policy" + type = string + default = "arn:aws:iam::aws:policy/service-role/AWSBatchServiceRole" +} + +# default subnet id +variable "aws_batch_subnet_ids" { + description = "The subnet IDs for the AWS Batch compute environment" + type = list(string) + default = ["subnet-9d6142e4"] + +} + +# default security group id +variable "aws_batch_security_group_ids" { + description = "The security group IDs for the AWS Batch compute environment" + type = list(string) + default = ["sg-ee1ccb9a"] +} + +variable "aws_placement_group_name" { + description = "The name of the placement group" + type = string + default = "sample" +} + +variable "aws_placement_group_strategy" { + description = "The strategy of the placement group" + type = string + default = "cluster" +} + +variable "aws_batch_compute_environment_name" { + description = "The name of the AWS Batch compute environment" + type = string + default = "snakecomputeenv" +} + +variable "instance_types" { + description = "The allowed instance types for the compute environment" + type = list(string) + default = ["c4.large", "c4.xlarge", "c4.2xlarge", "c4.4xlarge", "c4.8xlarge"] +} + +variable "max_vcpus" { + description = "The maximum number of vCPUs for the compute environment" + type = number + default = 16 +} + +variable "min_vcpus" { + description = "The minimum number of vCPUs for the compute environment" + type = number + default = 0 +} + +variable "aws_batch_compute_resource_type" { + description = "The type of the AWS Batch compute environment" + type = string + default = "EC2" +} + +variable "aws_batch_compute_environment_type" { + description = "The type of the AWS Batch compute environment" + type = string + default = "MANAGED" +} + +variable "aws_batch_job_queue_name" { + description = "The name of the AWS Batch job queue" + type = string + default = "snakejobqueue" +} + +variable "aws_batch_job_queue_state" { + description = "The state of the AWS Batch job queue" + type = string + default = "ENABLED" +} \ No newline at end of file From 14afb48b613e9aa76d07fd9ac1d8019eb5b61c4a Mon Sep 17 00:00:00 2001 From: jakevc Date: Sun, 26 Jan 2025 16:20:05 -0800 Subject: [PATCH 20/38] debug --- snakemake_executor_plugin_aws_batch/__init__.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index 442bebb..fdc3249 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -153,14 +153,10 @@ def run_job(self, job: JobExecutorInterface): except Exception as e: raise WorkflowError(e) + self.logger.debug(f"Job info: {pformat(job_info, indent=2)}") self.report_job_submission( SubmittedJobInfo( - job=job, - external_jobid=job_info["jobId"], - aux={ - "jobs_params": job_info["job_params"], - "job_def_arn": job_info["jobDefinitionArn"], - }, + job=job, external_jobid=job_info["jobId"], aux=dict(job_info) ) ) From 8c1d65d671b746508265fea30e8e998834279467 Mon Sep 17 00:00:00 2001 From: jakevc Date: Sun, 26 Jan 2025 17:01:26 -0800 Subject: [PATCH 21/38] tf --- terraform/main.tf | 26 +++++++++++++++++++++++--- terraform/vars.tf | 32 ++++++++++++++++++-------------- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/terraform/main.tf b/terraform/main.tf index 7ef3297..496594c 100644 --- a/terraform/main.tf +++ b/terraform/main.tf @@ -58,6 +58,26 @@ resource "aws_placement_group" "sample" { strategy = var.aws_placement_group_strategy } +resource "aws_security_group" "sg01" { + name = var.aws_security_group_name + + egress { + from_port = 0 + to_port = 0 + protocol = "-1" + cidr_blocks = ["0.0.0.0/0"] + } +} + +resource "aws_vpc" "vpc01" { + cidr_block = var.aws_vpc_cidr_block +} + +resource "aws_subnet" "subnet01" { + vpc_id = aws_vpc.vpc01.id + cidr_block = var.aws_subnet_cidr_block +} + resource "aws_batch_compute_environment" "sample" { compute_environment_name = var.aws_batch_compute_environment_name @@ -71,9 +91,9 @@ resource "aws_batch_compute_environment" "sample" { placement_group = aws_placement_group.sample.name - security_group_ids = var.aws_batch_security_group_ids + security_group_ids = [aws_security_group.sg01.id] - subnets = var.aws_batch_subnet_ids + subnets = [aws_subnet.subnet01.id] type = var.aws_batch_compute_resource_type } @@ -84,7 +104,7 @@ resource "aws_batch_compute_environment" "sample" { } -resource "aws_batch_job_queue" "test_queue" { +resource "aws_batch_job_queue" "snakequeue" { name = var.aws_batch_job_queue_name state = var.aws_batch_job_queue_state priority = 1 diff --git a/terraform/vars.tf b/terraform/vars.tf index 22d85b3..f321ef8 100644 --- a/terraform/vars.tf +++ b/terraform/vars.tf @@ -28,20 +28,6 @@ variable "aws_batch_service_role_policy_arn" { default = "arn:aws:iam::aws:policy/service-role/AWSBatchServiceRole" } -# default subnet id -variable "aws_batch_subnet_ids" { - description = "The subnet IDs for the AWS Batch compute environment" - type = list(string) - default = ["subnet-9d6142e4"] - -} - -# default security group id -variable "aws_batch_security_group_ids" { - description = "The security group IDs for the AWS Batch compute environment" - type = list(string) - default = ["sg-ee1ccb9a"] -} variable "aws_placement_group_name" { description = "The name of the placement group" @@ -55,6 +41,24 @@ variable "aws_placement_group_strategy" { default = "cluster" } +variable "aws_security_group_name" { + description = "The name of the security group" + type = string + default = "sg01" +} + +variable "aws_vpc_cidr_block" { + description = "The CIDR block for the VPC" + type = string + default = "10.1.0.0/16" +} + +variable "aws_subnet_cidr_block" { + description = "The CIDR block for the subnet" + type = string + default = "10.1.1.0/24" +} + variable "aws_batch_compute_environment_name" { description = "The name of the AWS Batch compute environment" type = string From b0865a71f8e547797d67202d54e89de9d204e63e Mon Sep 17 00:00:00 2001 From: jakevc Date: Sun, 26 Jan 2025 17:10:38 -0800 Subject: [PATCH 22/38] vpc_id --- terraform/main.tf | 1 + 1 file changed, 1 insertion(+) diff --git a/terraform/main.tf b/terraform/main.tf index 496594c..24821cc 100644 --- a/terraform/main.tf +++ b/terraform/main.tf @@ -60,6 +60,7 @@ resource "aws_placement_group" "sample" { resource "aws_security_group" "sg01" { name = var.aws_security_group_name + vpc_id = aws_vpc.vpc01.id egress { from_port = 0 From 9dfa6e7f6dd66a8c623ee2d9a51fc98e88abf4ee Mon Sep 17 00:00:00 2001 From: jakevc Date: Sun, 26 Jan 2025 19:48:57 -0800 Subject: [PATCH 23/38] tf --- snakemake_executor_plugin_aws_batch/batch_client.py | 10 ++++++++++ terraform/main.tf | 3 +++ 2 files changed, 13 insertions(+) diff --git a/snakemake_executor_plugin_aws_batch/batch_client.py b/snakemake_executor_plugin_aws_batch/batch_client.py index 168b558..bfbbea8 100644 --- a/snakemake_executor_plugin_aws_batch/batch_client.py +++ b/snakemake_executor_plugin_aws_batch/batch_client.py @@ -26,6 +26,7 @@ def initialize_client(self): return client +# client class stub for AWS Batch class BatchClient(AWSClient): def __init__(self, region_name=None): """ @@ -43,3 +44,12 @@ def submit_job(self, **kwargs): :return: The response from the submit_job method. """ return self.client.submit_job(**kwargs) + + def describe_jobs(self, **kwargs): + """ + Describe jobs in AWS Batch. + + :param kwargs: The keyword arguments to pass to the describe_jobs method. + :return: The response from the describe_jobs method. + """ + return self.client.describe_jobs(**kwargs) diff --git a/terraform/main.tf b/terraform/main.tf index 24821cc..87ec471 100644 --- a/terraform/main.tf +++ b/terraform/main.tf @@ -77,6 +77,9 @@ resource "aws_vpc" "vpc01" { resource "aws_subnet" "subnet01" { vpc_id = aws_vpc.vpc01.id cidr_block = var.aws_subnet_cidr_block + + # jobs will be stuck in runnable state if this is not set + map_public_ip_on_launch = true } resource "aws_batch_compute_environment" "sample" { From 504d8845c61a5cf8e51239f0818471dca93a9e62 Mon Sep 17 00:00:00 2001 From: jakevc Date: Sun, 26 Jan 2025 20:58:15 -0800 Subject: [PATCH 24/38] tf --- terraform/main.tf | 25 +++++++++++++++++++++++++ terraform/vars.tf | 5 +++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/terraform/main.tf b/terraform/main.tf index 87ec471..9acfc12 100644 --- a/terraform/main.tf +++ b/terraform/main.tf @@ -62,6 +62,13 @@ resource "aws_security_group" "sg01" { name = var.aws_security_group_name vpc_id = aws_vpc.vpc01.id + ingress { + from_port = 0 + to_port = 0 + protocol = "-1" + cidr_blocks = ["0.0.0.0/0"] + } + egress { from_port = 0 to_port = 0 @@ -82,6 +89,24 @@ resource "aws_subnet" "subnet01" { map_public_ip_on_launch = true } +resource "aws_internet_gateway" "igw" { + vpc_id = aws_vpc.vpc01.id +} + +resource "aws_route_table" "public_rt" { + vpc_id = aws_vpc.vpc01.id + + route { + cidr_block = "0.0.0.0/0" + gateway_id = aws_internet_gateway.igw.id + } +} + +resource "aws_route_table_association" "public_rt_assoc" { + subnet_id = aws_subnet.subnet01.id + route_table_id = aws_route_table.public_rt.id +} + resource "aws_batch_compute_environment" "sample" { compute_environment_name = var.aws_batch_compute_environment_name diff --git a/terraform/vars.tf b/terraform/vars.tf index f321ef8..a8246f9 100644 --- a/terraform/vars.tf +++ b/terraform/vars.tf @@ -68,13 +68,14 @@ variable "aws_batch_compute_environment_name" { variable "instance_types" { description = "The allowed instance types for the compute environment" type = list(string) - default = ["c4.large", "c4.xlarge", "c4.2xlarge", "c4.4xlarge", "c4.8xlarge"] + default = ["c4.large"] + # , "c4.xlarge", "c4.2xlarge", "c4.4xlarge", "c4.8xlarge"] } variable "max_vcpus" { description = "The maximum number of vCPUs for the compute environment" type = number - default = 16 + default = 2 } variable "min_vcpus" { From a17d9b60442a5e09fd166a5aacc763709bd60d59 Mon Sep 17 00:00:00 2001 From: jakevc Date: Mon, 27 Jan 2025 20:33:51 -0800 Subject: [PATCH 25/38] docker cmd --- snakemake_executor_plugin_aws_batch/__init__.py | 6 +++--- .../batch_job_builder.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index fdc3249..c70fa5b 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -4,7 +4,6 @@ __license__ = "MIT" from dataclasses import dataclass, field -import shlex from pprint import pformat from typing import List, AsyncGenerator, Optional from snakemake_executor_plugin_aws_batch.batch_client import BatchClient @@ -132,8 +131,7 @@ def run_job(self, job: JobExecutorInterface): # If required, make sure to pass the job's id to the job_info object, as keyword # argument 'external_job_id'. - remote_command = f"/bin/bash -c {shlex.quote(self.format_job_exec(job))}" - self.logger.debug(f"Remote command: {remote_command}") + remote_command = f"/bin/bash -c {self.format_job_exec(job)}" try: job_definition = BatchJobBuilder( @@ -212,10 +210,12 @@ def _get_job_status(self, job: SubmittedJobInfo) -> tuple[int, Optional[str]]: exit_code = None log_stream_name = None job_desc = self._describer.describe(self.batch_client, job.external_jobid, 1) + self.logger.debug(f"JOB DESCRIPTION: {job_desc}") job_status = job_desc["status"] # set log stream name if not none log_details = {"status": job_status, "jobId": job.external_jobid} + self.logger.debug(f"LOG DETAILS: {log_details}") if "container" in job_desc and "logStreamName" in job_desc["container"]: log_stream_name = job_desc["container"]["logStreamName"] diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 245c911..30c6282 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -1,4 +1,6 @@ import uuid +import shlex +from typing import List from snakemake.exceptions import WorkflowError from snakemake_interface_executor_plugins.jobs import ( JobExecutorInterface, @@ -43,6 +45,12 @@ def __init__( self.batch_client = batch_client self.created_job_defs = [] + def _make_container_command(remote_command: str) -> List[str]: + """ + Return docker CMD form of the command + """ + return [shlex.quote(part) for part in shlex.split(str)] + def build_job_definition(self): job_uuid = str(uuid.uuid4()) job_name = f"snakejob-{self.job.name}-{job_uuid}" @@ -55,7 +63,8 @@ def build_job_definition(self): container_properties = { "image": self.container_image, - "command": [self.job_command], + # command requires a list of strings ( docker CMD format ) + "command": self._make_container_command(self.job_command), "jobRoleArn": self.settings.job_role, "privileged": True, "resourceRequirements": [ From b91456403169a77460b8e0fd93df69c35bddc73f Mon Sep 17 00:00:00 2001 From: jakevc Date: Mon, 27 Jan 2025 21:33:09 -0800 Subject: [PATCH 26/38] its alive --- .../__init__.py | 11 ++++++----- .../batch_client.py | 18 ++++++++++++++++++ .../batch_job_builder.py | 4 ++-- terraform/vars.tf | 2 +- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index c70fa5b..f28b871 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -111,6 +111,9 @@ def __post_init__(self): # snakemake/snakemake:latest container image self.container_image = self.workflow.remote_execution_settings.container_image + # set the rate limit for status checks + self.next_seconds_between_status_checks = 5 + # access executor specific settings self.settings = self.workflow.executor_settings self.logger.debug(f"ExecutorSettings: {pformat(self.settings, indent=2)}") @@ -151,7 +154,6 @@ def run_job(self, job: JobExecutorInterface): except Exception as e: raise WorkflowError(e) - self.logger.debug(f"Job info: {pformat(job_info, indent=2)}") self.report_job_submission( SubmittedJobInfo( job=job, external_jobid=job_info["jobId"], aux=dict(job_info) @@ -185,11 +187,12 @@ async def check_active_jobs( self.logger.debug(f"Monitoring {len(active_jobs)} active Batch jobs") for job in active_jobs: async with self.status_rate_limiter: - status_code = self._get_job_status(job) + status_code, msg = self._get_job_status(job) if status_code == 0: self.report_job_success(job) elif status_code is not None: - self.report_job_error(job) + message = f"AWS Batch job failed. Code: {status_code}, Msg: {msg}." + self.report_job_error(job, msg=message) else: yield job @@ -210,12 +213,10 @@ def _get_job_status(self, job: SubmittedJobInfo) -> tuple[int, Optional[str]]: exit_code = None log_stream_name = None job_desc = self._describer.describe(self.batch_client, job.external_jobid, 1) - self.logger.debug(f"JOB DESCRIPTION: {job_desc}") job_status = job_desc["status"] # set log stream name if not none log_details = {"status": job_status, "jobId": job.external_jobid} - self.logger.debug(f"LOG DETAILS: {log_details}") if "container" in job_desc and "logStreamName" in job_desc["container"]: log_stream_name = job_desc["container"]["logStreamName"] diff --git a/snakemake_executor_plugin_aws_batch/batch_client.py b/snakemake_executor_plugin_aws_batch/batch_client.py index bfbbea8..6644122 100644 --- a/snakemake_executor_plugin_aws_batch/batch_client.py +++ b/snakemake_executor_plugin_aws_batch/batch_client.py @@ -53,3 +53,21 @@ def describe_jobs(self, **kwargs): :return: The response from the describe_jobs method. """ return self.client.describe_jobs(**kwargs) + + def deregister_job_definition(self, **kwargs): + """ + Deregister a job definition in AWS Batch. + + :param kwargs: The keyword arguments passed to deregister_job_definition method. + :return: The response from the deregister_job_definition method. + """ + return self.client.deregister_job_definition(**kwargs) + + def terminate_job(self, **kwargs): + """ + Terminate a job in AWS Batch. + + :param kwargs: The keyword arguments to pass to the terminate_job method. + :return: The response from the terminate_job method. + """ + return self.client.terminate_job(**kwargs) diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 30c6282..0baf48f 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -45,11 +45,11 @@ def __init__( self.batch_client = batch_client self.created_job_defs = [] - def _make_container_command(remote_command: str) -> List[str]: + def _make_container_command(self, remote_command: str) -> List[str]: """ Return docker CMD form of the command """ - return [shlex.quote(part) for part in shlex.split(str)] + return [shlex.quote(part) for part in shlex.split(remote_command)] def build_job_definition(self): job_uuid = str(uuid.uuid4()) diff --git a/terraform/vars.tf b/terraform/vars.tf index a8246f9..8d7b12f 100644 --- a/terraform/vars.tf +++ b/terraform/vars.tf @@ -75,7 +75,7 @@ variable "instance_types" { variable "max_vcpus" { description = "The maximum number of vCPUs for the compute environment" type = number - default = 2 + default = 16 } variable "min_vcpus" { From df9fec329dd4565787352a2e1cb826bde49564e1 Mon Sep 17 00:00:00 2001 From: jakevc Date: Tue, 28 Jan 2025 21:01:44 -0800 Subject: [PATCH 27/38] refactor --- .../__init__.py | 108 ++++++------------ .../batch_client.py | 37 +++--- .../batch_descriptor.py | 71 ------------ .../batch_job_builder.py | 3 +- 4 files changed, 57 insertions(+), 162 deletions(-) delete mode 100644 snakemake_executor_plugin_aws_batch/batch_descriptor.py diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index f28b871..0e1f796 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -8,7 +8,6 @@ from typing import List, AsyncGenerator, Optional from snakemake_executor_plugin_aws_batch.batch_client import BatchClient from snakemake_executor_plugin_aws_batch.batch_job_builder import BatchJobBuilder -from snakemake_executor_plugin_aws_batch.batch_descriptor import BatchJobDescriber from snakemake_interface_executor_plugins.executors.base import SubmittedJobInfo from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor from snakemake_interface_executor_plugins.settings import ( @@ -120,7 +119,6 @@ def __post_init__(self): # keep track of job definitions self.created_job_defs = list() - self._describer = BatchJobDescriber() self.batch_client = BatchClient(region_name=self.settings.region) def run_job(self, job: JobExecutorInterface): @@ -197,75 +195,40 @@ async def check_active_jobs( yield job def _get_job_status(self, job: SubmittedJobInfo) -> tuple[int, Optional[str]]: - """poll for Batch job success or failure - - returns exits code and failure information if exit code is not 0 """ - known_status_list = [ - "SUBMITTED", - "PENDING", - "RUNNABLE", - "STARTING", - "RUNNING", - "SUCCEEDED", - "FAILED", - ] - exit_code = None - log_stream_name = None - job_desc = self._describer.describe(self.batch_client, job.external_jobid, 1) - job_status = job_desc["status"] - - # set log stream name if not none - log_details = {"status": job_status, "jobId": job.external_jobid} - - if "container" in job_desc and "logStreamName" in job_desc["container"]: - log_stream_name = job_desc["container"]["logStreamName"] - - if log_stream_name: - log_details["logStreamName"] = log_stream_name - - if job_status not in known_status_list: - self.logger.info(f"unknown job status {job_status} from AWS Batch") - self.logger.debug("log details: {log_details} with status: {job_status}") - - failure_info = None - if job_status == "SUCCEEDED": - return 0, failure_info - - elif job_status == "FAILED": - reason = job_desc.get("container", {}).get("reason", None) - status_reason = job_desc.get("statusReason", None) - failure_info = {"jobId": job.external_jobid} - - if reason: - failure_info["reason"] = reason - - if status_reason: - failure_info["statusReason"] = status_reason + Poll for Batch job status and return exit code and message if job is complete. - if log_stream_name: - failure_info["logStreamName"] = log_stream_name - - if ( - status_reason - and "Host EC2" in status_reason - and "terminated" in status_reason - ): - raise WorkflowError( - "AWS Batch job interrupted (likely spot instance termination)" - f"with error {failure_info}" - ) - - if "exitCode" not in job_desc.get("container", {}): - raise WorkflowError( - f"AWS Batch job failed with error {failure_info['statusReason']}. " - f"View log stream {failure_info['logStreamName']}", - ) - - exit_code = job_desc["container"]["exitCode"] - assert isinstance(exit_code, int) and exit_code != 0 - - return exit_code, failure_info + Returns: + tuple: (exit_code, failure_message) + """ + try: + response = self.batch_client.describe_jobs(jobs=[job.external_jobid]) + jobs = response.get("jobs", []) + + if not jobs: + return None, f"No job found with ID {job.external_jobid}" + + job_info: dict = jobs[0] + job_status = job_info.get("status", "UNKNOWN") + job.aux["job_definition_arn"] = job_info.get("jobDefinition", None) + exit_code = job_info.get("container", {}).get("exitCode", None) + + if job_status == "SUCCEEDED": + return 0, None + elif job_status == "FAILED": + reason = job_info.get("statusReason", "Unknown reason") + return exit_code or 1, reason + else: + log_info = { + "job_name": job_info.get("jobName", "unknown"), + "job_id": job.external_jobid, + "status": job_status, + } + self.logger.debug(log_info) + return None, None + except Exception as e: + self.logger.error(f"Error getting job status: {e}") + return None, str(e) def _terminate_job(self, job: SubmittedJobInfo): """terminate job from submitted job info""" @@ -284,9 +247,10 @@ def _terminate_job(self, job: SubmittedJobInfo): def _deregister_job(self, job: SubmittedJobInfo): """deregister batch job definition""" try: - job_def_arn = job.aux["jobDefArn"] - self.logger.debug(f"de-registering Batch job definition {job_def_arn}") - self.batch_client.deregister_job_definition(jobDefinition=job_def_arn) + job_def_arn = job.aux.get("job_definition_arn") + if job_def_arn is not None: + self.logger.debug(f"de-registering Batch job definition {job_def_arn}") + self.batch_client.deregister_job_definition(jobDefinition=job_def_arn) except Exception as e: # AWS expires job definitions after 6mo # so failing to delete them isn't fatal diff --git a/snakemake_executor_plugin_aws_batch/batch_client.py b/snakemake_executor_plugin_aws_batch/batch_client.py index 6644122..335caf6 100644 --- a/snakemake_executor_plugin_aws_batch/batch_client.py +++ b/snakemake_executor_plugin_aws_batch/batch_client.py @@ -1,40 +1,41 @@ import boto3 +BATCH_SERVICE_NAME = "batch" -class AWSClient: - def __init__(self, service_name, region_name=None): + +class BatchClient: + def __init__(self, region_name=None): """ Initialize an AWS client for a specific service using default credentials. - :param service_name: The name of the AWS service (e.g., 's3', 'ec2', 'dynamodb') :param region_name: The region name to use for the client (optional). """ - self.service_name = service_name + self.service_name = BATCH_SERVICE_NAME self.region_name = region_name - self.client = self.initialize_client() + self.client = self.initialize_batch_client() - def initialize_client(self): + def initialize_batch_client(self): """ - Create an AWS client using boto3 with the default credentials. + Create an AWS Batch Client using boto3 with the default credentials. :return: The boto3 client for the specified service. """ - if self.region_name: - client = boto3.client(self.service_name, region_name=self.region_name) - else: - client = boto3.client(self.service_name) - return client + try: + if self.region_name: + return boto3.client(self.service_name, region_name=self.region_name) + return boto3.client(self.service_name) + except Exception as e: + raise Exception(f"Failed to initialize {self.service_name} client: {e}") -# client class stub for AWS Batch -class BatchClient(AWSClient): - def __init__(self, region_name=None): + def register_job_definition(self, **kwargs): """ - Initialize an AWS Batch client using default credentials. + Register a job definition in AWS Batch. - :param region_name: The region name to use for the client (optional). + :param kwargs: The keyword arguments to pass to the register_job_definition method. + :return: The response from the register_job_definition method. """ - super().__init__("batch", region_name) + return self.client.register_job_definition(**kwargs) def submit_job(self, **kwargs): """ diff --git a/snakemake_executor_plugin_aws_batch/batch_descriptor.py b/snakemake_executor_plugin_aws_batch/batch_descriptor.py deleted file mode 100644 index f42080d..0000000 --- a/snakemake_executor_plugin_aws_batch/batch_descriptor.py +++ /dev/null @@ -1,71 +0,0 @@ -import threading -import heapq -import time - - -class BatchJobDescriber: - """ - This singleton class handles calling the AWS Batch DescribeJobs API with up to 100 - job IDs per request, then dispensing each job description to the thread interested - in it. This helps avoid AWS API request rate limits when tracking concurrent jobs. - """ - - JOBS_PER_REQUEST = 100 # maximum jobs per DescribeJob request - - def __init__(self): - self.lock = threading.Lock() - self.last_request_time = 0 - self.job_queue = [] - self.jobs = {} - - def describe(self, aws, job_id, period): - """get the latest Batch job description""" - while True: - with self.lock: - if job_id not in self.jobs: - # register new job to be described ASAP - heapq.heappush(self.job_queue, (0.0, job_id)) - self.jobs[job_id] = None - # update as many job descriptions as possible - self._update(aws, period) - # return the desired job description if we have it - desc = self.jobs[job_id] - if desc: - return desc - # otherwise wait (outside the lock) and try again - time.sleep(period / 4) - - def unsubscribe(self, job_id): - """unsubscribe from job_id when no longer interested""" - with self.lock: - if job_id in self.jobs: - del self.jobs[job_id] - - def _update(self, aws, period): - # if enough time has passed since our last DescribeJobs request - if time.time() - self.last_request_time >= period: - # take the N least-recently described jobs - job_ids = set() - assert self.job_queue - while self.job_queue and len(job_ids) < self.JOBS_PER_REQUEST: - job_id = heapq.heappop(self.job_queue)[1] - assert job_id not in job_ids - if job_id in self.jobs: - job_ids.add(job_id) - if not job_ids: - return - # describe them - try: - job_descs = aws.describe_jobs(jobs=list(job_ids)) - finally: - # always: bump last_request_time and re-enqueue these jobs - self.last_request_time = time.time() - for job_id in job_ids: - heapq.heappush(self.job_queue, (self.last_request_time, job_id)) - # update self.jobs with the new descriptions - for job_desc in job_descs["jobs"]: - job_ids.remove(job_desc["jobId"]) - self.jobs[job_desc["jobId"]] = job_desc - assert ( - not job_ids - ), "AWS Batch DescribeJobs didn't return all expected results" diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 0baf48f..21b3d29 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -8,6 +8,7 @@ from enum import Enum from snakemake_executor_plugin_aws_batch.batch_client import BatchClient + SNAKEMAKE_COMMAND = "snakemake" @@ -84,7 +85,7 @@ def build_job_definition(self): timeout = {"attemptDurationSeconds": self.settings.task_timeout} tags = self.settings.tags if isinstance(self.settings.tags, dict) else dict() try: - job_def = self.batch_client.client.register_job_definition( + job_def = self.batch_client.register_job_definition( jobDefinitionName=job_definition_name, type=BATCH_JOB_DEFINITION_TYPE.CONTAINER.value, containerProperties=container_properties, From 0cc969e76e7072c2721f0dca8894e2ad2776910d Mon Sep 17 00:00:00 2001 From: jakevc Date: Tue, 28 Jan 2025 21:20:38 -0800 Subject: [PATCH 28/38] tweak --- .../__init__.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index 0e1f796..80a3db2 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -64,7 +64,7 @@ class ExecutorSettings(ExecutorSettingsBase): }, ) task_timeout: Optional[int] = field( - default=60, + default=300, metadata={ "help": ( "Task timeout (seconds) will force AWS Batch to terminate " @@ -110,15 +110,11 @@ def __post_init__(self): # snakemake/snakemake:latest container image self.container_image = self.workflow.remote_execution_settings.container_image - # set the rate limit for status checks self.next_seconds_between_status_checks = 5 - # access executor specific settings self.settings = self.workflow.executor_settings self.logger.debug(f"ExecutorSettings: {pformat(self.settings, indent=2)}") - # keep track of job definitions - self.created_job_defs = list() self.batch_client = BatchClient(region_name=self.settings.region) def run_job(self, job: JobExecutorInterface): @@ -144,11 +140,12 @@ def run_job(self, job: JobExecutorInterface): batch_client=self.batch_client, ) job_info = job_definition.submit() - self.logger.debug( - "AWS Batch job submitted with queue {}, jobId {} and tags {}".format( - self.settings.job_queue, job_info["jobId"], self.settings.tags - ) - ) + log_info = { + "job_name:": job_info["jobName"], + "jobId": job_info["jobId"], + "job_queue": self.settings.job_queue, + } + self.logger.debug(f"AWS Batch job submitted: {log_info}") except Exception as e: raise WorkflowError(e) @@ -210,6 +207,8 @@ def _get_job_status(self, job: SubmittedJobInfo) -> tuple[int, Optional[str]]: job_info: dict = jobs[0] job_status = job_info.get("status", "UNKNOWN") + + # push the job_definition_arn to the aux dict for use in cleanup job.aux["job_definition_arn"] = job_info.get("jobDefinition", None) exit_code = job_info.get("container", {}).get("exitCode", None) From ac92dfd695145fe180b88e7bdf74cf8f11f7393e Mon Sep 17 00:00:00 2001 From: jakevc Date: Tue, 28 Jan 2025 21:26:16 -0800 Subject: [PATCH 29/38] log --- snakemake_executor_plugin_aws_batch/batch_client.py | 2 +- snakemake_executor_plugin_aws_batch/batch_job_builder.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/batch_client.py b/snakemake_executor_plugin_aws_batch/batch_client.py index 335caf6..40a8e5d 100644 --- a/snakemake_executor_plugin_aws_batch/batch_client.py +++ b/snakemake_executor_plugin_aws_batch/batch_client.py @@ -32,7 +32,7 @@ def register_job_definition(self, **kwargs): """ Register a job definition in AWS Batch. - :param kwargs: The keyword arguments to pass to the register_job_definition method. + :param kwargs: The keyword arguments to pass to register_job_definition method. :return: The response from the register_job_definition method. """ return self.client.register_job_definition(**kwargs) diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index 21b3d29..b1dd735 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -111,11 +111,6 @@ def submit(self): try: submitted = self.batch_client.submit_job(**job_params) - self.logger.debug( - "AWS Batch job submitted with queue {}, jobId {} and tags {}".format( - self.settings.job_queue, submitted["jobId"], self.settings.tags - ) - ) return submitted except Exception as e: raise WorkflowError(f"Failed to submit job: {e}") from e From 48c9f36571e710680115ca2b625397c2f0cc9d25 Mon Sep 17 00:00:00 2001 From: jakevc Date: Wed, 29 Jan 2025 10:00:19 -0800 Subject: [PATCH 30/38] mock success --- tests/tests_mocked_api.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/tests_mocked_api.py b/tests/tests_mocked_api.py index 408d9b7..ced657e 100644 --- a/tests/tests_mocked_api.py +++ b/tests/tests_mocked_api.py @@ -1,12 +1,14 @@ -from unittest.mock import AsyncMock, MagicMock, patch # noqa +from unittest.mock import AsyncMock, patch # noqa from tests import TestWorkflowsBase class TestWorkflowsMocked(TestWorkflowsBase): __test__ = True - @patch("boto3.client") - # TODO: patch run_job internals + @patch( + "snakemake_executor_plugin_aws_batch.Executor._get_job_status", + return_value=(0, "SUCCEEDED"), + ) @patch( "snakemake.dag.DAG.check_and_touch_output", new=AsyncMock(autospec=True), From f59fe23e620c28e9f716483aeb7481071db09bc2 Mon Sep 17 00:00:00 2001 From: jakevc Date: Wed, 29 Jan 2025 10:04:43 -0800 Subject: [PATCH 31/38] error hangling --- snakemake_executor_plugin_aws_batch/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index 80a3db2..466912a 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -147,7 +147,7 @@ def run_job(self, job: JobExecutorInterface): } self.logger.debug(f"AWS Batch job submitted: {log_info}") except Exception as e: - raise WorkflowError(e) + raise WorkflowError(f"Failed to submit AWS Batch job: {e}") from e self.report_job_submission( SubmittedJobInfo( From 8e52d3be34cef5164bff3cd09aced484239fcca5 Mon Sep 17 00:00:00 2001 From: jakevc Date: Wed, 29 Jan 2025 10:15:10 -0800 Subject: [PATCH 32/38] env --- .github/workflows/ci_mocked_api.yml | 4 ++++ .github/workflows/ci_true_api.yml | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/.github/workflows/ci_mocked_api.yml b/.github/workflows/ci_mocked_api.yml index bcc3f55..84dac2b 100644 --- a/.github/workflows/ci_mocked_api.yml +++ b/.github/workflows/ci_mocked_api.yml @@ -88,6 +88,10 @@ jobs: run: poetry install - name: Run pytest + env: + SNAKEMAKE_AWS_BATCH_REGION: "${{ secrets.SNAKEMAKE_AWS_BATCH_REGION }}" + SNAKEMAKE_AWS_BATCH_JOB_QUEUE: "${{ secrets.SNAKEMAKE_AWS_BATCH_JOB_QUEUE }}" + SNAKEMAKE_AWS_BATCH_JOB_ROLE: "${{ secrets.SNAKEMAKE_AWS_BATCH_JOB_ROLE }}" run: poetry run coverage run -m pytest tests/tests_mocked_api.py -v - name: Run Coverage diff --git a/.github/workflows/ci_true_api.yml b/.github/workflows/ci_true_api.yml index 3898b5f..5576de6 100644 --- a/.github/workflows/ci_true_api.yml +++ b/.github/workflows/ci_true_api.yml @@ -31,5 +31,9 @@ jobs: run: poetry install - name: Run pytest + env: + SNAKEMAKE_AWS_BATCH_REGION: "${{ secrets.SNAKEMAKE_AWS_BATCH_REGION }}" + SNAKEMAKE_AWS_BATCH_JOB_QUEUE: "${{ secrets.SNAKEMAKE_AWS_BATCH_JOB_QUEUE }}" + SNAKEMAKE_AWS_BATCH_JOB_ROLE: "${{ secrets.SNAKEMAKE_AWS_BATCH_JOB_ROLE }}" run: poetry run pytest tests/tests_true_api.py -v \ No newline at end of file From 460e4221aaf2ceec78a349a82eb624e11175be84 Mon Sep 17 00:00:00 2001 From: Jake VanCampen Date: Wed, 29 Jan 2025 10:20:24 -0800 Subject: [PATCH 33/38] Update snakemake_executor_plugin_aws_batch/__init__.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- snakemake_executor_plugin_aws_batch/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/snakemake_executor_plugin_aws_batch/__init__.py b/snakemake_executor_plugin_aws_batch/__init__.py index 466912a..a44c3c2 100644 --- a/snakemake_executor_plugin_aws_batch/__init__.py +++ b/snakemake_executor_plugin_aws_batch/__init__.py @@ -115,7 +115,10 @@ def __post_init__(self): self.settings = self.workflow.executor_settings self.logger.debug(f"ExecutorSettings: {pformat(self.settings, indent=2)}") - self.batch_client = BatchClient(region_name=self.settings.region) + try: + self.batch_client = BatchClient(region_name=self.settings.region) + except Exception as e: + raise WorkflowError(f"Failed to initialize AWS Batch client: {e}") from e def run_job(self, job: JobExecutorInterface): # Implement here how to run a job. From 9756fc79d7ed74c3eb6386a05b419bf41f5a5b75 Mon Sep 17 00:00:00 2001 From: Jake VanCampen Date: Wed, 29 Jan 2025 10:20:36 -0800 Subject: [PATCH 34/38] Update .github/workflows/ci_mocked_api.yml Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- .github/workflows/ci_mocked_api.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci_mocked_api.yml b/.github/workflows/ci_mocked_api.yml index 84dac2b..6889928 100644 --- a/.github/workflows/ci_mocked_api.yml +++ b/.github/workflows/ci_mocked_api.yml @@ -89,9 +89,9 @@ jobs: - name: Run pytest env: - SNAKEMAKE_AWS_BATCH_REGION: "${{ secrets.SNAKEMAKE_AWS_BATCH_REGION }}" - SNAKEMAKE_AWS_BATCH_JOB_QUEUE: "${{ secrets.SNAKEMAKE_AWS_BATCH_JOB_QUEUE }}" - SNAKEMAKE_AWS_BATCH_JOB_ROLE: "${{ secrets.SNAKEMAKE_AWS_BATCH_JOB_ROLE }}" + SNAKEMAKE_AWS_BATCH_REGION: "${{ secrets.SNAKEMAKE_AWS_BATCH_REGION }}" + SNAKEMAKE_AWS_BATCH_JOB_QUEUE: "${{ secrets.SNAKEMAKE_AWS_BATCH_JOB_QUEUE }}" + SNAKEMAKE_AWS_BATCH_JOB_ROLE: "${{ secrets.SNAKEMAKE_AWS_BATCH_JOB_ROLE }}" run: poetry run coverage run -m pytest tests/tests_mocked_api.py -v - name: Run Coverage From aa1a1f010e82dff30774071adb48021d9254de28 Mon Sep 17 00:00:00 2001 From: jakevc Date: Wed, 29 Jan 2025 10:42:52 -0800 Subject: [PATCH 35/38] OIDC --- .github/workflows/ci_mocked_api.yml | 11 +++++++++++ .github/workflows/ci_true_api.yml | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/.github/workflows/ci_mocked_api.yml b/.github/workflows/ci_mocked_api.yml index 6889928..626f463 100644 --- a/.github/workflows/ci_mocked_api.yml +++ b/.github/workflows/ci_mocked_api.yml @@ -9,6 +9,10 @@ on: env: PYTHON_VERSION: 3.11 +permissions: + id-token: write + contents: read + jobs: formatting: runs-on: ubuntu-latest @@ -87,6 +91,13 @@ jobs: - name: Install dependencies run: poetry install + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-region: "${{ secrets.SNAKEMAKE_AWS_BATCH_REGION }}" + role-to-assume: "{{ secrets.GH_AWS_ROLE_ARN }}" + role-session-name: "GitHubActions" + - name: Run pytest env: SNAKEMAKE_AWS_BATCH_REGION: "${{ secrets.SNAKEMAKE_AWS_BATCH_REGION }}" diff --git a/.github/workflows/ci_true_api.yml b/.github/workflows/ci_true_api.yml index 5576de6..a1bf379 100644 --- a/.github/workflows/ci_true_api.yml +++ b/.github/workflows/ci_true_api.yml @@ -6,6 +6,10 @@ on: env: PYTHON_VERSION: "3.11" +permissions: + id-token: write + contents: read + jobs: testing-true-api: runs-on: ubuntu-latest @@ -30,6 +34,13 @@ jobs: - name: Install Dependencies using Poetry run: poetry install + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-region: "${{ secrets.SNAKEMAKE_AWS_BATCH_REGION }}" + role-to-assume: "{{ secrets.GH_AWS_ROLE_ARN }}" + role-session-name: "GitHubActions" + - name: Run pytest env: SNAKEMAKE_AWS_BATCH_REGION: "${{ secrets.SNAKEMAKE_AWS_BATCH_REGION }}" From 93fab0b09a86c357431ef412725cbdc43e987588 Mon Sep 17 00:00:00 2001 From: jakevc Date: Wed, 29 Jan 2025 10:44:21 -0800 Subject: [PATCH 36/38] black --- tests/tests_mocked_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_mocked_api.py b/tests/tests_mocked_api.py index ced657e..00c9894 100644 --- a/tests/tests_mocked_api.py +++ b/tests/tests_mocked_api.py @@ -6,8 +6,8 @@ class TestWorkflowsMocked(TestWorkflowsBase): __test__ = True @patch( - "snakemake_executor_plugin_aws_batch.Executor._get_job_status", - return_value=(0, "SUCCEEDED"), + "snakemake_executor_plugin_aws_batch.Executor._get_job_status", + return_value=(0, "SUCCEEDED"), ) @patch( "snakemake.dag.DAG.check_and_touch_output", From 5da9542ce64331dfeb186d05443c23663275b851 Mon Sep 17 00:00:00 2001 From: jakevc Date: Wed, 29 Jan 2025 10:46:12 -0800 Subject: [PATCH 37/38] secret --- .github/workflows/ci_mocked_api.yml | 2 +- .github/workflows/ci_true_api.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_mocked_api.yml b/.github/workflows/ci_mocked_api.yml index 626f463..e841683 100644 --- a/.github/workflows/ci_mocked_api.yml +++ b/.github/workflows/ci_mocked_api.yml @@ -95,7 +95,7 @@ jobs: uses: aws-actions/configure-aws-credentials@v4 with: aws-region: "${{ secrets.SNAKEMAKE_AWS_BATCH_REGION }}" - role-to-assume: "{{ secrets.GH_AWS_ROLE_ARN }}" + role-to-assume: "${{ secrets.GH_AWS_ROLE_ARN }}" role-session-name: "GitHubActions" - name: Run pytest diff --git a/.github/workflows/ci_true_api.yml b/.github/workflows/ci_true_api.yml index a1bf379..8a64bbf 100644 --- a/.github/workflows/ci_true_api.yml +++ b/.github/workflows/ci_true_api.yml @@ -38,7 +38,7 @@ jobs: uses: aws-actions/configure-aws-credentials@v4 with: aws-region: "${{ secrets.SNAKEMAKE_AWS_BATCH_REGION }}" - role-to-assume: "{{ secrets.GH_AWS_ROLE_ARN }}" + role-to-assume: "${{ secrets.GH_AWS_ROLE_ARN }}" role-session-name: "GitHubActions" - name: Run pytest From 0ddf9bf0b1fcd1620adcbacbd0eefa0bae91a319 Mon Sep 17 00:00:00 2001 From: jakevc Date: Sun, 2 Feb 2025 20:33:43 -0800 Subject: [PATCH 38/38] MinioLocal --- .github/workflows/ci_mocked_api.yml | 16 ++++++++ tests/__init__.py | 62 ++++++++++++++++++++++++++++- tests/docker-compose.yml | 11 +++++ 3 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 tests/docker-compose.yml diff --git a/.github/workflows/ci_mocked_api.yml b/.github/workflows/ci_mocked_api.yml index e841683..e4bcada 100644 --- a/.github/workflows/ci_mocked_api.yml +++ b/.github/workflows/ci_mocked_api.yml @@ -88,6 +88,22 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} cache: poetry + - name: Set up Docker Compose + run: | + sudo apt-get update + sudo apt-get install docker-compose + - name: Start MinIO service + run: docker-compose -f tests/docker-compose.yml up -d + + - name: Install MinIO Client CLI + run: | + curl -O https://dl.min.io/client/mc/release/linux-amd64/mc + chmod +x mc + sudo mv mc /usr/local/bin/ + - name: Configure MinIO client + run: | + mc alias set minio http://localhost:9000 minio minio123 + - name: Install dependencies run: poetry install diff --git a/tests/__init__.py b/tests/__init__.py index aa901ce..fac3d40 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,12 +1,70 @@ import os -from typing import Optional +import uuid +from typing import Optional, Mapping import snakemake.common.tests from snakemake_interface_executor_plugins.settings import ExecutorSettingsBase from snakemake_executor_plugin_aws_batch import ExecutorSettings +from snakemake.common.tests import TestWorkflowsBase as TestWorkflowsBase +from snakemake_interface_common.plugin_registry.plugin import TaggedSettings +from snakemake_interface_common.utils import lazy_property -class TestWorkflowsBase(snakemake.common.tests.TestWorkflowsMinioPlayStorageBase): +class TestWorkflowsMinioLocalStorageBase(TestWorkflowsBase): + def get_default_storage_provider(self) -> Optional[str]: + return "s3" + + def get_default_storage_prefix(self) -> Optional[str]: + return f"s3://{self.bucket}" + + def get_default_storage_provider_settings( + self, + ) -> Optional[Mapping[str, TaggedSettings]]: + from snakemake_storage_plugin_s3 import StorageProviderSettings + + self._storage_provider_settings = StorageProviderSettings( + endpoint_url=self.endpoint_url, + access_key=self.access_key, + secret_key=self.secret_key, + ) + + tagged_settings = TaggedSettings() + tagged_settings.register_settings(self._storage_provider_settings) + return {"s3": tagged_settings} + + def cleanup_test(self): + import boto3 + + # clean up using boto3 + s3c = boto3.resource( + "s3", + endpoint_url=self.endpoint_url, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + ) + try: + s3c.Bucket(self.bucket).delete() + except Exception: + pass + + @lazy_property + def bucket(self): + return f"snakemake-{uuid.uuid4().hex}" + + @property + def endpoint_url(self): + return "http://127.0.0.1:9000" + + @property + def access_key(self): + return "minio" + + @property + def secret_key(self): + return "minio123" + + +class TestWorkflowsBase(TestWorkflowsMinioLocalStorageBase): def get_executor(self) -> str: return "aws-batch" diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml new file mode 100644 index 0000000..dbb3db4 --- /dev/null +++ b/tests/docker-compose.yml @@ -0,0 +1,11 @@ +version: '3.8' + +services: + minio: + image: minio/minio + ports: + - "9000:9000" + environment: + MINIO_ACCESS_KEY: minio + MINIO_SECRET_KEY: minio123 + command: server /data \ No newline at end of file