diff --git a/changelog.d/20230503_101434_yadudoc1729_unfork_htex_2.rst b/changelog.d/20230503_101434_yadudoc1729_unfork_htex_2.rst new file mode 100644 index 000000000..c9632b443 --- /dev/null +++ b/changelog.d/20230503_101434_yadudoc1729_unfork_htex_2.rst @@ -0,0 +1,43 @@ +.. A new scriv changelog fragment. +.. +.. Uncomment the header that is right (remove the leading dots). +.. +New Functionality +^^^^^^^^^^^^^^^^^ + +- Support for 3 new execution ``Engines``, designed to replace the ``HighThroughputExecutor`` + + - ``GlobusComputeEngine``: Wraps Parsl's ``HighThroughputExecutor`` to match the current + default executor (globus-computes' fork of ``HighThroughputExecutor``) + - ``ProcessPoolEngine``: Wraps ``concurrent.futures.ProcessPoolExecutor`` for concurrent + local execution + - ``ThreadPoolEngine``: Wraps ``concurrent.futures.ThreadPoolEngine`` for concurrent + local execution on MacOS. + +.. - A bullet item for the New Functionality category. +.. +.. Bug Fixes +.. ^^^^^^^^^ +.. +.. - A bullet item for the Bug Fixes category. +.. +.. Removed +.. ^^^^^^^ +.. +.. - A bullet item for the Removed category. +.. +.. Deprecated +.. ^^^^^^^^^^ +.. +.. - A bullet item for the Deprecated category. +.. +.. Changed +.. ^^^^^^^ +.. +.. - A bullet item for the Changed category. +.. +.. Security +.. ^^^^^^^^ +.. +.. - A bullet item for the Security category. +.. diff --git a/compute_endpoint/globus_compute_endpoint/engines/__init__.py b/compute_endpoint/globus_compute_endpoint/engines/__init__.py new file mode 100644 index 000000000..7e4d748c6 --- /dev/null +++ b/compute_endpoint/globus_compute_endpoint/engines/__init__.py @@ -0,0 +1,9 @@ +from globus_compute_endpoint.engines.globus_compute import GlobusComputeEngine +from globus_compute_endpoint.engines.process_pool import ProcessPoolEngine +from globus_compute_endpoint.engines.thread_pool import ThreadPoolEngine + +__all__ = [ + "GlobusComputeEngine", + "ProcessPoolEngine", + "ThreadPoolEngine", +] diff --git a/compute_endpoint/globus_compute_endpoint/engines/base.py b/compute_endpoint/globus_compute_endpoint/engines/base.py new file mode 100644 index 000000000..92106693d --- /dev/null +++ b/compute_endpoint/globus_compute_endpoint/engines/base.py @@ -0,0 +1,176 @@ +import logging +import queue +import threading +import time +import typing as t +import uuid +from abc import ABC, abstractmethod +from concurrent.futures import Future + +from globus_compute_common import messagepack +from globus_compute_common.messagepack.message_types import ( + EPStatusReport, + Result, + TaskTransition, +) +from globus_compute_common.tasks import ActorName, TaskState +from globus_compute_endpoint.engines.helper import execute_task +from globus_compute_endpoint.exception_handling import ( + get_error_string, + get_result_error_details, +) + +logger = logging.getLogger(__name__) + + +class ReportingThread: + def __init__( + self, target: t.Callable, args: t.List, reporting_period: float = 30.0 + ): + """This class wraps threading.Thread to run a callable in a loop + periodically until the user calls `stop`. A status attribute can + report exceptions to the parent thread upon failure. + Parameters + ---------- + target: Target function to be invoked to get report and post to queue + args: args to be passed to target fn + kwargs: kwargs to be passed to target fn + reporting_period + """ + self.status: Future = Future() + self._shutdown_event = threading.Event() + self.reporting_period = reporting_period + self._thread = threading.Thread( + target=self.run_in_loop, args=[target] + args, name="GCReportingThread" + ) + + def start(self): + logger.info("Start called") + self._thread.start() + + def run_in_loop(self, target: t.Callable, *args) -> None: + while True: + try: + target(*args) + except Exception as e: + # log and update future before exiting, if it is not already set + self.status.set_exception(exception=e) + self._shutdown_event.set() + if self._shutdown_event.wait(timeout=self.reporting_period): + break + + logger.warning("ReportingThread exiting") + + def stop(self) -> None: + self._shutdown_event.set() + self._thread.join(timeout=0.1) + + +class GlobusComputeEngineBase(ABC): + """Shared functionality and interfaces required by all GlobusCompute Engines. + This is designed to plug-in executors following the concurrent.futures.Executor + interface as execution backends to GlobusCompute + """ + + def __init__( + self, + *args: object, + heartbeat_period_s: float = 30.0, + endpoint_id: t.Optional[uuid.UUID] = None, + **kwargs: object, + ): + self._shutdown_event = threading.Event() + self._heartbeat_period_s = heartbeat_period_s + self.endpoint_id = endpoint_id + + # remove these unused vars that we are adding to just keep + # endpoint interchange happy + self.container_type: t.Optional[str] = None + self.funcx_service_address: t.Optional[str] = None + self.run_dir: t.Optional[str] = None + # This attribute could be set by the subclasses in their + # start method if another component insists on owning the queue. + self.results_passthrough: queue.Queue = queue.Queue() + + @abstractmethod + def start( + self, + *args, + **kwargs, + ) -> None: + raise NotImplementedError + + @abstractmethod + def get_status_report(self) -> EPStatusReport: + raise NotImplementedError + + def report_status(self): + status_report = self.get_status_report() + packed_status = messagepack.pack(status_report) + self.results_passthrough.put(packed_status) + + def _status_report( + self, shutdown_event: threading.Event, heartbeat_period_s: float + ): + while not shutdown_event.wait(timeout=heartbeat_period_s): + status_report = self.get_status_report() + packed = messagepack.pack(status_report) + self.results_passthrough.put(packed) + + def _future_done_callback(self, future: Future): + """Callback to post result to the passthrough queue + Parameters + ---------- + future: Future for which the callback is triggerd + """ + + if future.exception(): + code, user_message = get_result_error_details() + error_details = {"code": code, "user_message": user_message} + exec_end = TaskTransition( + timestamp=time.time_ns(), + state=TaskState.EXEC_END, + actor=ActorName.WORKER, + ) + result_message = dict( + task_id=future.task_id, # type: ignore + data=get_error_string(), + exception=get_error_string(), + error_details=error_details, + task_statuses=[exec_end], # We don't have any more info transitions + ) + packed_result = messagepack.pack(Result(**result_message)) + else: + packed_result = future.result() + + self.results_passthrough.put(packed_result) + + @abstractmethod + def _submit( + self, + func: t.Callable, + *args: t.Any, + **kwargs: t.Any, + ) -> Future: + """Subclass should use the internal execution system to implement this""" + raise NotImplementedError() + + def submit(self, task_id: uuid.UUID, packed_task: bytes) -> Future: + """GC Endpoints should submit tasks via this method so that tasks are + tracked properly. + Parameters + ---------- + packed_task: messagepack bytes buffer + Returns + ------- + future + """ + + future: Future = self._submit(execute_task, packed_task) + + # Executors mark futures are failed in the event of faults + # We need to tie the task_id info into the future to identify + # which tasks have failed + future.task_id = task_id # type: ignore + future.add_done_callback(self._future_done_callback) + return future diff --git a/compute_endpoint/globus_compute_endpoint/engines/globus_compute.py b/compute_endpoint/globus_compute_endpoint/engines/globus_compute.py new file mode 100644 index 000000000..c901b82fc --- /dev/null +++ b/compute_endpoint/globus_compute_endpoint/engines/globus_compute.py @@ -0,0 +1,113 @@ +import logging +import multiprocessing +import os +import typing as t +import uuid +from concurrent.futures import Future + +from globus_compute_common.messagepack.message_types import ( + EPStatusReport, + TaskTransition, +) +from globus_compute_endpoint.engines.base import ( + GlobusComputeEngineBase, + ReportingThread, +) +from parsl.executors.high_throughput.executor import HighThroughputExecutor + +logger = logging.getLogger(__name__) + + +class GlobusComputeEngine(GlobusComputeEngineBase): + def __init__( + self, + *args, + label: str = "GlobusComputeEngine", + address: t.Optional[str] = None, + heartbeat_period_s: float = 30.0, + **kwargs, + ): + self.address = address + self.run_dir = os.getcwd() + self.label = label + self._status_report_thread = ReportingThread( + target=self.report_status, args=[], reporting_period=heartbeat_period_s + ) + super().__init__(*args, heartbeat_period_s=heartbeat_period_s, **kwargs) + self.executor = HighThroughputExecutor( # type: ignore + *args, address=address, **kwargs + ) + + def start( + self, + *args, + endpoint_id: t.Optional[uuid.UUID] = None, + run_dir: t.Optional[str] = None, + results_passthrough: t.Optional[multiprocessing.Queue] = None, + **kwargs, + ): + assert run_dir, "GCExecutor requires kwarg:run_dir at start" + assert endpoint_id, "GCExecutor requires kwarg:endpoint_id at start" + self.run_dir = os.path.join(os.getcwd(), run_dir) + self.endpoint_id = endpoint_id + self.executor.provider.script_dir = os.path.join(self.run_dir, "submit_scripts") + os.makedirs(self.executor.provider.script_dir, exist_ok=True) + if results_passthrough: + # Only update the default queue in GCExecutorBase if + # a queue is passed in + self.results_passthrough = results_passthrough + self.executor.start() + self._status_report_thread.start() + + def _submit( + self, + func: t.Callable, + *args: t.Any, + **kwargs: t.Any, + ) -> Future: + return self.executor.submit(func, {}, *args, **kwargs) + + def get_status_report(self) -> EPStatusReport: + """ + endpoint_id: uuid.UUID + ep_status_report: t.Dict[str, t.Any] + task_statuses: t.Dict[str, t.List[TaskTransition]] + Returns + ------- + """ + executor_status: t.Dict[str, t.Any] = { + "task_id": -2, + "info": { + "total_cores": 0, + "total_mem": 0, + "new_core_hrs": 0, + "total_core_hrs": 0, + "managers": 0, + "active_managers": 0, + "total_workers": 0, + "idle_workers": 0, + "pending_tasks": 0, + "outstanding_tasks": 0, + "worker_mode": 0, + "scheduler_mode": 0, + "scaling_enabled": False, + "mem_per_worker": 0, + "cores_per_worker": 0, + "prefetch_capacity": 0, + "max_blocks": 1, + "min_blocks": 1, + "max_workers_per_node": 0, + "nodes_per_block": 1, + "heartbeat_period": self._heartbeat_period_s, + }, + } + task_status_deltas: t.Dict[str, t.List[TaskTransition]] = {} + return EPStatusReport( + endpoint_id=self.endpoint_id, + ep_status_report=executor_status, + task_statuses=task_status_deltas, + ) + + def shutdown(self): + self._status_report_thread.stop() + return self.executor.shutdown() diff --git a/compute_endpoint/globus_compute_endpoint/engines/helper.py b/compute_endpoint/globus_compute_endpoint/engines/helper.py new file mode 100644 index 000000000..bddf0d14c --- /dev/null +++ b/compute_endpoint/globus_compute_endpoint/engines/helper.py @@ -0,0 +1,128 @@ +import logging +import time +import typing as t +import uuid + +from globus_compute_common import messagepack +from globus_compute_common.messagepack.message_types import Result, Task, TaskTransition +from globus_compute_common.tasks import ActorName, TaskState +from globus_compute_endpoint.exception_handling import ( + get_error_string, + get_result_error_details, +) +from globus_compute_endpoint.exceptions import CouldNotExecuteUserTaskError +from globus_compute_endpoint.executors.high_throughput.messages import Message +from globus_compute_sdk.errors import MaxResultSizeExceeded +from globus_compute_sdk.serialize import ComputeSerializer + +log = logging.getLogger(__name__) + +serializer = ComputeSerializer() + + +def execute_task(task_body: bytes, result_size_limit: int = 10 * 1024 * 1024) -> bytes: + """Execute task is designed to enable any executor to execute a Task payload + and return a Result payload, where the payload follows the globus-compute protocols + This method is placed here to make serialization easy for executor classes + Parameters + ---------- + task_id: uuid string + task_body: packed message as bytes + result_size_limit: result size in bytes + Returns + ------- + messagepack packed Result + """ + exec_start = TaskTransition( + timestamp=time.time_ns(), state=TaskState.EXEC_START, actor=ActorName.WORKER + ) + + result_message: dict[ + str, + t.Union[uuid.UUID, str, tuple[str, str], list[TaskTransition], dict[str, str]], + ] = {} + + try: + task, task_buffer = _unpack_messagebody(task_body) + log.debug("executing task task_id='%s'", task.task_id) + result = _call_user_function(task_buffer, result_size_limit=result_size_limit) + log.debug("Execution completed without exception") + result_message = dict(task_id=task.task_id, data=result) + + except Exception: + log.exception("Caught an exception while executing user function") + code, user_message = get_result_error_details() + error_details = {"code": code, "user_message": user_message} + result_message = dict( + task_id=task.task_id, + data=get_error_string(), + exception=get_error_string(), + error_details=error_details, + ) + + exec_end = TaskTransition( + timestamp=time.time_ns(), + state=TaskState.EXEC_END, + actor=ActorName.WORKER, + ) + + result_message["task_statuses"] = [exec_start, exec_end] + + log.debug( + "task %s completed in %d ns", + task.task_id, + (exec_end.timestamp - exec_start.timestamp), + ) + + return messagepack.pack(Result(**result_message)) + + +def _unpack_messagebody(message: bytes) -> t.Tuple[Task, str]: + """Unpack messagebody as a messagepack message with + some legacy handling + Parameters + ---------- + message: messagepack'ed message body + Returns + ------- + tuple(task, task_buffer) + """ + try: + task = messagepack.unpack(message) + if not isinstance(task, messagepack.message_types.Task): + raise CouldNotExecuteUserTaskError( + f"wrong type of message in worker: {type(task)}" + ) + task_buffer = task.task_buffer + # on parse errors, failover to trying the "legacy" message reading + except ( + messagepack.InvalidMessageError, + messagepack.UnrecognizedProtocolVersion, + ): + task = Message.unpack(message) + assert isinstance(task, Task) + task_buffer = task.task_buffer.decode("utf-8") # type: ignore[attr-defined] + return task, task_buffer + + +def _call_user_function( + task_buffer: str, result_size_limit: int, serializer=serializer +) -> str: + """Deserialize the buffer and execute the task. + Parameters + ---------- + task_buffer: serialized buffer of (fn, args, kwargs) + result_size_limit: size limit in bytes for results + serializer: serializer for the buffers + Returns + ------- + Returns serialized result or throws exception. + """ + f, args, kwargs = serializer.unpack_and_deserialize(task_buffer) + result_data = f(*args, **kwargs) + serialized_data = serializer.serialize(result_data) + + if len(serialized_data) > result_size_limit: + raise MaxResultSizeExceeded(len(serialized_data), result_size_limit) + + return serialized_data diff --git a/compute_endpoint/globus_compute_endpoint/engines/process_pool.py b/compute_endpoint/globus_compute_endpoint/engines/process_pool.py new file mode 100644 index 000000000..d35af81a8 --- /dev/null +++ b/compute_endpoint/globus_compute_endpoint/engines/process_pool.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import logging +import multiprocessing +import typing as t +import uuid +from concurrent.futures import Future +from concurrent.futures import ProcessPoolExecutor as NativeExecutor +from multiprocessing.queues import Queue as mpQueue + +import psutil +from globus_compute_common.messagepack.message_types import ( + EPStatusReport, + TaskTransition, +) +from globus_compute_endpoint.engines.base import ( + GlobusComputeEngineBase, + ReportingThread, +) + +logger = logging.getLogger(__name__) + + +class ProcessPoolEngine(GlobusComputeEngineBase): + def __init__( + self, + *args, + label: str = "ProcessPoolEngine", + heartbeat_period_s: float = 30.0, + **kwargs, + ): + self.label = label + self.executor = NativeExecutor(*args, **kwargs) + self._status_report_thread = ReportingThread( + target=self.report_status, args=[], reporting_period=heartbeat_period_s + ) + super().__init__(*args, heartbeat_period_s=heartbeat_period_s, **kwargs) + + def start( + self, + *args, + endpoint_id: t.Optional[uuid.UUID] = None, + results_passthrough: t.Optional[mpQueue] = None, + **kwargs, + ) -> None: + """ + Parameters + ---------- + endpoint_id: Endpoint UUID + results_passthrough: Queue to which packed results will be posted + run_dir Not used + Returns + ------- + """ + assert endpoint_id, "ProcessPoolExecutor requires kwarg:endpoint_id at start" + self.endpoint_id = endpoint_id + if results_passthrough: + self.results_passthrough = results_passthrough + assert self.results_passthrough + + # mypy think the thread can be none + self._status_report_thread.start() + + def get_status_report(self) -> EPStatusReport: + """ + endpoint_id: uuid.UUID + ep_status_report: t.Dict[str, t.Any] + task_statuses: t.Dict[str, t.List[TaskTransition]] + Returns + ------- + """ + executor_status: t.Dict[str, t.Any] = { + "task_id": -2, + "info": { + "total_cores": multiprocessing.cpu_count(), + "total_mem": round(psutil.virtual_memory().available / (2**30), 1), + "total_core_hrs": 0, + "total_workers": self.executor._max_workers, # type: ignore + "pending_tasks": 0, + "outstanding_tasks": 0, + "scaling_enabled": False, + "max_blocks": 1, + "min_blocks": 1, + "max_workers_per_node": self.executor._max_workers, # type: ignore + "nodes_per_block": 1, + "heartbeat_period": self._heartbeat_period_s, + }, + } + task_status_deltas: t.Dict[str, t.List[TaskTransition]] = {} + + return EPStatusReport( + endpoint_id=self.endpoint_id, + ep_status_report=executor_status, + task_statuses=task_status_deltas, + ) + + def _submit( + self, + func: t.Callable, + *args: t.Any, + **kwargs: t.Any, + ) -> Future: + """We basically pass all params except the resource_specification + over to executor.submit + """ + logger.warning("Got task") + return self.executor.submit(func, *args, **kwargs) + + def status_polling_interval(self) -> int: + return 30 + + def scale_out(self, blocks: int) -> list[str]: + return [] + + def scale_in(self, blocks: int) -> list[str]: + return [] + + def status(self) -> dict: + return {} + + def shutdown(self): + self._status_report_thread.stop() + self.executor.shutdown() diff --git a/compute_endpoint/globus_compute_endpoint/engines/thread_pool.py b/compute_endpoint/globus_compute_endpoint/engines/thread_pool.py new file mode 100644 index 000000000..d3dab431f --- /dev/null +++ b/compute_endpoint/globus_compute_endpoint/engines/thread_pool.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import logging +import multiprocessing +import typing as t +import uuid +from concurrent.futures import Future +from concurrent.futures import ThreadPoolExecutor as NativeExecutor +from multiprocessing.queues import Queue as mpQueue + +import psutil +from globus_compute_common.messagepack.message_types import ( + EPStatusReport, + TaskTransition, +) +from globus_compute_endpoint.engines.base import ( + GlobusComputeEngineBase, + ReportingThread, +) + +logger = logging.getLogger(__name__) + + +class ThreadPoolEngine(GlobusComputeEngineBase): + def __init__( + self, + *args, + label: str = "ThreadPoolEngine", + heartbeat_period_s: float = 30.0, + **kwargs, + ): + self.label = label + self.executor = NativeExecutor(*args, **kwargs) + self._status_report_thread = ReportingThread( + target=self.report_status, args=[], reporting_period=heartbeat_period_s + ) + super().__init__(*args, heartbeat_period_s=heartbeat_period_s, **kwargs) + + def start( + self, + *args, + endpoint_id: t.Optional[uuid.UUID] = None, + results_passthrough: t.Optional[mpQueue] = None, + **kwargs, + ) -> None: + """ + Parameters + ---------- + endpoint_id: Endpoint UUID + results_passthrough: Queue to which packed results will be posted + run_dir Not used + Returns + ------- + """ + assert endpoint_id, "ThreadPoolEngine requires kwarg:endpoint_id at start" + self.endpoint_id = endpoint_id + if results_passthrough: + self.results_passthrough = results_passthrough + assert self.results_passthrough + + # mypy think the thread can be none + self._status_report_thread.start() + + def get_status_report(self) -> EPStatusReport: + """ + endpoint_id: uuid.UUID + ep_status_report: t.Dict[str, t.Any] + task_statuses: t.Dict[str, t.List[TaskTransition]] + Returns + ------- + """ + executor_status: t.Dict[str, t.Any] = { + "task_id": -2, + "info": { + "total_cores": multiprocessing.cpu_count(), + "total_mem": round(psutil.virtual_memory().available / (2**30), 1), + "total_core_hrs": 0, + "total_workers": self.executor._max_workers, # type: ignore + "pending_tasks": 0, + "outstanding_tasks": 0, + "scaling_enabled": False, + "max_blocks": 1, + "min_blocks": 1, + "max_workers_per_node": self.executor._max_workers, # type: ignore + "nodes_per_block": 1, + "heartbeat_period": self._heartbeat_period_s, + }, + } + task_status_deltas: t.Dict[str, t.List[TaskTransition]] = {} + + return EPStatusReport( + endpoint_id=self.endpoint_id, + ep_status_report=executor_status, + task_statuses=task_status_deltas, + ) + + def _submit( + self, + func: t.Callable, + *args: t.Any, + **kwargs: t.Any, + ) -> Future: + """We basically pass all params except the resource_specification + over to executor.submit + """ + logger.warning("Got task") + return self.executor.submit(func, *args, **kwargs) + + def status_polling_interval(self) -> int: + return 30 + + def scale_out(self, blocks: int) -> list[str]: + return [] + + def scale_in(self, blocks: int) -> list[str]: + return [] + + def status(self) -> dict: + return {} + + def shutdown(self): + self._status_report_thread.stop() + self.executor.shutdown() diff --git a/compute_endpoint/tests/unit/test_engines.py b/compute_endpoint/tests/unit/test_engines.py new file mode 100644 index 000000000..99ddb2063 --- /dev/null +++ b/compute_endpoint/tests/unit/test_engines.py @@ -0,0 +1,207 @@ +import concurrent.futures +import logging +import multiprocessing +import os +import random +import shutil +import time +import uuid +from queue import Queue + +import pytest +from globus_compute_common import messagepack +from globus_compute_common.messagepack.message_types import TaskTransition +from globus_compute_common.tasks import ActorName, TaskState +from globus_compute_endpoint.engines import ( + GlobusComputeEngine, + ProcessPoolEngine, + ThreadPoolEngine, +) +from globus_compute_sdk.serialize import ComputeSerializer +from parsl.executors.high_throughput.interchange import ManagerLost +from tests.utils import double, ez_pack_function, slow_double + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def proc_pool_engine(): + ep_id = uuid.uuid4() + engine = ProcessPoolEngine( + label="ProcessPoolEngine", heartbeat_period_s=1, max_workers=2 + ) + queue = multiprocessing.Queue() + engine.start(endpoint_id=ep_id, run_dir="/tmp", results_passthrough=queue) + + yield engine + engine.shutdown() + + +@pytest.fixture +def thread_pool_engine(): + ep_id = uuid.uuid4() + engine = ThreadPoolEngine(heartbeat_period_s=1, max_workers=2) + + queue = Queue() + engine.start(endpoint_id=ep_id, run_dir="/tmp", results_passthrough=queue) + + yield engine + engine.shutdown() + + +@pytest.fixture +def gc_engine(): + ep_id = uuid.uuid4() + engine = GlobusComputeEngine( + address="127.0.0.1", heartbeat_period_s=1, heartbeat_threshold=1 + ) + queue = multiprocessing.Queue() + tempdir = "/tmp/HTEX_logs" + os.makedirs(tempdir, exist_ok=True) + engine.start(endpoint_id=ep_id, run_dir=tempdir, results_passthrough=queue) + + yield engine + engine.shutdown() + shutil.rmtree(tempdir, ignore_errors=True) + + +def test_result_message_packing(): + exec_start = TaskTransition( + timestamp=time.time_ns(), state=TaskState.EXEC_START, actor=ActorName.WORKER + ) + + serializer = ComputeSerializer() + task_id = uuid.uuid1() + result = random.randint(0, 1000) + + exec_end = TaskTransition( + timestamp=time.time_ns(), state=TaskState.EXEC_END, actor=ActorName.WORKER + ) + result_message = dict( + task_id=task_id, + data=serializer.serialize(result), + task_statuses=[exec_start, exec_end], + ) + + mResult = messagepack.message_types.Result(**result_message) + assert isinstance(mResult, messagepack.message_types.Result) + packed_result = messagepack.pack(mResult) + assert isinstance(packed_result, bytes) + + unpacked = messagepack.unpack(packed_result) + assert isinstance(unpacked, messagepack.message_types.Result) + # assert unpacked. + logger.warning(f"Type of unpacked : {unpacked}") + assert unpacked.task_id == task_id + assert serializer.deserialize(unpacked.data) == result + + +# Skipping "gc_engine" since parsl.htex.submit has an +# unreliable serialization method. +# @pytest.mark.parametrize("x", ["gc_engine"]) +@pytest.mark.parametrize("x", ["proc_pool_engine"]) +def test_engine_submit(x, gc_engine, proc_pool_engine): + "Test engine.submit with multiple engines" + if x == "gc_engine": + engine = gc_engine + else: + engine = proc_pool_engine + + param = random.randint(1, 100) + future = engine._submit(double, param) + assert isinstance(future, concurrent.futures.Future) + logger.warning(f"Got result: {future.result()}") + assert future.result() == param * 2 + + +@pytest.mark.parametrize("x", ["gc_engine", "proc_pool_engine"]) +def test_engine_submit_internal(x, gc_engine, proc_pool_engine): + if x == "gc_engine": + engine = gc_engine + else: + engine = proc_pool_engine + + q = engine.results_passthrough + task_id = uuid.uuid1() + serializer = ComputeSerializer() + task_body = ez_pack_function(serializer, double, (3,), {}) + task_message = messagepack.pack( + messagepack.message_types.Task( + task_id=task_id, container_id=uuid.uuid1(), task_buffer=task_body + ) + ) + future = engine.submit(task_id, task_message) + packed_result = future.result() + + # Confirm that the future got the right answer + assert isinstance(packed_result, bytes) + result = messagepack.unpack(packed_result) + assert isinstance(result, messagepack.message_types.Result) + assert result.task_id == task_id + + # Confirm that the same result got back though the queue + for _i in range(3): + packed_result_q = q.get(timeout=0.1) + assert isinstance( + packed_result_q, bytes + ), "Expected bytes from the passthrough_q" + + result = messagepack.unpack(packed_result_q) + # Handle a sneaky EPStatusReport that popped in ahead of the result + if isinstance(result, messagepack.message_types.EPStatusReport): + continue + + # At this point the message should be the result + assert ( + packed_result == packed_result_q + ), "Result from passthrough_q and future should match" + + assert result.task_id == task_id + final_result = serializer.deserialize(result.data) + assert final_result == 6, f"Expected 6, but got: {final_result}" + break + + +def test_gc_engine_system_failure(gc_engine): + """Test behavior of engine failure killing task""" + param = random.randint(1, 100) + q = gc_engine.results_passthrough + task_id = uuid.uuid1() + serializer = ComputeSerializer() + # We want the task to be running when we kill the manager + task_body = ez_pack_function( + serializer, + slow_double, + ( + param, + 5, + ), + {}, + ) + task_message = messagepack.pack( + messagepack.message_types.Task( + task_id=task_id, container_id=uuid.uuid1(), task_buffer=task_body + ) + ) + future = gc_engine.submit(task_id, task_message) + + assert isinstance(future, concurrent.futures.Future) + # Trigger a failure from managers scaling in. + gc_engine.executor.scale_in(blocks=1) + # We need to scale out to make sure following tests will not be affected + gc_engine.executor.scale_out(blocks=1) + with pytest.raises(ManagerLost): + future.result() + + for _i in range(10): + packed_result = q.get(timeout=1) + assert packed_result + + result = messagepack.unpack(packed_result) + if isinstance(result, messagepack.message_types.EPStatusReport): + continue + else: + assert result.task_id == task_id + assert result.error_details + assert "ManagerLost" in result.data + break diff --git a/compute_endpoint/tests/unit/test_execute_task.py b/compute_endpoint/tests/unit/test_execute_task.py new file mode 100644 index 000000000..eb973c7fe --- /dev/null +++ b/compute_endpoint/tests/unit/test_execute_task.py @@ -0,0 +1,61 @@ +import logging +import uuid + +from globus_compute_common import messagepack +from globus_compute_endpoint.engines.helper import execute_task +from globus_compute_sdk.serialize import ComputeSerializer +from tests.utils import ez_pack_function + +logger = logging.getLogger(__name__) + + +def divide(x, y): + return x / y + + +def test_execute_task(): + serializer = ComputeSerializer() + task_id = uuid.uuid1() + input, output = (10, 2), 5 + task_body = ez_pack_function(serializer, divide, input, {}) + + task_message = messagepack.pack( + messagepack.message_types.Task( + task_id=task_id, container_id=uuid.uuid1(), task_buffer=task_body + ) + ) + + packed_result = execute_task(task_message) + assert isinstance(packed_result, bytes) + + result = messagepack.unpack(packed_result) + assert isinstance(result, messagepack.message_types.Result) + assert result.data + assert serializer.deserialize(result.data) == output + + +def test_execute_task_with_exception(): + serializer = ComputeSerializer() + task_id = uuid.uuid1() + task_body = ez_pack_function( + serializer, + divide, + ( + 10, + 0, + ), + {}, + ) + + task_message = messagepack.pack( + messagepack.message_types.Task( + task_id=task_id, container_id=uuid.uuid1(), task_buffer=task_body + ) + ) + + packed_result = execute_task(task_message) + assert isinstance(packed_result, bytes) + result = messagepack.unpack(packed_result) + assert isinstance(result, messagepack.message_types.Result) + assert result.error_details + assert "ZeroDivisionError" in result.data diff --git a/compute_endpoint/tests/unit/test_htex.py b/compute_endpoint/tests/unit/test_htex.py new file mode 100644 index 000000000..7a4645182 --- /dev/null +++ b/compute_endpoint/tests/unit/test_htex.py @@ -0,0 +1,142 @@ +import os +import queue +import uuid + +import pytest +from globus_compute_common import messagepack +from globus_compute_endpoint.executors import HighThroughputExecutor +from globus_compute_sdk.serialize import ComputeSerializer +from tests.utils import div_zero, double, ez_pack_function, kill_manager + + +@pytest.fixture +def htex(): + ep_id = uuid.uuid4() + executor = HighThroughputExecutor( + address="127.0.0.1", + heartbeat_period=1, + heartbeat_threshold=1, + worker_debug=True, + ) + q = queue.Queue() + tempdir = "/tmp/HTEX_logs" + + os.makedirs(tempdir, exist_ok=True) + executor.start(endpoint_id=ep_id, run_dir=tempdir, results_passthrough=q) + + yield executor + executor.shutdown() + # shutil.rmtree(tempdir, ignore_errors=True) + + +@pytest.mark.skip("Skip until HTEX has been fixed up") +def test_htex_submit_raw(htex): + """Testing the HighThroughputExecutor/Engine""" + engine = htex + + q = engine.results_passthrough + task_id = uuid.uuid1() + serializer = ComputeSerializer() + task_body = ez_pack_function(serializer, double, (3,), {}) + task_message = messagepack.pack( + messagepack.message_types.Task( + task_id=task_id, container_id=uuid.uuid1(), task_buffer=task_body + ) + ) + + # HTEX doesn't give you a future back + engine.submit_raw(task_message) + + # Confirm that the same result got back though the queue + for _i in range(3): + packed_result_q = q.get(timeout=5) + assert isinstance( + packed_result_q, bytes + ), "Expected bytes from the passthrough_q" + + result = messagepack.unpack(packed_result_q) + # Handle a sneaky EPStatusReport that popped in ahead of the result + if isinstance(result, messagepack.message_types.EPStatusReport): + continue + + # At this point the message should be the result + assert result.task_id == task_id + + final_result = serializer.deserialize(result.data) + assert final_result == 6, f"Expected 6, but got: {final_result}" + break + + +@pytest.mark.skip("Skip until HTEX has been fixed up") +def test_htex_submit_raw_exception(htex): + """Testing the HighThroughputExecutor/Engine with a remote side exception""" + engine = htex + + q = engine.results_passthrough + task_id = uuid.uuid1() + serializer = ComputeSerializer() + task_body = ez_pack_function(serializer, div_zero, (3,), {}) + task_message = messagepack.pack( + messagepack.message_types.Task( + task_id=task_id, container_id=uuid.uuid1(), task_buffer=task_body + ) + ) + + # HTEX doesn't give you a future back + engine.submit_raw(task_message) + + # Confirm that the same result got back though the queue + for _i in range(3): + packed_result_q = q.get(timeout=5) + assert isinstance( + packed_result_q, bytes + ), "Expected bytes from the passthrough_q" + + result = messagepack.unpack(packed_result_q) + # Handle a sneaky EPStatusReport that popped in ahead of the result + if isinstance(result, messagepack.message_types.EPStatusReport): + continue + + # At this point the message should be the result + assert result.task_id == task_id + assert result.error_details + break + + +@pytest.mark.skip("Skip until HTEX has been fixed up") +def test_htex_manager_lost(htex): + """Testing the HighThroughputExecutor/Engine""" + engine = htex + + q = engine.results_passthrough + task_id = uuid.uuid1() + serializer = ComputeSerializer() + task_body = ez_pack_function(serializer, kill_manager, (), {}) + task_message = messagepack.pack( + messagepack.message_types.Task( + task_id=task_id, container_id=uuid.uuid1(), task_buffer=task_body + ) + ) + + # HTEX doesn't give you a future back + engine.submit_raw(task_message) + + # Confirm that the same result got back though the queue + for _i in range(10): + # We need a longer timeout to detect manager fail + packed_result_q = q.get(timeout=5) + assert isinstance( + packed_result_q, bytes + ), "Expected bytes from the passthrough_q" + + result = messagepack.unpack(packed_result_q) + # Handle a sneaky EPStatusReport that popped in ahead of the result + if isinstance(result, messagepack.message_types.EPStatusReport): + continue + + # At this point the message should be the result + assert result.task_id == task_id + + assert result.error_details.code == "RemoteExecutionError" + assert "ManagerLost" in result.data + break diff --git a/compute_endpoint/tests/unit/test_reporting_thread.py b/compute_endpoint/tests/unit/test_reporting_thread.py new file mode 100644 index 000000000..e81a45b57 --- /dev/null +++ b/compute_endpoint/tests/unit/test_reporting_thread.py @@ -0,0 +1,29 @@ +import logging +import queue +import time + +import pytest +from globus_compute_endpoint.engines.base import ReportingThread + +logger = logging.getLogger(__name__) + + +def test_reporting_thread(): + def callback(result_queue: queue.Queue): + # Pop item into queue until queue is full + # after which it raises queue.Full + result_queue.put("42", block=False) + + # We expect 5 items to be popped into the queue + # before it throws an exception + result_q = queue.Queue(maxsize=5) + rt = ReportingThread(target=callback, args=[result_q], reporting_period=0.01) + rt.start() + # Give enough time for the callbacks to be executed N times + time.sleep(0.2) + rt.stop() + + assert result_q.qsize() == 5 + assert rt.status.exception() + with pytest.raises(queue.Full): + rt.status.result() diff --git a/compute_endpoint/tests/unit/test_status_reporting.py b/compute_endpoint/tests/unit/test_status_reporting.py new file mode 100644 index 000000000..f8b5a4a19 --- /dev/null +++ b/compute_endpoint/tests/unit/test_status_reporting.py @@ -0,0 +1,71 @@ +import logging +import multiprocessing +import os +import shutil +import uuid + +import pytest +from globus_compute_common import messagepack +from globus_compute_common.messagepack.message_types import EPStatusReport +from globus_compute_endpoint.engines import GlobusComputeEngine, ProcessPoolEngine +from pytest import fixture + +logger = logging.getLogger(__name__) + +HEARTBEAT_PERIOD = 0.1 + + +@fixture +def proc_pool_engine(): + ep_id = uuid.uuid4() + executor = ProcessPoolEngine(max_workers=2, heartbeat_period_s=HEARTBEAT_PERIOD) + queue = multiprocessing.Queue() + executor.start(endpoint_id=ep_id, run_dir="/tmp", results_passthrough=queue) + + yield executor + executor.shutdown() + + +@fixture +def gc_engine(): + ep_id = uuid.uuid4() + executor = GlobusComputeEngine( + address="127.0.0.1", + heartbeat_period_s=HEARTBEAT_PERIOD, + heartbeat_threshold=1, + ) + queue = multiprocessing.Queue() + tempdir = "/tmp/HTEX_logs" + os.makedirs(tempdir, exist_ok=True) + executor.start(endpoint_id=ep_id, run_dir=tempdir, results_passthrough=queue) + + yield executor + executor.shutdown() + shutil.rmtree(tempdir, ignore_errors=True) + + +@pytest.mark.parametrize("x", ["gc_engine", "proc_pool_engine"]) +def test_status_reporting(x, gc_engine, proc_pool_engine): + if x == "gc_engine": + executor = gc_engine + elif x == "proc_pool_engine": + executor = proc_pool_engine + + report = executor.get_status_report() + assert isinstance(report, EPStatusReport) + + results_q = executor.results_passthrough + + assert executor._status_report_thread.reporting_period == HEARTBEAT_PERIOD + + # Flush queue to start + while not results_q.empty(): + results_q.get() + + # Confirm heartbeats in regular intervals + for _i in range(3): + message = results_q.get(timeout=0.2) + assert isinstance(message, bytes) + + report = messagepack.unpack(message) + assert isinstance(report, EPStatusReport) diff --git a/compute_endpoint/tests/utils.py b/compute_endpoint/tests/utils.py index 41d095f7e..dc403d397 100644 --- a/compute_endpoint/tests/utils.py +++ b/compute_endpoint/tests/utils.py @@ -74,3 +74,35 @@ def try_for_timeout( return True time.sleep(check_period_s) return False + + +def ez_pack_function(serializer, func, args, kwargs): + serialized_func = serializer.serialize(func) + serialized_args = serializer.serialize(args) + serialized_kwargs = serializer.serialize(kwargs) + return serializer.pack_buffers( + [serialized_func, serialized_args, serialized_kwargs] + ) + + +def double(x: int) -> int: + return x * 2 + + +def slow_double(x: int, sleep_duration_s: int) -> int: + import time + + time.sleep(sleep_duration_s) + return x * 2 + + +def kill_manager(): + import os + import signal + + manager_pid = os.getppid() + os.kill(manager_pid, signal.SIGKILL) + + +def div_zero(x: int): + return x / 0