From cd4a752fe05f6bb15c2a949da6e2789030a2c8df Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Fri, 22 Oct 2021 20:04:01 +0000 Subject: [PATCH] Update linting config and fix Remove flake8 ignores for various rules: only ignore the rules which must be ignored to be black-compatible. Remove exclusions from pre-commit config. --- .flake8 | 29 +- .pre-commit-config.yaml | 12 - docs/conf.py | 4 +- docs/configs/bluewaters.py | 2 +- docs/configs/polaris.py | 1 - docs/configs/uchicago_ai_cluster.py | 17 +- .../funcx_endpoint/endpoint/config.py | 76 +- .../funcx_endpoint/endpoint/default_config.py | 23 +- .../funcx_endpoint/endpoint/endpoint.py | 146 +-- .../endpoint/endpoint_manager.py | 288 +++-- .../funcx_endpoint/endpoint/interchange.py | 590 +++++++---- .../endpoint/register_endpoint.py | 36 +- .../funcx_endpoint/endpoint/results_ack.py | 39 +- .../funcx_endpoint/endpoint/taskqueue.py | 80 +- .../funcx_endpoint/endpoint/utils/config.py | 15 +- .../funcx_endpoint/executors/__init__.py | 2 +- .../high_throughput/container_sched.py | 26 +- .../executors/high_throughput/executor.py | 623 ++++++----- .../high_throughput/funcx_manager.py | 669 +++++++----- .../executors/high_throughput/funcx_worker.py | 161 +-- .../high_throughput/global_config.py | 11 +- .../executors/high_throughput/interchange.py | 985 +++++++++++------- .../interchange_task_dispatch.py | 218 ++-- .../high_throughput/mac_safe_queue.py | 5 +- .../executors/high_throughput/messages.py | 58 +- .../executors/high_throughput/worker_map.py | 347 +++--- .../executors/high_throughput/zmq_pipes.py | 67 +- .../funcx_endpoint/providers/__init__.py | 2 +- .../providers/kubernetes/kube.py | 38 +- .../funcx_endpoint/strategies/__init__.py | 7 +- .../funcx_endpoint/strategies/base.py | 52 +- .../funcx_endpoint/strategies/kube_simple.py | 73 +- .../funcx_endpoint/strategies/simple.py | 62 +- .../funcx_endpoint/strategies/test.py | 15 +- .../tests/strategies/test_kube_simple.py | 47 +- funcx_endpoint/funcx_endpoint/version.py | 2 +- funcx_endpoint/setup.py | 9 +- .../funcx_endpoint/endpoint/test_endpoint.py | 17 +- .../endpoint/test_endpoint_manager.py | 375 ++++--- .../endpoint/test_interchange.py | 137 +-- .../endpoint/test_register_endpoint.py | 45 +- .../high_throughput/test_funcx_manager.py | 53 +- .../high_throughput/test_funcx_worker.py | 27 +- .../high_throughput/test_worker_map.py | 27 +- .../tests/integration/test_batch_submit.py | 29 +- .../tests/integration/test_config.py | 70 +- .../tests/integration/test_containers.py | 14 +- .../tests/integration/test_deserialization.py | 3 +- .../tests/integration/test_executor.py | 2 - .../integration/test_executor_passthrough.py | 27 +- .../tests/integration/test_interchange.py | 19 +- .../tests/integration/test_per_func_batch.py | 6 +- .../tests/integration/test_registration.py | 3 +- .../tests/integration/test_serialization.py | 7 +- .../tests/integration/test_status.py | 21 +- .../tests/integration/test_submits.py | 12 +- .../tests/integration/test_throttling.py | 33 +- .../tests/tutorial_ep/test_tutotial_ep.py | 95 +- funcx_sdk/funcx/__init__.py | 4 +- funcx_sdk/funcx/sdk/__init__.py | 2 + .../funcx/sdk/asynchronous/ws_polling_task.py | 39 +- funcx_sdk/funcx/sdk/client.py | 4 +- funcx_sdk/funcx/sdk/error_handling_client.py | 4 +- funcx_sdk/funcx/sdk/executor.py | 9 +- funcx_sdk/funcx/sdk/search.py | 5 +- funcx_sdk/funcx/sdk/utils/futures.py | 1 - funcx_sdk/funcx/utils/loggers.py | 3 +- funcx_sdk/funcx/utils/response_errors.py | 23 +- 68 files changed, 3558 insertions(+), 2395 deletions(-) diff --git a/.flake8 b/.flake8 index 39142f074..98f0dc50e 100644 --- a/.flake8 +++ b/.flake8 @@ -1,27 +1,4 @@ -[flake8] -# TODO: remove all of these ignores other than W503,W504,B008 -# `black` will handle enforcement of styling, and we will have no opinionated -# ignore rules -# any cases in which we actually need to ignore a rule (e.g. E402) we will mark -# the relevant segment with noqa comments as necessary -# -# D203: 1 blank line required before class docstring -# E124: closing bracket does not match visual indentation -# E126: continuation line over-indented for hanging indent -# This one is bad. Sometimes ordering matters, conditional imports -# setting env vars necessary etc. -# E402: module level import not at top of file -# E129: Visual indent to not match indent as next line, counter eg here: -# https://github.com/PyCQA/pycodestyle/issues/386 -# -# E203,W503,W504: conflict with black formatting sometimes -# B008: a flake8-bugbear rule which fails on idiomatic typer usage (consider -# re-enabling this once everything else is fixed and updating usage) -ignore = D203, E124, E126, E402, E129, W605, W503, W504, E203, F401, B008 +[flake8] # black-compatible +ignore = W503, W504, E203, B008 # TODO: reduce this to 88 once `black` is applied to all code -max-line-length = 160 -exclude = parsl/executors/serialize/, test_import_fail.py -# F632 is comparing constant literals with == instead of "is" -per-file-ignores = funcx_sdk/funcx/sdk/client.py:F632, - funcx_endpoint/funcx/endpoint/auth.py:F821, - funcx_endpoint/funcx/serialize/base.py:F821 +max-line-length = 88 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cabc51d2a..54ad96044 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,9 +8,6 @@ repos: hooks: - id: check-merge-conflict - id: trailing-whitespace - # FIXME: temporary exclude to reduce conflicts, remove this - exclude: ^funcx_endpoint/ - # end FIXME - repo: https://github.com/sirosen/check-jsonschema rev: 0.3.1 hooks: @@ -19,9 +16,6 @@ repos: rev: 21.5b1 hooks: - id: black - # FIXME: temporary exclude to reduce conflicts, remove this - exclude: ^funcx_endpoint/ - # end FIXME - repo: https://github.com/timothycrosley/isort rev: 5.8.0 hooks: @@ -29,17 +23,11 @@ repos: # explicitly pass settings file so that isort does not try to deduce # which settings to use based on a file's directory args: ["--settings-path", ".isort.cfg"] - # FIXME: temporary exclude to reduce conflicts, remove this - exclude: ^funcx_endpoint/ - # end FIXME - repo: https://github.com/asottile/pyupgrade rev: v2.17.0 hooks: - id: pyupgrade args: ["--py36-plus"] - # FIXME: temporary exclude to reduce conflicts, remove this - exclude: ^funcx_endpoint/ - # end FIXME - repo: https://gitlab.com/pycqa/flake8 rev: 3.9.2 hooks: diff --git a/docs/conf.py b/docs/conf.py index e8eb9f11e..9dde0e7f3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,10 +16,8 @@ import os import sys -import requests - sys.path.insert(0, os.path.abspath("../funcx_sdk/")) -import funcx +import funcx # noqa:E402 # -- Project information ----------------------------------------------------- diff --git a/docs/configs/bluewaters.py b/docs/configs/bluewaters.py index bb1530836..eddd6588c 100644 --- a/docs/configs/bluewaters.py +++ b/docs/configs/bluewaters.py @@ -10,7 +10,7 @@ # PLEASE UPDATE user_opts BEFORE USE user_opts = { 'bluewaters': { - 'worker_init': 'module load bwpy;source anaconda3/etc/profile.d/conda.sh;conda activate funcx_testing_py3.7', + 'worker_init': 'module load bwpy;source anaconda3/etc/profile.d/conda.sh;conda activate funcx_testing_py3.7', # noqa: E501 'scheduler_options': '', } } diff --git a/docs/configs/polaris.py b/docs/configs/polaris.py index 07987a598..869e7b2c4 100644 --- a/docs/configs/polaris.py +++ b/docs/configs/polaris.py @@ -1,4 +1,3 @@ -from parsl.addresses import address_by_hostname from parsl.launchers import SingleNodeLauncher from parsl.providers import PBSProProvider diff --git a/docs/configs/uchicago_ai_cluster.py b/docs/configs/uchicago_ai_cluster.py index c4b69af0e..ba9a06b23 100644 --- a/docs/configs/uchicago_ai_cluster.py +++ b/docs/configs/uchicago_ai_cluster.py @@ -1,6 +1,6 @@ from parsl.addresses import address_by_hostname from parsl.launchers import SrunLauncher -from parsl.providers import LocalProvider, SlurmProvider +from parsl.providers import SlurmProvider from funcx_endpoint.endpoint.utils.config import Config from funcx_endpoint.executors import HighThroughputExecutor @@ -28,12 +28,17 @@ partition='general', # Launch 4 managers per node, each bound to 1 GPU - # This is a hack. We use hostname ; to terminate the srun command, and start our own + # This is a hack. We use hostname ; to terminate the srun command, and + # start our own + # # DO NOT MODIFY unless you know what you are doing. - launcher=SrunLauncher(overrides=(f'hostname; srun --ntasks={TOTAL_WORKERS} ' - f'--ntasks-per-node={WORKERS_PER_NODE} ' - f'--gpus-per-task=rtx2080ti:{GPUS_PER_WORKER} ' - f'--gpu-bind=map_gpu:{GPU_MAP}') + launcher=SrunLauncher( + overrides=( + f'hostname; srun --ntasks={TOTAL_WORKERS} ' + f'--ntasks-per-node={WORKERS_PER_NODE} ' + f'--gpus-per-task=rtx2080ti:{GPUS_PER_WORKER} ' + f'--gpu-bind=map_gpu:{GPU_MAP}' + ) ), # Scale between 0-1 blocks with 2 nodes per block diff --git a/funcx_endpoint/funcx_endpoint/endpoint/config.py b/funcx_endpoint/funcx_endpoint/endpoint/config.py index 432eb7bbd..22394a72c 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/config.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/config.py @@ -1,17 +1,16 @@ -import globus_sdk -import parsl import os -from parsl.config import Config +import globus_sdk +from parsl.addresses import address_by_route from parsl.channels import LocalChannel -from parsl.providers import LocalProvider, KubernetesProvider +from parsl.config import Config from parsl.executors import HighThroughputExecutor -from parsl.addresses import address_by_route +from parsl.providers import KubernetesProvider, LocalProvider # GlobusAuth-related secrets -SECRET_KEY = os.environ.get('secret_key') -GLOBUS_KEY = os.environ.get('globus_key') -GLOBUS_CLIENT = os.environ.get('globus_client') +SECRET_KEY = os.environ.get("secret_key") +GLOBUS_KEY = os.environ.get("globus_key") +GLOBUS_CLIENT = os.environ.get("globus_client") FUNCX_URL = "https://funcx.org/" FUNCX_HUB_URL = "3.88.81.131" @@ -31,10 +30,9 @@ def _load_auth_client(): _prod = True if _prod: - app = globus_sdk.ConfidentialAppAuthClient(GLOBUS_CLIENT, - GLOBUS_KEY) + app = globus_sdk.ConfidentialAppAuthClient(GLOBUS_CLIENT, GLOBUS_KEY) else: - app = globus_sdk.ConfidentialAppAuthClient('', '') + app = globus_sdk.ConfidentialAppAuthClient("", "") return app @@ -62,7 +60,7 @@ def _get_parsl_config(): ), ) ], - strategy=None + strategy=None, ) return config @@ -77,31 +75,35 @@ def _get_executor(container): """ executor = HighThroughputExecutor( - label=container['container_uuid'], - cores_per_worker=1, - max_workers=1, - poll_period=10, - # launch_cmd="ls; sleep 3600", - worker_logdir_root='runinfo', - # worker_debug=True, - address=address_by_route(), - provider=KubernetesProvider( - namespace="dlhub-privileged", - image=container['location'], - nodes_per_block=1, - init_blocks=1, - max_blocks=1, - parallelism=1, - worker_init="""pip install git+https://github.com/Parsl/parsl; + label=container["container_uuid"], + cores_per_worker=1, + max_workers=1, + poll_period=10, + # launch_cmd="ls; sleep 3600", + worker_logdir_root="runinfo", + # worker_debug=True, + address=address_by_route(), + provider=KubernetesProvider( + namespace="dlhub-privileged", + image=container["location"], + nodes_per_block=1, + init_blocks=1, + max_blocks=1, + parallelism=1, + worker_init="""pip install git+https://github.com/Parsl/parsl; pip install git+https://github.com/funcx-faas/funcX; export PYTHONPATH=$PYTHONPATH:/home/ubuntu:/app""", - # security=None, - secret="ryan-kube-secret", - pod_name=container['name'].replace('.', '-').replace("_", '-').replace('/', '-').lower(), - # secret="minikube-aws-ecr", - # user_id=32781, - # group_id=10253, - # run_as_non_root=True - ), - ) + # security=None, + secret="ryan-kube-secret", + pod_name=container["name"] + .replace(".", "-") + .replace("_", "-") + .replace("/", "-") + .lower(), + # secret="minikube-aws-ecr", + # user_id=32781, + # group_id=10253, + # run_as_non_root=True + ), + ) return [executor] diff --git a/funcx_endpoint/funcx_endpoint/endpoint/default_config.py b/funcx_endpoint/funcx_endpoint/endpoint/default_config.py index cd451ad5c..10e5553e3 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/default_config.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/default_config.py @@ -1,16 +1,19 @@ +from parsl.providers import LocalProvider + from funcx_endpoint.endpoint.utils.config import Config from funcx_endpoint.executors import HighThroughputExecutor -from parsl.providers import LocalProvider config = Config( - executors=[HighThroughputExecutor( - provider=LocalProvider( - init_blocks=1, - min_blocks=0, - max_blocks=1, - ), - )], - funcx_service_address='https://api2.funcx.org/v2' + executors=[ + HighThroughputExecutor( + provider=LocalProvider( + init_blocks=1, + min_blocks=0, + max_blocks=1, + ), + ) + ], + funcx_service_address="https://api2.funcx.org/v2", ) # For now, visible_to must be a list of URNs for globus auth users or groups, e.g.: @@ -22,5 +25,5 @@ "organization": "", "department": "", "public": False, - "visible_to": [] + "visible_to": [], } diff --git a/funcx_endpoint/funcx_endpoint/endpoint/endpoint.py b/funcx_endpoint/funcx_endpoint/endpoint/endpoint.py index a3fa83fb0..9cba7242a 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/endpoint.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/endpoint.py @@ -1,32 +1,13 @@ import glob -from importlib.machinery import SourceFileLoader -import json import logging import os import pathlib -import random -import shutil -import signal -import sys -import time -import uuid -from string import Template - -import daemon -import daemon.pidfile -import psutil -import requests +from importlib.machinery import SourceFileLoader + import typer -from retry import retry import funcx -import zmq - -from funcx_endpoint.endpoint import default_config as endpoint_default_config -from funcx_endpoint.executors.high_throughput import global_config as funcx_default_config -from funcx_endpoint.endpoint.interchange import EndpointInterchange from funcx_endpoint.endpoint.endpoint_manager import EndpointManager -from funcx.sdk.client import FuncXClient app = typer.Typer() logger = None @@ -35,23 +16,28 @@ def version_callback(value): if value: import funcx_endpoint - typer.echo("FuncX endpoint version: {}".format(funcx_endpoint.__version__)) + + typer.echo(f"FuncX endpoint version: {funcx_endpoint.__version__}") raise typer.Exit() def complete_endpoint_name(): # Manager context is not initialized at this point, so we assume the default # the funcx_dir path of ~/.funcx - funcx_dir = os.path.join(pathlib.Path.home(), '.funcx') - config_files = glob.glob(os.path.join(funcx_dir, '*', 'config.py')) + funcx_dir = os.path.join(pathlib.Path.home(), ".funcx") + config_files = glob.glob(os.path.join(funcx_dir, "*", "config.py")) for config_file in config_files: yield os.path.basename(os.path.dirname(config_file)) @app.command(name="configure", help="Configure an endpoint") def configure_endpoint( - name: str = typer.Argument("default", help="endpoint name", autocompletion=complete_endpoint_name), - endpoint_config: str = typer.Option(None, "--endpoint-config", help="endpoint config file") + name: str = typer.Argument( + "default", help="endpoint name", autocompletion=complete_endpoint_name + ), + endpoint_config: str = typer.Option( + None, "--endpoint-config", help="endpoint config file" + ), ): """Configure an endpoint @@ -63,8 +49,10 @@ def configure_endpoint( @app.command(name="start", help="Start an endpoint by name") def start_endpoint( - name: str = typer.Argument("default", autocompletion=complete_endpoint_name), - endpoint_uuid: str = typer.Option(None, help="The UUID for the endpoint to register with") + name: str = typer.Argument("default", autocompletion=complete_endpoint_name), + endpoint_uuid: str = typer.Option( + None, help="The UUID for the endpoint to register with" + ), ): """Start an endpoint @@ -94,37 +82,45 @@ def start_endpoint( endpoint_dir = os.path.join(manager.funcx_dir, name) if not os.path.exists(endpoint_dir): - msg = (f'\nEndpoint {name} is not configured!\n' - '1. Please create a configuration template with:\n' - f'\tfuncx-endpoint configure {name}\n' - '2. Update the configuration\n' - '3. Start the endpoint\n') + msg = ( + f"\nEndpoint {name} is not configured!\n" + "1. Please create a configuration template with:\n" + f"\tfuncx-endpoint configure {name}\n" + "2. Update the configuration\n" + "3. Start the endpoint\n" + ) print(msg) return try: - endpoint_config = SourceFileLoader('config', - os.path.join(endpoint_dir, manager.funcx_config_file_name)).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(endpoint_dir, manager.funcx_config_file_name) + ).load_module() except Exception: - manager.logger.exception('funcX v0.2.0 made several non-backwards compatible changes to the config. ' - 'Your config might be out of date. ' - 'Refer to https://funcx.readthedocs.io/en/latest/endpoints.html#configuring-funcx') + manager.logger.exception( + "funcX v0.2.0 made several non-backwards compatible changes to the config. " + "Your config might be out of date. " + "Refer to " + "https://funcx.readthedocs.io/en/latest/endpoints.html#configuring-funcx" + ) raise manager.start_endpoint(name, endpoint_uuid, endpoint_config) @app.command(name="stop") -def stop_endpoint(name: str = typer.Argument("default", autocompletion=complete_endpoint_name)): - """ Stops an endpoint using the pidfile - - """ +def stop_endpoint( + name: str = typer.Argument("default", autocompletion=complete_endpoint_name) +): + """Stops an endpoint using the pidfile""" manager.stop_endpoint(name) @app.command(name="restart") -def restart_endpoint(name: str = typer.Argument("default", autocompletion=complete_endpoint_name)): +def restart_endpoint( + name: str = typer.Argument("default", autocompletion=complete_endpoint_name) +): """Restarts an endpoint""" stop_endpoint(name) start_endpoint(name) @@ -132,52 +128,72 @@ def restart_endpoint(name: str = typer.Argument("default", autocompletion=comple @app.command(name="list") def list_endpoints(): - """ List all available endpoints - """ + """List all available endpoints""" manager.list_endpoints() @app.command(name="delete") def delete_endpoint( - name: str = typer.Argument(..., autocompletion=complete_endpoint_name), - autoconfirm: bool = typer.Option(False, "-y", help="Do not ask for confirmation to delete.") + name: str = typer.Argument(..., autocompletion=complete_endpoint_name), + autoconfirm: bool = typer.Option( + False, "-y", help="Do not ask for confirmation to delete." + ), ): """Deletes an endpoint and its config.""" if not autoconfirm: - typer.confirm(f"Are you sure you want to delete the endpoint <{name}>?", abort=True) + typer.confirm( + f"Are you sure you want to delete the endpoint <{name}>?", abort=True + ) manager.delete_endpoint(name) @app.callback() def main( - ctx: typer.Context, - _: bool = typer.Option(None, "--version", "-v", callback=version_callback, is_eager=True), - debug: bool = typer.Option(False, "--debug", "-d"), - config_dir: str = typer.Option(os.path.join(pathlib.Path.home(), '.funcx'), "--config_dir", "-c", help="override default config dir") + ctx: typer.Context, + _: bool = typer.Option( + None, "--version", "-v", callback=version_callback, is_eager=True + ), + debug: bool = typer.Option(False, "--debug", "-d"), + config_dir: str = typer.Option( + os.path.join(pathlib.Path.home(), ".funcx"), + "--config_dir", + "-c", + help="override default config dir", + ), ): - # Note: no docstring here; the docstring for @app.callback is used as a help message for overall app. - # Sets up global variables in the State wrapper (debug flag, config dir, default config file). - # For commands other than `init`, we ensure the existence of the config directory and file. + # Note: no docstring here; the docstring for @app.callback is used as a help + # message for overall app. + # + # Sets up global variables in the State wrapper (debug flag, config dir, default + # config file). + # + # For commands other than `init`, we ensure the existence of the config directory + # and file. global logger - funcx.set_stream_logger(name='endpoint', - level=logging.DEBUG if debug else logging.INFO) - logger = logging.getLogger('endpoint') - logger.debug("Command: {}".format(ctx.invoked_subcommand)) + funcx.set_stream_logger( + name="endpoint", level=logging.DEBUG if debug else logging.INFO + ) + logger = logging.getLogger("endpoint") + logger.debug(f"Command: {ctx.invoked_subcommand}") global manager - manager = EndpointManager(funcx_dir=config_dir, - debug=debug) + manager = EndpointManager(funcx_dir=config_dir, debug=debug) # Otherwise, we ensure that configs exist if not os.path.exists(manager.funcx_config_file): - logger.info(f"No existing configuration found at {manager.funcx_config_file}. Initializing...") + logger.info( + "No existing configuration found at %s. Initializing...", + manager.funcx_config_file, + ) manager.init_endpoint() - logger.debug("Loading config files from {}".format(manager.funcx_dir)) + logger.debug(f"Loading config files from {manager.funcx_dir}") - funcx_config = SourceFileLoader('global_config', manager.funcx_config_file).load_module() + funcx_config = SourceFileLoader( + "global_config", manager.funcx_config_file + ).load_module() manager.funcx_config = funcx_config.global_options @@ -186,5 +202,5 @@ def cli_run(): app() -if __name__ == '__main__': +if __name__ == "__main__": app() diff --git a/funcx_endpoint/funcx_endpoint/endpoint/endpoint_manager.py b/funcx_endpoint/funcx_endpoint/endpoint/endpoint_manager.py index e4f196064..bcf0159ec 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/endpoint_manager.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/endpoint_manager.py @@ -6,7 +6,6 @@ import shutil import signal import sys -import time import uuid from string import Template @@ -15,30 +14,32 @@ import psutil import texttable import typer - -import funcx import zmq from globus_sdk import GlobusAPIError, NetworkError +from funcx.sdk.client import FuncXClient from funcx.utils.response_errors import FuncxResponseError from funcx_endpoint.endpoint import default_config as endpoint_default_config -from funcx_endpoint.executors.high_throughput import global_config as funcx_default_config from funcx_endpoint.endpoint.interchange import EndpointInterchange from funcx_endpoint.endpoint.register_endpoint import register_endpoint from funcx_endpoint.endpoint.results_ack import ResultsAckHandler -from funcx.sdk.client import FuncXClient +from funcx_endpoint.executors.high_throughput import ( + global_config as funcx_default_config, +) logger = logging.getLogger("endpoint.endpoint_manager") class EndpointManager: - """ EndpointManager is primarily responsible for configuring, launching and stopping the Endpoint. + """ + EndpointManager is primarily responsible for configuring, launching and stopping + the Endpoint. """ - def __init__(self, - funcx_dir=os.path.join(pathlib.Path.home(), '.funcx'), - debug=False): - """ Initialize the EndpointManager + def __init__( + self, funcx_dir=os.path.join(pathlib.Path.home(), ".funcx"), debug=False + ): + """Initialize the EndpointManager Parameters ---------- @@ -49,18 +50,20 @@ def __init__(self, debug: Bool Enable debug logging. Default: False """ - self.funcx_config_file_name = 'config.py' + self.funcx_config_file_name = "config.py" self.debug = debug self.funcx_dir = funcx_dir - self.funcx_config_file = os.path.join(self.funcx_dir, self.funcx_config_file_name) + self.funcx_config_file = os.path.join( + self.funcx_dir, self.funcx_config_file_name + ) self.funcx_default_config_template = funcx_default_config.__file__ self.funcx_config = {} - self.name = 'default' + self.name = "default" global logger self.logger = logger def init_endpoint_dir(self, endpoint_config=None): - """ Initialize a clean endpoint dir + """Initialize a clean endpoint dir Returns if an endpoint_dir already exists Parameters @@ -73,7 +76,9 @@ def init_endpoint_dir(self, endpoint_config=None): self.logger.debug(f"Creating endpoint dir {endpoint_dir}") os.makedirs(endpoint_dir, exist_ok=True) - endpoint_config_target_file = os.path.join(endpoint_dir, self.funcx_config_file_name) + endpoint_config_target_file = os.path.join( + endpoint_dir, self.funcx_config_file_name + ) if endpoint_config: shutil.copyfile(endpoint_config, endpoint_config_target_file) return endpoint_dir @@ -95,29 +100,35 @@ def configure_endpoint(self, name, endpoint_config): if not os.path.exists(endpoint_dir): self.init_endpoint_dir(endpoint_config=endpoint_config) - print(f'A default profile has been create for <{self.name}> at {new_config_file}') - print('Configure this file and try restarting with:') - print(f' $ funcx-endpoint start {self.name}') + print( + f"A default profile has been create for <{self.name}> " + f"at {new_config_file}" + ) + print("Configure this file and try restarting with:") + print(f" $ funcx-endpoint start {self.name}") else: - print(f'config dir <{self.name}> already exsits') - raise Exception('ConfigExists') + print(f"config dir <{self.name}> already exsits") + raise Exception("ConfigExists") def init_endpoint(self): """Setup funcx dirs and default endpoint config files TODO : Every mechanism that will update the config file, must be using a - locking mechanism, ideally something like fcntl https://docs.python.org/3/library/fcntl.html - to ensure that multiple endpoint invocations do not mangle the funcx config files - or the lockfile module. + locking mechanism, ideally something like fcntl [1] + to ensure that multiple endpoint invocations do not mangle the funcx config + files or the lockfile module. + + [1] https://docs.python.org/3/library/fcntl.html """ _ = FuncXClient() if os.path.exists(self.funcx_config_file): typer.confirm( "Are you sure you want to initialize this directory? " - f"This will erase everything in {self.funcx_dir}", abort=True + f"This will erase everything in {self.funcx_dir}", + abort=True, ) - self.logger.info("Wiping all current configs in {}".format(self.funcx_dir)) + self.logger.info(f"Wiping all current configs in {self.funcx_dir}") backup_dir = self.funcx_dir + ".bak" try: self.logger.debug(f"Removing old backups in {backup_dir}") @@ -127,22 +138,24 @@ def init_endpoint(self): os.renames(self.funcx_dir, backup_dir) if os.path.exists(self.funcx_config_file): - self.logger.debug("Config file exists at {}".format(self.funcx_config_file)) + self.logger.debug(f"Config file exists at {self.funcx_config_file}") return try: os.makedirs(self.funcx_dir, exist_ok=True) except Exception as e: - print("[FuncX] Caught exception during registration {}".format(e)) + print(f"[FuncX] Caught exception during registration {e}") shutil.copyfile(self.funcx_default_config_template, self.funcx_config_file) def check_endpoint_json(self, endpoint_json, endpoint_uuid): if os.path.exists(endpoint_json): - with open(endpoint_json, 'r') as fp: - self.logger.debug("Connection info loaded from prior registration record") + with open(endpoint_json) as fp: + self.logger.debug( + "Connection info loaded from prior registration record" + ) reg_info = json.load(fp) - endpoint_uuid = reg_info['endpoint_id'] + endpoint_uuid = reg_info["endpoint_id"] elif not endpoint_uuid: endpoint_uuid = str(uuid.uuid4()) return endpoint_uuid @@ -151,18 +164,23 @@ def start_endpoint(self, name, endpoint_uuid, endpoint_config): self.name = name endpoint_dir = os.path.join(self.funcx_dir, self.name) - endpoint_json = os.path.join(endpoint_dir, 'endpoint.json') + endpoint_json = os.path.join(endpoint_dir, "endpoint.json") # These certs need to be recreated for every registration - keys_dir = os.path.join(endpoint_dir, 'certificates') + keys_dir = os.path.join(endpoint_dir, "certificates") os.makedirs(keys_dir, exist_ok=True) - client_public_file, client_secret_file = zmq.auth.create_certificates(keys_dir, "endpoint") + client_public_file, client_secret_file = zmq.auth.create_certificates( + keys_dir, "endpoint" + ) client_public_key, _ = zmq.auth.load_certificate(client_public_file) - client_public_key = client_public_key.decode('utf-8') + client_public_key = client_public_key.decode("utf-8") # This is to ensure that at least 1 executor is defined if not endpoint_config.config.executors: - raise Exception(f"Endpoint config file at {endpoint_dir} is missing executor definitions") + raise Exception( + f"Endpoint config file at {endpoint_dir} is missing " + "executor definitions" + ) funcx_client_options = { "funcx_service_address": endpoint_config.config.funcx_service_address, @@ -174,17 +192,20 @@ def start_endpoint(self, name, endpoint_uuid, endpoint_config): self.logger.info(f"Starting endpoint with uuid: {endpoint_uuid}") - pid_file = os.path.join(endpoint_dir, 'daemon.pid') + pid_file = os.path.join(endpoint_dir, "daemon.pid") pid_check = self.check_pidfile(pid_file) # if the pidfile exists, we should return early because we don't # want to attempt to create a new daemon when one is already # potentially running with the existing pidfile - if pid_check['exists']: - if pid_check['active']: + if pid_check["exists"]: + if pid_check["active"]: self.logger.info("Endpoint is already active") sys.exit(-1) else: - self.logger.info("A prior Endpoint instance appears to have been terminated without proper cleanup. Cleaning up now.") + self.logger.info( + "A prior Endpoint instance appears to have been terminated without " + "proper cleanup. Cleaning up now." + ) self.pidfile_cleanup(pid_file) results_ack_handler = ResultsAckHandler(endpoint_dir=endpoint_dir) @@ -193,64 +214,93 @@ def start_endpoint(self, name, endpoint_uuid, endpoint_config): results_ack_handler.load() results_ack_handler.persist() except Exception: - self.logger.exception("Caught exception while attempting load and persist of outstanding results") + self.logger.exception( + "Caught exception while attempting load and persist of outstanding " + "results" + ) sys.exit(-1) # Create a daemon context # If we are running a full detached daemon then we will send the output to # log files, otherwise we can piggy back on our stdout if endpoint_config.config.detach_endpoint: - stdout = open(os.path.join(endpoint_dir, endpoint_config.config.stdout), 'a+') - stderr = open(os.path.join(endpoint_dir, endpoint_config.config.stderr), 'a+') + stdout = open( + os.path.join(endpoint_dir, endpoint_config.config.stdout), "a+" + ) + stderr = open( + os.path.join(endpoint_dir, endpoint_config.config.stderr), "a+" + ) else: stdout = sys.stdout stderr = sys.stderr try: - context = daemon.DaemonContext(working_directory=endpoint_dir, - umask=0o002, - pidfile=daemon.pidfile.PIDLockFile(pid_file), - stdout=stdout, - stderr=stderr, - detach_process=endpoint_config.config.detach_endpoint) + context = daemon.DaemonContext( + working_directory=endpoint_dir, + umask=0o002, + pidfile=daemon.pidfile.PIDLockFile(pid_file), + stdout=stdout, + stderr=stderr, + detach_process=endpoint_config.config.detach_endpoint, + ) except Exception: - self.logger.exception("Caught exception while trying to setup endpoint context dirs") + self.logger.exception( + "Caught exception while trying to setup endpoint context dirs" + ) sys.exit(-1) # place registration after everything else so that the endpoint will # only be registered if everything else has been set up successfully reg_info = None try: - reg_info = register_endpoint(funcx_client, endpoint_uuid, endpoint_dir, self.name, logger=self.logger) + reg_info = register_endpoint( + funcx_client, endpoint_uuid, endpoint_dir, self.name, logger=self.logger + ) # if the service sends back an error response, it will be a FuncxResponseError except FuncxResponseError as e: # an example of an error that could conceivably occur here would be # if the service could not register this endpoint with the forwarder # because the forwarder was unreachable if e.http_status_code >= 500: - self.logger.exception("Caught exception while attempting endpoint registration") - self.logger.critical("Endpoint registration will be retried in the new endpoint daemon " - "process. The endpoint will not work until it is successfully registered.") + self.logger.exception( + "Caught exception while attempting endpoint registration" + ) + self.logger.critical( + "Endpoint registration will be retried in the new endpoint daemon " + "process. The endpoint will not work until it is successfully " + "registered." + ) else: raise e # if the service has an unexpected internal error and is unable to send # back a FuncxResponseError except GlobusAPIError as e: if e.http_status >= 500: - self.logger.exception("Caught exception while attempting endpoint registration") - self.logger.critical("Endpoint registration will be retried in the new endpoint daemon " - "process. The endpoint will not work until it is successfully registered.") + self.logger.exception( + "Caught exception while attempting endpoint registration" + ) + self.logger.critical( + "Endpoint registration will be retried in the new endpoint daemon " + "process. The endpoint will not work until it is successfully " + "registered." + ) else: raise e # if the service is unreachable due to a timeout or connection error except NetworkError as e: # the output of a NetworkError exception is huge and unhelpful, so # it seems better to just stringify it here and get a concise error - self.logger.exception(f"Caught exception while attempting endpoint registration: {e}") - self.logger.critical("funcx-endpoint is unable to reach the funcX service due to a NetworkError \n" - "Please make sure that the funcX service address you provided is reachable \n" - "and then attempt restarting the endpoint") + self.logger.exception( + f"Caught exception while attempting endpoint registration: {e}" + ) + self.logger.critical( + "funcx-endpoint is unable to reach the funcX service due to a " + "NetworkError \n" + "Please make sure that the funcX service address you provided is " + "reachable \n" + "and then attempt restarting the endpoint" + ) exit(-1) except Exception: raise @@ -258,31 +308,52 @@ def start_endpoint(self, name, endpoint_uuid, endpoint_config): if reg_info: self.logger.info("Launching endpoint daemon process") else: - self.logger.critical("Launching endpoint daemon process with errors noted above") + self.logger.critical( + "Launching endpoint daemon process with errors noted above" + ) with context: - self.daemon_launch(endpoint_uuid, endpoint_dir, keys_dir, endpoint_config, reg_info, funcx_client_options, results_ack_handler) + self.daemon_launch( + endpoint_uuid, + endpoint_dir, + keys_dir, + endpoint_config, + reg_info, + funcx_client_options, + results_ack_handler, + ) - def daemon_launch(self, endpoint_uuid, endpoint_dir, keys_dir, endpoint_config, reg_info, funcx_client_options, results_ack_handler): + def daemon_launch( + self, + endpoint_uuid, + endpoint_dir, + keys_dir, + endpoint_config, + reg_info, + funcx_client_options, + results_ack_handler, + ): # Configure the parameters for the interchange optionals = {} - if 'endpoint_address' in self.funcx_config: - optionals['interchange_address'] = self.funcx_config['endpoint_address'] + if "endpoint_address" in self.funcx_config: + optionals["interchange_address"] = self.funcx_config["endpoint_address"] - optionals['logdir'] = endpoint_dir + optionals["logdir"] = endpoint_dir if self.debug: - optionals['logging_level'] = logging.DEBUG - - ic = EndpointInterchange(endpoint_config.config, - endpoint_id=endpoint_uuid, - keys_dir=keys_dir, - endpoint_dir=endpoint_dir, - endpoint_name=self.name, - reg_info=reg_info, - funcx_client_options=funcx_client_options, - results_ack_handler=results_ack_handler, - **optionals) + optionals["logging_level"] = logging.DEBUG + + ic = EndpointInterchange( + endpoint_config.config, + endpoint_id=endpoint_uuid, + keys_dir=keys_dir, + endpoint_dir=endpoint_dir, + endpoint_name=self.name, + reg_info=reg_info, + funcx_client_options=funcx_client_options, + results_ack_handler=results_ack_handler, + **optionals, + ) ic.start() @@ -294,15 +365,16 @@ def stop_endpoint(self, name): pid_file = os.path.join(endpoint_dir, "daemon.pid") pid_check = self.check_pidfile(pid_file) - # The process is active if the PID file exists and the process it points to is a funcx-endpoint - if pid_check['active']: + # The process is active if the PID file exists and the process it points to is + # a funcx-endpoint + if pid_check["active"]: self.logger.debug(f"{self.name} has a daemon.pid file") pid = None - with open(pid_file, 'r') as f: + with open(pid_file) as f: pid = int(f.read()) # Attempt terminating try: - self.logger.debug("Signalling process: {}".format(pid)) + self.logger.debug(f"Signalling process: {pid}") # For all the processes, including the deamon and its child process tree # Send SIGTERM to the processes # Wait for 200ms @@ -322,25 +394,25 @@ def stop_endpoint(self, name): pass # Wait to confirm that the pid file disappears if not os.path.exists(pid_file): - self.logger.info("Endpoint <{}> is now stopped".format(self.name)) + self.logger.info(f"Endpoint <{self.name}> is now stopped") except OSError: - self.logger.warning("Endpoint <{}> could not be terminated".format(self.name)) - self.logger.warning("Attempting Endpoint <{}> cleanup".format(self.name)) + self.logger.warning(f"Endpoint <{self.name}> could not be terminated") + self.logger.warning(f"Attempting Endpoint <{self.name}> cleanup") os.remove(pid_file) sys.exit(-1) # The process is not active, but the PID file exists and needs to be deleted - elif pid_check['exists']: + elif pid_check["exists"]: self.pidfile_cleanup(pid_file) else: - self.logger.info("Endpoint <{}> is not active.".format(self.name)) + self.logger.info(f"Endpoint <{self.name}> is not active.") def delete_endpoint(self, name): self.name = name endpoint_dir = os.path.join(self.funcx_dir, self.name) if not os.path.exists(endpoint_dir): - self.logger.warning("Endpoint <{}> does not exist".format(self.name)) + self.logger.warning(f"Endpoint <{self.name}> does not exist") sys.exit(-1) # stopping the endpoint should handle all of the process cleanup before @@ -348,10 +420,10 @@ def delete_endpoint(self, name): self.stop_endpoint(self.name) shutil.rmtree(endpoint_dir) - self.logger.info("Endpoint <{}> has been deleted.".format(self.name)) + self.logger.info(f"Endpoint <{self.name}> has been deleted.") def check_pidfile(self, filepath): - """ Helper function to identify possible dead endpoints + """Helper function to identify possible dead endpoints Returns a record with 'exists' and 'active' fields indicating whether the pidfile exists, and whether the process is active if it does exist @@ -363,12 +435,9 @@ def check_pidfile(self, filepath): Path to the pidfile """ if not os.path.exists(filepath): - return { - "exists": False, - "active": False - } + return {"exists": False, "active": False} - pid = int(open(filepath, 'r').read().strip()) + pid = int(open(filepath).read().strip()) active = False try: @@ -380,40 +449,37 @@ def check_pidfile(self, filepath): # it means the endpoint has been terminated without proper cleanup active = True - return { - "exists": True, - "active": active - } + return {"exists": True, "active": active} def pidfile_cleanup(self, filepath): os.remove(filepath) - self.logger.info("Endpoint <{}> has been cleaned up.".format(self.name)) + self.logger.info(f"Endpoint <{self.name}> has been cleaned up.") def list_endpoints(self): table = texttable.Texttable() - headings = ['Endpoint Name', 'Status', 'Endpoint ID'] + headings = ["Endpoint Name", "Status", "Endpoint ID"] table.header(headings) - config_files = glob.glob(os.path.join(self.funcx_dir, '*', 'config.py')) + config_files = glob.glob(os.path.join(self.funcx_dir, "*", "config.py")) for config_file in config_files: endpoint_dir = os.path.dirname(config_file) endpoint_name = os.path.basename(endpoint_dir) - status = 'Initialized' + status = "Initialized" endpoint_id = None - endpoint_json = os.path.join(endpoint_dir, 'endpoint.json') + endpoint_json = os.path.join(endpoint_dir, "endpoint.json") if os.path.exists(endpoint_json): - with open(endpoint_json, 'r') as f: + with open(endpoint_json) as f: endpoint_info = json.load(f) - endpoint_id = endpoint_info['endpoint_id'] - pid_check = self.check_pidfile(os.path.join(endpoint_dir, 'daemon.pid')) - if pid_check['active']: - status = 'Running' - elif pid_check['exists']: - status = 'Disconnected' + endpoint_id = endpoint_info["endpoint_id"] + pid_check = self.check_pidfile(os.path.join(endpoint_dir, "daemon.pid")) + if pid_check["active"]: + status = "Running" + elif pid_check["exists"]: + status = "Disconnected" else: - status = 'Stopped' + status = "Stopped" table.add_row([endpoint_name, status, endpoint_id]) diff --git a/funcx_endpoint/funcx_endpoint/endpoint/interchange.py b/funcx_endpoint/funcx_endpoint/endpoint/interchange.py index dbb24cb51..ff97ee8f3 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/interchange.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/interchange.py @@ -1,38 +1,39 @@ #!/usr/bin/env python import argparse -from typing import Tuple, Dict - -import zmq +import logging import os -import sys -import platform -import random -import time import pickle -import logging +import platform import queue -import threading -import json -import daemon -import collections -from retry.api import retry_call import signal -from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue +import sys +import threading +import time +from queue import Queue +from typing import Tuple +import zmq from parsl.executors.errors import ScalingFailed from parsl.version import VERSION as PARSL_VERSION +from retry.api import retry_call -from funcx_endpoint.executors.high_throughput.messages import Message, COMMAND_TYPES, MessageType, Task -from funcx_endpoint.executors.high_throughput.messages import EPStatusReport, Heartbeat, TaskStatusCode, ResultsAck -from funcx.sdk.client import FuncXClient -from funcx import set_file_logger from funcx import __version__ as funcx_sdk_version -from funcx_endpoint import __version__ as funcx_endpoint_version -from funcx_endpoint.executors.high_throughput.interchange_task_dispatch import naive_interchange_task_dispatch +from funcx import set_file_logger +from funcx.sdk.client import FuncXClient from funcx.serialize import FuncXSerializer -from funcx_endpoint.endpoint.taskqueue import TaskQueue +from funcx_endpoint import __version__ as funcx_endpoint_version from funcx_endpoint.endpoint.register_endpoint import register_endpoint -from queue import Queue +from funcx_endpoint.endpoint.taskqueue import TaskQueue +from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue +from funcx_endpoint.executors.high_throughput.messages import ( + COMMAND_TYPES, + Heartbeat, + Message, + MessageType, + ResultsAck, + Task, + TaskStatusCode, +) LOOP_SLOWDOWN = 0.0 # in seconds HEARTBEAT_CODE = (2 ** 32) - 1 @@ -40,18 +41,17 @@ class ShutdownRequest(Exception): - """ Exception raised when any async component receives a ShutdownRequest - """ + """Exception raised when any async component receives a ShutdownRequest""" def __init__(self): self.tstamp = time.time() def __repr__(self): - return "Shutdown request received at {}".format(self.tstamp) + return f"Shutdown request received at {self.tstamp}" class ManagerLost(Exception): - """ Task lost due to worker loss. Worker is considered lost when multiple heartbeats + """Task lost due to worker loss. Worker is considered lost when multiple heartbeats have been missed. """ @@ -60,12 +60,11 @@ def __init__(self, worker_id): self.tstamp = time.time() def __repr__(self): - return "Task failure due to loss of worker {}".format(self.worker_id) + return f"Task failure due to loss of worker {self.worker_id}" class BadRegistration(Exception): - ''' A new Manager tried to join the executor with a BadRegistration message - ''' + """A new Manager tried to join the executor with a BadRegistration message""" def __init__(self, worker_id, critical=False): self.worker_id = worker_id @@ -73,12 +72,11 @@ def __init__(self, worker_id, critical=False): self.handled = "critical" if critical else "suppressed" def __repr__(self): - return "Manager:{} caused a {} failure".format(self.worker_id, - self.handled) + return f"Manager:{self.worker_id} caused a {self.handled} failure" -class EndpointInterchange(object): - """ Interchange is a task orchestrator for distributed systems. +class EndpointInterchange: + """Interchange is a task orchestrator for distributed systems. 1. Asynchronously queue large volume of tasks (>100K) 2. Allow for workers to join and leave the union @@ -90,23 +88,24 @@ class EndpointInterchange(object): TODO: We most likely need a PUB channel to send out global commandzs, like shutdown """ - def __init__(self, - config, - client_address="127.0.0.1", - interchange_address="127.0.0.1", - client_ports: Tuple[int, int, int] = (50055, 50056, 50057), - launch_cmd=None, - logdir=".", - logging_level=logging.INFO, - endpoint_id=None, - keys_dir=".curve", - suppress_failure=True, - endpoint_dir=".", - endpoint_name="default", - reg_info=None, - funcx_client_options=None, - results_ack_handler=None, - ): + def __init__( + self, + config, + client_address="127.0.0.1", + interchange_address="127.0.0.1", + client_ports: Tuple[int, int, int] = (50055, 50056, 50057), + launch_cmd=None, + logdir=".", + logging_level=logging.INFO, + endpoint_id=None, + keys_dir=".curve", + suppress_failure=True, + endpoint_dir=".", + endpoint_name="default", + reg_info=None, + funcx_client_options=None, + results_ack_handler=None, + ): """ Parameters ---------- @@ -114,10 +113,12 @@ def __init__(self, Funcx config object that describes how compute should be provisioned client_address : str - The ip address at which the parsl client can be reached. Default: "127.0.0.1" + The ip address at which the parsl client can be reached. + Default: "127.0.0.1" interchange_address : str - The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1" + The ip address at which the workers will be able to reach the Interchange. + Default: "127.0.0.1" client_ports : Tuple[int, int, int] The ports at which the client can be reached @@ -132,14 +133,15 @@ def __init__(self, Logging level as defined in the logging module. Default: logging.INFO (20) keys_dir : str - Directory from where keys used for communicating with the funcX service (forwarders) - are stored + Directory from where keys used for communicating with the funcX + service (forwarders) are stored endpoint_id : str Identity string that identifies the endpoint to the broker suppress_failure : Bool - When set to True, the interchange will attempt to suppress failures. Default: False + When set to True, the interchange will attempt to suppress failures. + Default: False endpoint_dir : str Endpoint directory path to store registration info in @@ -148,7 +150,8 @@ def __init__(self, Name of endpoint reg_info : Dict - Registration info from initial registration on endpoint start, if it succeeded + Registration info from initial registration on endpoint start, if it + succeeded funcx_client_options : Dict FuncXClient initialization options @@ -161,10 +164,18 @@ def __init__(self, global logger - logger = set_file_logger(os.path.join(self.logdir, "endpoint.log"), name="funcx_endpoint", level=logging_level) - logger.info("Initializing EndpointInterchange process with Endpoint ID: {}".format(endpoint_id)) + logger = set_file_logger( + os.path.join(self.logdir, "endpoint.log"), + name="funcx_endpoint", + level=logging_level, + ) + logger.info( + "Initializing EndpointInterchange process with Endpoint ID: {}".format( + endpoint_id + ) + ) self.config = config - logger.info("Got config : {}".format(config)) + logger.info(f"Got config : {config}") self.client_address = client_address self.interchange_address = interchange_address @@ -199,24 +210,26 @@ def __init__(self, self.results_ack_handler = results_ack_handler - logger.info("Interchange address is {}".format(self.interchange_address)) + logger.info(f"Interchange address is {self.interchange_address}") self.endpoint_id = endpoint_id - self.current_platform = {'parsl_v': PARSL_VERSION, - 'python_v': "{}.{}.{}".format(sys.version_info.major, - sys.version_info.minor, - sys.version_info.micro), - 'libzmq_v': zmq.zmq_version(), - 'pyzmq_v': zmq.pyzmq_version(), - 'os': platform.system(), - 'hname': platform.node(), - 'funcx_sdk_version': funcx_sdk_version, - 'funcx_endpoint_version': funcx_endpoint_version, - 'registration': self.endpoint_id, - 'dir': os.getcwd()} - - logger.info("Platform info: {}".format(self.current_platform)) + self.current_platform = { + "parsl_v": PARSL_VERSION, + "python_v": "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro + ), + "libzmq_v": zmq.zmq_version(), + "pyzmq_v": zmq.pyzmq_version(), + "os": platform.system(), + "hname": platform.node(), + "funcx_sdk_version": funcx_sdk_version, + "funcx_endpoint_version": funcx_endpoint_version, + "registration": self.endpoint_id, + "dir": os.getcwd(), + } + + logger.info(f"Platform info: {self.current_platform}") try: self.load_config() except Exception: @@ -229,8 +242,7 @@ def __init__(self, self._test_start = False def load_config(self): - """ Load the config - """ + """Load the config""" logger.info("Loading endpoint local config") self.results_passthrough = mpQueue() @@ -242,7 +254,7 @@ def load_config(self): executor.endpoint_id = self.endpoint_id else: if not executor.endpoint_id == self.endpoint_id: - raise Exception('InconsistentEndpointId') + raise Exception("InconsistentEndpointId") self.executors[executor.label] = executor if executor.run_dir is None: executor.run_dir = self.logdir @@ -250,15 +262,21 @@ def load_config(self): def start_executors(self): logger.info("Starting Executors") for executor in self.config.executors: - if hasattr(executor, 'passthrough') and executor.passthrough is True: + if hasattr(executor, "passthrough") and executor.passthrough is True: executor.start(results_passthrough=self.results_passthrough) def apply_reg_info(self, reg_info): - self.client_address = reg_info['public_ip'] - self.client_ports = reg_info['tasks_port'], reg_info['results_port'], reg_info['commands_port'], + self.client_address = reg_info["public_ip"] + self.client_ports = ( + reg_info["tasks_port"], + reg_info["results_port"], + reg_info["commands_port"], + ) def register_endpoint(self): - reg_info = register_endpoint(self.funcx_client, self.endpoint_id, self.endpoint_dir, self.endpoint_name) + reg_info = register_endpoint( + self.funcx_client, self.endpoint_id, self.endpoint_dir, self.endpoint_name + ) self.apply_reg_info(reg_info) return reg_info @@ -286,17 +304,21 @@ def _task_puller_loop(self, quiesce_event): task_counter = 0 # Create the incoming queue in the thread to keep # zmq.context in the same thread. zmq.context is not thread-safe - self.task_incoming = TaskQueue(self.client_address, - port=self.client_ports[0], - identity=self.endpoint_id, - mode='client', - set_hwm=True, - keys_dir=self.keys_dir, - RCVTIMEO=1000, - linger=0) - - self.task_incoming.put('forwarder', pickle.dumps(self.current_platform)) - logger.info(f"Task incoming on tcp://{self.client_address}:{self.client_ports[0]}") + self.task_incoming = TaskQueue( + self.client_address, + port=self.client_ports[0], + identity=self.endpoint_id, + mode="client", + set_hwm=True, + keys_dir=self.keys_dir, + RCVTIMEO=1000, + linger=0, + ) + + self.task_incoming.put("forwarder", pickle.dumps(self.current_platform)) + logger.info( + f"Task incoming on tcp://{self.client_address}:{self.client_ports[0]}" + ) self.last_heartbeat = time.time() @@ -304,7 +326,10 @@ def _task_puller_loop(self, quiesce_event): try: if int(time.time() - self.last_heartbeat) > self.heartbeat_threshold: - logger.critical("[TASK_PULL_THREAD] Missed too many heartbeats. Setting quiesce event.") + logger.critical( + "[TASK_PULL_THREAD] Missed too many heartbeats. " + "Setting quiesce event." + ) quiesce_event.set() break @@ -314,19 +339,27 @@ def _task_puller_loop(self, quiesce_event): self.last_heartbeat = time.time() except zmq.Again: # We just timed out while attempting to receive - logger.debug("[TASK_PULL_THREAD] {} tasks in internal queue".format(self.total_pending_task_count)) + logger.debug( + "[TASK_PULL_THREAD] {} tasks in internal queue".format( + self.total_pending_task_count + ) + ) continue except Exception: - logger.exception("[TASK_PULL_THREAD] Unknown exception while waiting for tasks") + logger.exception( + "[TASK_PULL_THREAD] Unknown exception while waiting for tasks" + ) # YADU: TODO We need to do the routing here try: msg = Message.unpack(raw_msg) except Exception: - logger.exception("[TASK_PULL_THREAD] Failed to unpack message from forwarder") + logger.exception( + "[TASK_PULL_THREAD] Failed to unpack message from forwarder" + ) pass - if msg == 'STOP': + if msg == "STOP": self._kill_event.set() quiesce_event.set() break @@ -338,40 +371,53 @@ def _task_puller_loop(self, quiesce_event): logger.info(f"[TASK_PULL_THREAD] Received task:{msg.task_id}") self.pending_task_queue.put(msg) self.total_pending_task_count += 1 - self.task_status_deltas[msg.task_id] = TaskStatusCode.WAITING_FOR_NODES + self.task_status_deltas[ + msg.task_id + ] = TaskStatusCode.WAITING_FOR_NODES task_counter += 1 - logger.debug(f"[TASK_PULL_THREAD] Task counter:{task_counter} Pending Tasks: {self.total_pending_task_count}") + logger.debug( + "[TASK_PULL_THREAD] Task counter:%s Pending Tasks: %s", + task_counter, + self.total_pending_task_count, + ) elif isinstance(msg, ResultsAck): self.results_ack_handler.ack(msg.task_id) else: - logger.warning(f"[TASK_PULL_THREAD] Unknown message type received: {msg}") + logger.warning( + f"[TASK_PULL_THREAD] Unknown message type received: {msg}" + ) except Exception: logger.exception("[TASK_PULL_THREAD] Something really bad happened") continue def get_container(self, container_uuid): - """ Get the container image location if it is not known to the interchange""" + """Get the container image location if it is not known to the interchange""" if container_uuid not in self.containers: - if container_uuid == 'RAW' or not container_uuid: - self.containers[container_uuid] = 'RAW' + if container_uuid == "RAW" or not container_uuid: + self.containers[container_uuid] = "RAW" else: try: - container = self.funcx_client.get_container(container_uuid, self.config.container_type) + container = self.funcx_client.get_container( + container_uuid, self.config.container_type + ) except Exception: - logger.exception("[FETCH_CONTAINER] Unable to resolve container location") - self.containers[container_uuid] = 'RAW' + logger.exception( + "[FETCH_CONTAINER] Unable to resolve container location" + ) + self.containers[container_uuid] = "RAW" else: - logger.info("[FETCH_CONTAINER] Got container info: {}".format(container)) - self.containers[container_uuid] = container.get('location', 'RAW') + logger.info(f"[FETCH_CONTAINER] Got container info: {container}") + self.containers[container_uuid] = container.get("location", "RAW") return self.containers[container_uuid] def _command_server(self, quiesce_event): - """ Command server to run async command to the interchange + """Command server to run async command to the interchange - We want to be able to receive the following not yet implemented/updated commands: + We want to be able to receive the following not yet implemented/updated + commands: - OutstandingCount - ListManagers (get outstanding broken down by manager) - HoldWorker @@ -389,17 +435,21 @@ def _command_server(self, quiesce_event): logger.info("[COMMAND] Thread loop exiting") def _command_server_loop(self, quiesce_event): - self.command_channel = TaskQueue(self.client_address, - port=self.client_ports[2], - identity=self.endpoint_id, - mode='client', - RCVTIMEO=1000, # in milliseconds - keys_dir=self.keys_dir, - set_hwm=True, - linger=0) + self.command_channel = TaskQueue( + self.client_address, + port=self.client_ports[2], + identity=self.endpoint_id, + mode="client", + RCVTIMEO=1000, # in milliseconds + keys_dir=self.keys_dir, + set_hwm=True, + linger=0, + ) # TODO :Register all channels with the authentication string. - self.command_channel.put('forwarder', pickle.dumps({"registration": self.endpoint_id})) + self.command_channel.put( + "forwarder", pickle.dumps({"registration": self.endpoint_id}) + ) while not quiesce_event.is_set(): try: @@ -414,10 +464,12 @@ def _command_server_loop(self, quiesce_event): if command.type is MessageType.HEARTBEAT_REQ: logger.info("[COMMAND] Received synchonous HEARTBEAT_REQ from hub") - logger.info(f"[COMMAND] Replying with Heartbeat({self.endpoint_id})") + logger.info( + f"[COMMAND] Replying with Heartbeat({self.endpoint_id})" + ) reply = Heartbeat(self.endpoint_id) - logger.debug("[COMMAND] Reply: {}".format(reply)) + logger.debug(f"[COMMAND] Reply: {reply}") self.command_channel.put(reply.pack()) except zmq.Again: @@ -428,7 +480,9 @@ def quiesce(self): """Temporarily stop everything on the interchange in order to reach a consistent state before attempting to start again. This must be called on the main thread """ - logger.info("Interchange Quiesce in progress (stopping and joining all threads)") + logger.info( + "Interchange Quiesce in progress (stopping and joining all threads)" + ) self._quiesce_event.set() self._task_puller_thread.join() self._command_thread.join() @@ -449,8 +503,8 @@ def stop(self): # TODO: shut down executors gracefully - # kill_event must be set before quiesce_event because we need to guarantee that once - # the quiesce is complete, the interchange will not try to start again + # kill_event must be set before quiesce_event because we need to guarantee that + # once the quiesce is complete, the interchange will not try to start again self._kill_event.set() self._quiesce_event.set() @@ -465,8 +519,7 @@ def handle_sigterm(self, sig_num, curr_stack_frame): sys.exit(1) def start(self): - """ Start the Interchange - """ + """Start the Interchange""" logger.info("Starting EndpointInterchange") signal.signal(signal.SIGTERM, self.handle_sigterm) @@ -493,20 +546,32 @@ def _start_threads_and_main(self): if not self.initial_registration_complete: # Register the endpoint logger.info("Running endpoint registration retry loop") - reg_info = retry_call(self.register_endpoint, delay=10, max_delay=300, backoff=1.2) - logger.info("Endpoint registered with UUID: {}".format(reg_info['endpoint_id'])) + reg_info = retry_call( + self.register_endpoint, delay=10, max_delay=300, backoff=1.2 + ) + logger.info( + "Endpoint registered with UUID: {}".format(reg_info["endpoint_id"]) + ) self.initial_registration_complete = False - logger.info("Attempting connection to client at {} on ports: {},{},{}".format( - self.client_address, self.client_ports[0], self.client_ports[1], self.client_ports[2])) + logger.info( + "Attempting connection to client at {} on ports: {},{},{}".format( + self.client_address, + self.client_ports[0], + self.client_ports[1], + self.client_ports[2], + ) + ) - self._task_puller_thread = threading.Thread(target=self.migrate_tasks_to_internal, - args=(self._quiesce_event, )) + self._task_puller_thread = threading.Thread( + target=self.migrate_tasks_to_internal, args=(self._quiesce_event,) + ) self._task_puller_thread.start() - self._command_thread = threading.Thread(target=self._command_server, - args=(self._quiesce_event, )) + self._command_thread = threading.Thread( + target=self._command_server, args=(self._quiesce_event,) + ) self._command_thread.start() try: @@ -518,26 +583,33 @@ def _start_threads_and_main(self): logger.info("[MAIN] Thread loop exiting") def _main_loop(self): - self.results_outgoing = TaskQueue(self.client_address, - port=self.client_ports[1], - identity=self.endpoint_id, - mode='client', - keys_dir=self.keys_dir, - # Fail immediately if results cannot be sent back - SNDTIMEO=0, - set_hwm=True, - linger=0) - self.results_outgoing.put('forwarder', pickle.dumps({"registration": self.endpoint_id})) + self.results_outgoing = TaskQueue( + self.client_address, + port=self.client_ports[1], + identity=self.endpoint_id, + mode="client", + keys_dir=self.keys_dir, + # Fail immediately if results cannot be sent back + SNDTIMEO=0, + set_hwm=True, + linger=0, + ) + self.results_outgoing.put( + "forwarder", pickle.dumps({"registration": self.endpoint_id}) + ) # TODO: this resend must happen after any endpoint re-registration to # ensure there are not unacked results left resend_results_messages = self.results_ack_handler.get_unacked_results_list() if len(resend_results_messages) > 0: - logger.info(f"[MAIN] Resending {len(resend_results_messages)} previously unacked results") + logger.info( + "[MAIN] Resending %s previously unacked results", + len(resend_results_messages), + ) # TODO: this should be a multipart send rather than a loop for results in resend_results_messages: - self.results_outgoing.put('forwarder', results) + self.results_outgoing.put("forwarder", results) executor = list(self.executors.values())[0] last = time.time() @@ -549,9 +621,12 @@ def _main_loop(self): try: # Adding results heartbeat to essentially force a TCP keepalive # without meddling with OS TCP keepalive defaults - self.results_outgoing.put('forwarder', b'HEARTBEAT') + self.results_outgoing.put("forwarder", b"HEARTBEAT") except Exception: - logger.exception("[MAIN] Sending heartbeat to the forwarder over the results channel has failed") + logger.exception( + "[MAIN] Sending heartbeat to the forwarder over the results " + "channel has failed" + ) raise self.results_ack_handler.check_ack_counts() @@ -562,7 +637,9 @@ def _main_loop(self): except queue.Empty: pass except Exception: - logger.exception("[MAIN] Unhandled issue while waiting for pending tasks") + logger.exception( + "[MAIN] Unhandled issue while waiting for pending tasks" + ) pass try: @@ -573,19 +650,22 @@ def _main_loop(self): self.results_ack_handler.put(task_id, results["message"]) logger.info(f"Passing result to forwarder for task {task_id}") - # results will be a pickled dict with task_id, container_id, and results/exception - self.results_outgoing.put('forwarder', results["message"]) + # results will be a pickled dict with task_id, container_id, + # and results/exception + self.results_outgoing.put("forwarder", results["message"]) except queue.Empty: pass except Exception: - logger.exception("[MAIN] Something broke while forwarding results from executor to forwarder queues") + logger.exception( + "[MAIN] Something broke while forwarding results from executor " + "to forwarder queues" + ) continue def get_status_report(self): - """ Get utilization numbers - """ + """Get utilization numbers""" total_cores = 0 total_mem = 0 core_hrs = 0 @@ -597,36 +677,43 @@ def get_status_report(self): live_workers = self.get_total_live_workers() for manager in self._ready_manager_queue: - total_cores += self._ready_manager_queue[manager]['cores'] - total_mem += self._ready_manager_queue[manager]['mem'] - active_dur = abs(time.time() - self._ready_manager_queue[manager]['reg_time']) + total_cores += self._ready_manager_queue[manager]["cores"] + total_mem += self._ready_manager_queue[manager]["mem"] + active_dur = abs( + time.time() - self._ready_manager_queue[manager]["reg_time"] + ) core_hrs += (active_dur * total_cores) / 3600 - if self._ready_manager_queue[manager]['active']: + if self._ready_manager_queue[manager]["active"]: active_managers += 1 - free_capacity += self._ready_manager_queue[manager]['free_capacity']['total_workers'] - - result_package = {'task_id': -2, - 'info': {'total_cores': total_cores, - 'total_mem': total_mem, - 'new_core_hrs': core_hrs - self.last_core_hr_counter, - 'total_core_hrs': round(core_hrs, 2), - 'managers': num_managers, - 'active_managers': active_managers, - 'total_workers': live_workers, - 'idle_workers': free_capacity, - 'pending_tasks': pending_tasks, - 'outstanding_tasks': outstanding_tasks, - 'worker_mode': self.config.worker_mode, - 'scheduler_mode': self.config.scheduler_mode, - 'scaling_enabled': self.config.scaling_enabled, - 'mem_per_worker': self.config.mem_per_worker, - 'cores_per_worker': self.config.cores_per_worker, - 'prefetch_capacity': self.config.prefetch_capacity, - 'max_blocks': self.config.provider.max_blocks, - 'min_blocks': self.config.provider.min_blocks, - 'max_workers_per_node': self.config.max_workers_per_node, - 'nodes_per_block': self.config.provider.nodes_per_block - }} + free_capacity += self._ready_manager_queue[manager]["free_capacity"][ + "total_workers" + ] + + result_package = { + "task_id": -2, + "info": { + "total_cores": total_cores, + "total_mem": total_mem, + "new_core_hrs": core_hrs - self.last_core_hr_counter, + "total_core_hrs": round(core_hrs, 2), + "managers": num_managers, + "active_managers": active_managers, + "total_workers": live_workers, + "idle_workers": free_capacity, + "pending_tasks": pending_tasks, + "outstanding_tasks": outstanding_tasks, + "worker_mode": self.config.worker_mode, + "scheduler_mode": self.config.scheduler_mode, + "scaling_enabled": self.config.scaling_enabled, + "mem_per_worker": self.config.mem_per_worker, + "cores_per_worker": self.config.cores_per_worker, + "prefetch_capacity": self.config.prefetch_capacity, + "max_blocks": self.config.provider.max_blocks, + "min_blocks": self.config.provider.min_blocks, + "max_workers_per_node": self.config.max_workers_per_node, + "nodes_per_block": self.config.provider.nodes_per_block, + }, + } self.last_core_hr_counter = core_hrs return result_package @@ -642,18 +729,28 @@ def scale_out(self, blocks=1, task_type=None): if self.config.provider: self._block_counter += 1 external_block_id = str(self._block_counter) - if not task_type and self.config.scheduler_mode == 'hard': - launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type='RAW') + if not task_type and self.config.scheduler_mode == "hard": + launch_cmd = self.launch_cmd.format( + block_id=external_block_id, worker_type="RAW" + ) else: - launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type=task_type) + launch_cmd = self.launch_cmd.format( + block_id=external_block_id, worker_type=task_type + ) if not task_type: internal_block = self.config.provider.submit(launch_cmd, 1) else: - internal_block = self.config.provider.submit(launch_cmd, 1, task_type) - logger.debug("Launched block {}->{}".format(external_block_id, internal_block)) + internal_block = self.config.provider.submit( + launch_cmd, 1, task_type + ) + logger.debug(f"Launched block {external_block_id}->{internal_block}") if not internal_block: - raise(ScalingFailed(self.provider.label, - "Attempts to provision nodes via provider has failed")) + raise ( + ScalingFailed( + self.provider.label, + "Attempts to provision nodes via provider has failed", + ) + ) self.blocks[external_block_id] = internal_block self.block_id_map[internal_block] = external_block_id else: @@ -675,14 +772,22 @@ def scale_in(self, blocks=None, block_ids=None, task_type=None): if block_ids is None: block_ids = [] if task_type: - logger.info("Scaling in blocks of specific task type {}. Let the provider decide which to kill".format(task_type)) + logger.info( + "Scaling in blocks of specific task type %s. " + "Let the provider decide which to kill", + task_type, + ) if self.config.scaling_enabled and self.config.provider: to_kill, r = self.config.provider.cancel(blocks, task_type) - logger.info("Get the killed blocks: {}, and status: {}".format(to_kill, r)) + logger.info(f"Get the killed blocks: {to_kill}, and status: {r}") for job in to_kill: - logger.info("[scale_in] Getting the block_id map {} for job {}".format(self.block_id_map, job)) + logger.info( + "[scale_in] Getting the block_id map {} for job {}".format( + self.block_id_map, job + ) + ) block_id = self.block_id_map[job] - logger.info("[scale_in] Holding block {}".format(block_id)) + logger.info(f"[scale_in] Holding block {block_id}") self._hold_block(block_id) self.blocks.pop(block_id) return r @@ -706,13 +811,16 @@ def scale_in(self, blocks=None, block_ids=None, task_type=None): return r def provider_status(self): - """ Get status of all blocks from the provider - """ + """Get status of all blocks from the provider""" status = [] if self.config.provider: - logger.debug("[MAIN] Getting the status of {} blocks.".format(list(self.blocks.values()))) + logger.debug( + "[MAIN] Getting the status of {} blocks.".format( + list(self.blocks.values()) + ) + ) status = self.config.provider.status(list(self.blocks.values())) - logger.debug("[MAIN] The status is {}".format(status)) + logger.debug(f"[MAIN] The status is {status}") return status @@ -720,7 +828,8 @@ def provider_status(self): def starter(comm_q, *args, **kwargs): """Start the interchange process - The executor is expected to call this function. The args, kwargs match that of the Interchange.__init__ + The executor is expected to call this function. + The args, kwargs match that of the Interchange.__init__ """ # logger = multiprocessing.get_logger() ic = EndpointInterchange(*args, **kwargs) @@ -732,42 +841,61 @@ def starter(comm_q, *args, **kwargs): def cli_run(): parser = argparse.ArgumentParser() - parser.add_argument("-c", "--client_address", required=True, - help="Client address") - parser.add_argument("--client_ports", required=True, - help="client ports as a triple of outgoing,incoming,command") - parser.add_argument("--worker_port_range", - help="Worker port range as a tuple") - parser.add_argument("-l", "--logdir", default="./parsl_worker_logs", - help="Parsl worker log directory") - parser.add_argument("--worker_ports", default=None, - help="OPTIONAL, pair of workers ports to listen on, eg --worker_ports=50001,50005") - parser.add_argument("--suppress_failure", action='store_true', - help="Enables suppression of failures") - parser.add_argument("--endpoint_id", required=True, - help="Endpoint ID, used to identify the endpoint to the remote broker") - parser.add_argument("--hb_threshold", - help="Heartbeat threshold in seconds") - parser.add_argument("--config", default=None, - help="Configuration object that describes provisioning") - parser.add_argument("-d", "--debug", action='store_true', - help="Enables debug logging") + parser.add_argument("-c", "--client_address", required=True, help="Client address") + parser.add_argument( + "--client_ports", + required=True, + help="client ports as a triple of outgoing,incoming,command", + ) + parser.add_argument("--worker_port_range", help="Worker port range as a tuple") + parser.add_argument( + "-l", + "--logdir", + default="./parsl_worker_logs", + help="Parsl worker log directory", + ) + parser.add_argument( + "--worker_ports", + default=None, + help="OPTIONAL, pair of workers ports to listen on, " + "e.g. --worker_ports=50001,50005", + ) + parser.add_argument( + "--suppress_failure", + action="store_true", + help="Enables suppression of failures", + ) + parser.add_argument( + "--endpoint_id", + required=True, + help="Endpoint ID, used to identify the endpoint to the remote broker", + ) + parser.add_argument("--hb_threshold", help="Heartbeat threshold in seconds") + parser.add_argument( + "--config", + default=None, + help="Configuration object that describes provisioning", + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="Enables debug logging" + ) print("Starting HTEX Intechange") args = parser.parse_args() optionals = {} - optionals['suppress_failure'] = args.suppress_failure - optionals['logdir'] = os.path.abspath(args.logdir) - optionals['client_address'] = args.client_address - optionals['client_ports'] = [int(i) for i in args.client_ports.split(',')] - optionals['endpoint_id'] = args.endpoint_id + optionals["suppress_failure"] = args.suppress_failure + optionals["logdir"] = os.path.abspath(args.logdir) + optionals["client_address"] = args.client_address + optionals["client_ports"] = [int(i) for i in args.client_ports.split(",")] + optionals["endpoint_id"] = args.endpoint_id # DEBUG ONLY : TODO: FIX if args.config is None: - from funcx_endpoint.endpoint.utils.config import Config from parsl.providers import LocalProvider + from funcx_endpoint.endpoint.utils.config import Config + config = Config( worker_debug=True, scaling_enabled=True, @@ -777,18 +905,20 @@ def cli_run(): max_blocks=1, ), max_workers_per_node=2, - funcx_service_address='http://127.0.0.1:8080' + funcx_service_address="http://127.0.0.1:8080", ) - optionals['config'] = config + optionals["config"] = config else: - optionals['config'] = args.config + optionals["config"] = args.config if args.debug: - optionals['logging_level'] = logging.DEBUG + optionals["logging_level"] = logging.DEBUG if args.worker_ports: - optionals['worker_ports'] = [int(i) for i in args.worker_ports.split(',')] + optionals["worker_ports"] = [int(i) for i in args.worker_ports.split(",")] if args.worker_port_range: - optionals['worker_port_range'] = [int(i) for i in args.worker_port_range.split(',')] + optionals["worker_port_range"] = [ + int(i) for i in args.worker_port_range.split(",") + ] ic = EndpointInterchange(**optionals) ic.start() diff --git a/funcx_endpoint/funcx_endpoint/endpoint/register_endpoint.py b/funcx_endpoint/funcx_endpoint/endpoint/register_endpoint.py index 1b36f8d34..e32348bce 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/register_endpoint.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/register_endpoint.py @@ -1,13 +1,15 @@ -import os import json import logging +import os import funcx_endpoint namespace_logger = logging.getLogger(__name__) -def register_endpoint(funcx_client, endpoint_uuid, endpoint_dir, endpoint_name, logger=None): +def register_endpoint( + funcx_client, endpoint_uuid, endpoint_dir, endpoint_name, logger=None +): """Register the endpoint and return the registration info. This function needs to be isolated (including the logger which is passed in) so that the function can both be called from the endpoint start process as well as the daemon process @@ -35,31 +37,37 @@ def register_endpoint(funcx_client, endpoint_uuid, endpoint_dir, endpoint_name, logger.debug("Attempting registration") logger.debug(f"Trying with eid : {endpoint_uuid}") - reg_info = funcx_client.register_endpoint(endpoint_name, - endpoint_uuid, - endpoint_version=funcx_endpoint.__version__) + reg_info = funcx_client.register_endpoint( + endpoint_name, endpoint_uuid, endpoint_version=funcx_endpoint.__version__ + ) # this is a backup error handler in case an endpoint ID is not sent back # from the service or a bad ID is sent back - if 'endpoint_id' not in reg_info: - raise Exception("Endpoint ID was not included in the service's registration response.") - elif not isinstance(reg_info['endpoint_id'], str): + if "endpoint_id" not in reg_info: + raise Exception( + "Endpoint ID was not included in the service's registration response." + ) + elif not isinstance(reg_info["endpoint_id"], str): raise Exception("Endpoint ID sent by the service was not a string.") # NOTE: While all registration info is saved to endpoint.json, only the # endpoint UUID is reused from this file. The latest forwarder URI is used # every time we fetch registration info and register - with open(os.path.join(endpoint_dir, 'endpoint.json'), 'w+') as fp: + with open(os.path.join(endpoint_dir, "endpoint.json"), "w+") as fp: json.dump(reg_info, fp) - logger.debug("Registration info written to {}".format(os.path.join(endpoint_dir, 'endpoint.json'))) + logger.debug( + "Registration info written to {}".format( + os.path.join(endpoint_dir, "endpoint.json") + ) + ) - certs_dir = os.path.join(endpoint_dir, 'certificates') + certs_dir = os.path.join(endpoint_dir, "certificates") os.makedirs(certs_dir, exist_ok=True) - server_keyfile = os.path.join(certs_dir, 'server.key') + server_keyfile = os.path.join(certs_dir, "server.key") logger.debug(f"Writing server key to {server_keyfile}") try: - with open(server_keyfile, 'w') as f: - f.write(reg_info['forwarder_pubkey']) + with open(server_keyfile, "w") as f: + f.write(reg_info["forwarder_pubkey"]) os.chmod(server_keyfile, 0o600) except Exception: logger.exception("Failed to write server certificate") diff --git a/funcx_endpoint/funcx_endpoint/endpoint/results_ack.py b/funcx_endpoint/funcx_endpoint/endpoint/results_ack.py index e222445e0..15abb92a8 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/results_ack.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/results_ack.py @@ -1,23 +1,22 @@ import logging -import time import os import pickle +import time # The logger path needs to start with endpoint. while the current path # start with funcx_endpoint.endpoint. logger = logging.getLogger("endpoint.results_ack") -class ResultsAckHandler(): +class ResultsAckHandler: """ Tracks task results by task ID, discarding results after they have been ack'ed """ def __init__(self, endpoint_dir): - """ Initialize results storage and timing for log updates - """ + """Initialize results storage and timing for log updates""" self.endpoint_dir = endpoint_dir - self.data_path = os.path.join(self.endpoint_dir, 'unacked_results.p') + self.data_path = os.path.join(self.endpoint_dir, "unacked_results.p") self.unacked_results = {} # how frequently to log info about acked and unacked results @@ -26,7 +25,7 @@ def __init__(self, endpoint_dir): self.acked_count = 0 def put(self, task_id, message): - """ Put sent task result into Unacked Dict + """Put sent task result into Unacked Dict Parameters ---------- @@ -39,7 +38,7 @@ def put(self, task_id, message): self.unacked_results[task_id] = message def ack(self, task_id): - """ Ack a task result that was sent. Nothing happens if the task ID is not + """Ack a task result that was sent. Nothing happens if the task ID is not present in the Unacked Dict Parameters @@ -54,18 +53,22 @@ def ack(self, task_id): logger.debug(f"Acked task {task_id}, Unacked count: {unacked_count}") def check_ack_counts(self): - """ Log the number of currently Unacked tasks and the tasks Acked since + """Log the number of currently Unacked tasks and the tasks Acked since the last check """ now = time.time() if now - self.last_log_timestamp > self.log_period: unacked_count = len(self.unacked_results) - logger.info(f"Unacked count: {unacked_count}, Acked results since last check {self.acked_count}") + logger.info( + "Unacked count: %s, Acked results since last check %s", + unacked_count, + self.acked_count, + ) self.acked_count = 0 self.last_log_timestamp = now def get_unacked_results_list(self): - """ Get a list of unacked results messages that can be used for resending + """Get a list of unacked results messages that can be used for resending Returns ------- @@ -75,17 +78,19 @@ def get_unacked_results_list(self): return list(self.unacked_results.values()) def persist(self): - """ Save unacked results to disk - """ - with open(self.data_path, 'wb') as fp: + """Save unacked results to disk""" + with open(self.data_path, "wb") as fp: pickle.dump(self.unacked_results, fp) def load(self): - """ Load unacked results from disk - """ + """Load unacked results from disk""" try: if os.path.exists(self.data_path): - with open(self.data_path, 'rb') as fp: + with open(self.data_path, "rb") as fp: self.unacked_results = pickle.load(fp) except pickle.UnpicklingError: - logger.warning(f"Cached results {self.data_path} appear to be corrupt. Proceeding without loading cached results") + logger.warning( + "Cached results %s appear to be corrupt. " + "Proceeding without loading cached results", + self.data_path, + ) diff --git a/funcx_endpoint/funcx_endpoint/endpoint/taskqueue.py b/funcx_endpoint/funcx_endpoint/endpoint/taskqueue.py index ab1ab69c2..29195239b 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/taskqueue.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/taskqueue.py @@ -1,30 +1,31 @@ +import logging import os +import uuid + import zmq import zmq.auth from zmq.auth.thread import ThreadAuthenticator -import uuid -import logging -import time logger = logging.getLogger(__name__) -class TaskQueue(object): - """ Outgoing task queue from the executor to the Interchange - """ - - def __init__(self, - address: str, - port: int = 55001, - identity: str = str(uuid.uuid4()), - zmq_context=None, - set_hwm=False, - RCVTIMEO=None, - SNDTIMEO=None, - linger=None, - ironhouse: bool = False, - keys_dir: str = os.path.abspath('.curve'), - mode: str = 'client'): +class TaskQueue: + """Outgoing task queue from the executor to the Interchange""" + + def __init__( + self, + address: str, + port: int = 55001, + identity: str = str(uuid.uuid4()), + zmq_context=None, + set_hwm=False, + RCVTIMEO=None, + SNDTIMEO=None, + linger=None, + ironhouse: bool = False, + keys_dir: str = os.path.abspath(".curve"), + mode: str = "client", + ): """ Parameters ---------- @@ -37,7 +38,8 @@ def __init__(self, identity : str Applies only to clients, where the identity must match the endpoint uuid. - This will be utf-8 encoded on the wire. A random uuid4 string is set by default. + This will be utf-8 encoded on the wire. A random uuid4 string is set by + default. mode: string Either 'client' or 'server' @@ -59,21 +61,26 @@ def __init__(self, self.ironhouse = ironhouse self.keys_dir = keys_dir - assert self.mode in ['client', 'server'], "Only two modes are supported: client, server" + assert self.mode in [ + "client", + "server", + ], "Only two modes are supported: client, server" - if self.mode == 'server': + if self.mode == "server": print("Configuring server") self.zmq_socket = self.context.socket(zmq.ROUTER) self.zmq_socket.set(zmq.ROUTER_MANDATORY, 1) self.zmq_socket.set(zmq.ROUTER_HANDOVER, 1) print("Setting up auth-server") self.setup_server_auth() - elif self.mode == 'client': + elif self.mode == "client": self.zmq_socket = self.context.socket(zmq.DEALER) self.setup_client_auth() - self.zmq_socket.setsockopt(zmq.IDENTITY, identity.encode('utf-8')) + self.zmq_socket.setsockopt(zmq.IDENTITY, identity.encode("utf-8")) else: - raise ValueError("TaskQueue must be initialized with mode set to 'server' or 'client'") + raise ValueError( + "TaskQueue must be initialized with mode set to 'server' or 'client'" + ) if set_hwm: self.zmq_socket.set_hwm(0) @@ -85,10 +92,10 @@ def __init__(self, self.zmq_socket.setsockopt(zmq.LINGER, linger) # all zmq setsockopt calls must be done before bind/connect is called - if self.mode == 'server': - self.zmq_socket.bind("tcp://*:{}".format(port)) - elif self.mode == 'client': - self.zmq_socket.connect("tcp://{}:{}".format(address, port)) + if self.mode == "server": + self.zmq_socket.bind(f"tcp://*:{port}") + elif self.mode == "client": + self.zmq_socket.connect(f"tcp://{address}:{port}") self.poller = zmq.Poller() self.poller.register(self.zmq_socket) @@ -102,10 +109,10 @@ def add_client_key(self, endpoint_id, client_key): logger.info("Adding client key") if self.ironhouse: # Use the ironhouse ZMQ pattern: http://hintjens.com/blog:49#toc6 - with open(os.path.join(self.keys_dir, f'{endpoint_id}.key'), 'w') as f: + with open(os.path.join(self.keys_dir, f"{endpoint_id}.key"), "w") as f: f.write(client_key) try: - self.auth.configure_curve(domain='*', location=self.keys_dir) + self.auth.configure_curve(domain="*", location=self.keys_dir) except Exception: logger.exception("Failed to load keys from {self.keys_dir}") return @@ -114,12 +121,12 @@ def setup_server_auth(self): # Start an authenticator for this context. self.auth = ThreadAuthenticator(self.context) self.auth.start() - self.auth.allow('127.0.0.1') + self.auth.allow("127.0.0.1") # Tell the authenticator how to handle CURVE requests if not self.ironhouse: # Use the stonehouse ZMQ pattern: http://hintjens.com/blog:49#toc5 - self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) + self.auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY) server_secret_file = os.path.join(self.keys_dir, "server.key_secret") server_public, server_secret = zmq.auth.load_certificate(server_secret_file) @@ -166,7 +173,7 @@ def register_client(self, message): return self.zmq_socket.send_multipart([message]) def put(self, dest, message, max_timeout=1000): - """ This function needs to be fast at the same time aware of the possibility of + """This function needs to be fast at the same time aware of the possibility of ZMQ pipes overflowing. The timeout increases slowly if contention is detected on ZMQ pipes. @@ -183,7 +190,8 @@ def put(self, dest, message, max_timeout=1000): Python object to send max_timeout : int - Max timeout in milliseconds that we will wait for before raising an exception + Max timeout in milliseconds that we will wait for before raising an + exception Raises ------ @@ -192,7 +200,7 @@ def put(self, dest, message, max_timeout=1000): zmq.error.ZMQError: Host unreachable (if client disconnects?) """ - if self.mode == 'client': + if self.mode == "client": return self.zmq_socket.send_multipart([message]) else: return self.zmq_socket.send_multipart([dest, message]) diff --git a/funcx_endpoint/funcx_endpoint/endpoint/utils/config.py b/funcx_endpoint/funcx_endpoint/endpoint/utils/config.py index 00588ba62..0f24000ac 100644 --- a/funcx_endpoint/funcx_endpoint/endpoint/utils/config.py +++ b/funcx_endpoint/funcx_endpoint/endpoint/utils/config.py @@ -1,7 +1,7 @@ -from funcx_endpoint.executors import HighThroughputExecutor -from funcx_endpoint.strategies.simple import SimpleStrategy from parsl.utils import RepresentationMixin +from funcx_endpoint.executors import HighThroughputExecutor + _DEFAULT_EXECUTORS = [HighThroughputExecutor()] @@ -12,8 +12,8 @@ class Config(RepresentationMixin): ---------- executors : list of Executors - A list of executors which serve as the backend for function execution. As of 0.2.2, - this list should contain only one executor. + A list of executors which serve as the backend for function execution. + As of 0.2.2, this list should contain only one executor. Default: [HighThroughtputExecutor()] funcx_service_address: str @@ -21,12 +21,13 @@ class Config(RepresentationMixin): Default: 'https://api2.funcx.org/v2' heartbeat_period: int (seconds) - The interval at which heartbeat messages are sent from the endpoint to the funcx-web-service + The interval at which heartbeat messages are sent from the endpoint to the + funcx-web-service Default: 30s heartbeat_threshold: int (seconds) - Seconds since the last hearbeat message from the funcx-web-service after which the connection - is assumed to be disconnected. + Seconds since the last hearbeat message from the funcx-web-service after which + the connection is assumed to be disconnected. Default: 120s stdout : str diff --git a/funcx_endpoint/funcx_endpoint/executors/__init__.py b/funcx_endpoint/funcx_endpoint/executors/__init__.py index 6fbfebe5a..61d8fa594 100644 --- a/funcx_endpoint/funcx_endpoint/executors/__init__.py +++ b/funcx_endpoint/funcx_endpoint/executors/__init__.py @@ -1,3 +1,3 @@ from funcx_endpoint.executors.high_throughput.executor import HighThroughputExecutor -__all__ = ['HighThroughputExecutor'] +__all__ = ["HighThroughputExecutor"] diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/container_sched.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/container_sched.py index 591bca56f..bdf4e4291 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/container_sched.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/container_sched.py @@ -1,18 +1,20 @@ - import math import random -def naive_scheduler(task_qs, outstanding_task_count, max_workers, old_worker_map, to_die_list, logger): - """ Return two items (as one tuple) dict kill_list :: KILL [(worker_type, num_kill), ...] - dict create_list :: CREATE [(worker_type, num_create), ...] - - In this scheduler model, there is minimum 1 instance of each nonempty task queue. +def naive_scheduler( + task_qs, outstanding_task_count, max_workers, old_worker_map, to_die_list, logger +): + """ + Return two items (as one tuple) + dict kill_list :: KILL [(worker_type, num_kill), ...] + dict create_list :: CREATE [(worker_type, num_create), ...] + In this scheduler model, there is minimum 1 instance of each nonempty task queue. """ logger.debug("Entering scheduler...") - logger.debug("old_worker_map: {}".format(old_worker_map)) + logger.debug(f"old_worker_map: {old_worker_map}") q_sizes = {} q_types = [] new_worker_map = {} @@ -26,12 +28,14 @@ def naive_scheduler(task_qs, outstanding_task_count, max_workers, old_worker_map q_sizes[q_type] = q_size if sum_q_size > 0: - logger.info("[SCHEDULER] Total number of tasks is {}".format(sum_q_size)) + logger.info(f"[SCHEDULER] Total number of tasks is {sum_q_size}") # Set proportions of workers equal to the proportion of queue size. for q_type in q_sizes: ratio = q_sizes[q_type] / sum_q_size - new_worker_map[q_type] = min(int(math.floor(ratio * max_workers)), q_sizes[q_type]) + new_worker_map[q_type] = min( + int(math.floor(ratio * max_workers)), q_sizes[q_type] + ) # CLEANUP: Assign the difference here to any random worker. Should be small. # logger.debug("Temporary new worker map: {}".format(new_worker_map)) @@ -41,8 +45,8 @@ def naive_scheduler(task_qs, outstanding_task_count, max_workers, old_worker_map difference = 0 if sum_q_size > tmp_sum_q_size: difference = min(max_workers - tmp_sum_q_size, sum_q_size - tmp_sum_q_size) - logger.debug("[SCHEDULER] Offset difference: {}".format(difference)) - logger.debug("[SCHEDULER] Queue Types: {}".format(q_types)) + logger.debug(f"[SCHEDULER] Offset difference: {difference}") + logger.debug(f"[SCHEDULER] Queue Types: {q_types}") if len(q_types) > 0: while difference > 0: diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/executor.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/executor.py index 51f97d8b8..ac0a23d9f 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/executor.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/executor.py @@ -1,43 +1,46 @@ -"""HighThroughputExecutor builds on the Swift/T EMEWS architecture to use MPI for fast task distribution +"""HighThroughputExecutor builds on the Swift/T EMEWS architecture to use MPI for fast +task distribution There's a slow but sure deviation from Parsl's Executor interface here, that needs to be addressed. """ -from concurrent.futures import Future -import os -import time import logging -import threading -import queue +import os import pickle -import daemon -import uuid +import queue +import threading +import time +from concurrent.futures import Future from multiprocessing import Process -from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue - -from funcx_endpoint.executors.high_throughput.messages import HeartbeatReq, EPStatusReport, Heartbeat -from funcx_endpoint.executors.high_throughput.messages import Message, COMMAND_TYPES, Task -from funcx.serialize import FuncXSerializer -from funcx_endpoint.strategies.simple import SimpleStrategy -fx_serializer = FuncXSerializer() - -# from parsl.executors.high_throughput import interchange -from funcx_endpoint.executors.high_throughput import interchange +import daemon +from parsl.dataflow.error import ConfigurationError from parsl.executors.errors import BadMessage, ScalingFailed + # from parsl.executors.base import ParslExecutor from parsl.executors.status_handling import StatusHandlingExecutor -from parsl.dataflow.error import ConfigurationError - -from parsl.utils import RepresentationMixin from parsl.providers import LocalProvider +from parsl.utils import RepresentationMixin - -from funcx_endpoint.executors.high_throughput import zmq_pipes from funcx import set_file_logger +from funcx.serialize import FuncXSerializer -# TODO: YADU There's a bug here which causes some of the log messages to write out to stderr +# from parsl.executors.high_throughput import interchange +from funcx_endpoint.executors.high_throughput import interchange, zmq_pipes +from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue +from funcx_endpoint.executors.high_throughput.messages import ( + EPStatusReport, + Heartbeat, + HeartbeatReq, + Task, +) +from funcx_endpoint.strategies.simple import SimpleStrategy + +fx_serializer = FuncXSerializer() + +# TODO: YADU There's a bug here which causes some of the log messages to write out to +# stderr # "logging" python3 self.stream.flush() OSError: [Errno 9] Bad file descriptor logger = logging.getLogger(__name__) @@ -54,10 +57,12 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): The HighThroughputExecutor system has the following components: 1. The HighThroughputExecutor instance which is run as part of the Parsl script. - 2. The Interchange which is acts as a load-balancing proxy between workers and Parsl - 3. The multiprocessing based worker pool which coordinates task execution over several - cores on a node. - 4. ZeroMQ pipes connect the HighThroughputExecutor, Interchange and the process_worker_pool + 2. The Interchange which is acts as a load-balancing proxy between workers and + Parsl + 3. The multiprocessing based worker pool which coordinates task execution over + several cores on a node. + 4. ZeroMQ pipes connect the HighThroughputExecutor, Interchange and the + process_worker_pool Here is a diagram @@ -82,7 +87,8 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): ---------- provider : :class:`~parsl.providers.provider_base.ExecutionProvider` - Provider to access computation resources. Can be one of :class:`~parsl.providers.aws.aws.EC2Provider`, + Provider to access computation resources. Can be one of + :class:`~parsl.providers.aws.aws.EC2Provider`, :class:`~parsl.providers.cobalt.cobalt.Cobalt`, :class:`~parsl.providers.condor.condor.Condor`, :class:`~parsl.providers.googlecloud.googlecloud.GoogleCloud`, @@ -97,21 +103,33 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): Label for this executor instance. launch_cmd : str - Command line string to launch the process_worker_pool from the provider. The command line string - will be formatted with appropriate values for the following values (debug, task_url, result_url, - cores_per_worker, nodes_per_block, heartbeat_period ,heartbeat_threshold, logdir). For eg: - launch_cmd="process_worker_pool.py {debug} -c {cores_per_worker} --task_url={task_url} --result_url={result_url}" + Command line string to launch the process_worker_pool from the provider. The + command line string will be formatted with appropriate values for the following + values: ( + debug, + task_url, + result_url, + cores_per_worker, + nodes_per_block, + heartbeat_period, + heartbeat_threshold, + logdir, + ). + For example: + launch_cmd="process_worker_pool.py {debug} -c {cores_per_worker} \ + --task_url={task_url} --result_url={result_url}" address : string - An address of the host on which the executor runs, which is reachable from the network in which - workers will be running. This can be either a hostname as returned by `hostname` or an - IP address. Most login nodes on clusters have several network interfaces available, only - some of which can be reached from the compute nodes. Some trial and error might be - necessary to indentify what addresses are reachable from compute nodes. + An address of the host on which the executor runs, which is reachable from the + network in which workers will be running. This can be either a hostname as + returned by `hostname` or an IP address. Most login nodes on clusters have + several network interfaces available, only some of which can be reached + from the compute nodes. Some trial and error might be necessary to + indentify what addresses are reachable from compute nodes. worker_ports : (int, int) - Specify the ports to be used by workers to connect to Parsl. If this option is specified, - worker_port_range will not be honored. + Specify the ports to be used by workers to connect to Parsl. If this + option is specified, worker_port_range will not be honored. worker_port_range : (int, int) Worker ports will be chosen between the two integers provided. @@ -139,19 +157,22 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): Caps the number of workers launched by the manager. Default: infinity suppress_failure : Bool - If set, the interchange will suppress failures rather than terminate early. Default: False + If set, the interchange will suppress failures rather than terminate early. + Default: False heartbeat_threshold : int Seconds since the last message from the counterpart in the communication pair: - (interchange, manager) after which the counterpart is assumed to be un-available. Default:120s + (interchange, manager) after which the counterpart is assumed to be unavailable. + Default:120s heartbeat_period : int - Number of seconds after which a heartbeat message indicating liveness is sent to the endpoint + Number of seconds after which a heartbeat message indicating liveness is sent to + the endpoint counterpart (interchange, manager). Default:30s poll_period : int - Timeout period to be used by the executor components in milliseconds. Increasing poll_periods - trades performance for cpu efficiency. Default: 10ms + Timeout period to be used by the executor components in milliseconds. + Increasing poll_periods trades performance for cpu efficiency. Default: 10ms container_image : str Path or identfier to the container image to be used by the workers @@ -162,7 +183,8 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): 'soft' -> managers can replace unused worker's containers based on demand worker_mode : str - Select the mode of operation from no_container, singularity_reuse, singularity_single_use + Select the mode of operation from no_container, singularity_reuse, + singularity_single_use Default: singularity_reuse container_cmd_options: str @@ -176,66 +198,65 @@ class HighThroughputExecutor(StatusHandlingExecutor, RepresentationMixin): Specify the scaling strategy to use for this executor. launch_cmd: str - Specify the launch command as using f-string format that will be used to specify command to - launch managers. Default: None + Specify the launch command as using f-string format that will be used to specify + command to launch managers. Default: None prefetch_capacity: int - Number of tasks that can be fetched by managers in excess of available workers is a - prefetching optimization. This option can cause poor load-balancing for long running functions. + Number of tasks that can be fetched by managers in excess of available + workers is a prefetching optimization. This option can cause poor + load-balancing for long running functions. Default: 10 provider: Provider object - Provider determines how managers can be provisioned, say LocalProvider offers forked processes, - and SlurmProvider interfaces to request resources from the Slurm batch scheduler. + Provider determines how managers can be provisioned, say LocalProvider + offers forked processes, and SlurmProvider interfaces to request + resources from the Slurm batch scheduler. Default: LocalProvider funcx_service_address: str - Override funcx_service_address used by the FuncXClient. If no address is specified, - the FuncXClient's default funcx_service_address is used. + Override funcx_service_address used by the FuncXClient. If no address + is specified, the FuncXClient's default funcx_service_address is used. Default: None """ - def __init__(self, - label='HighThroughputExecutor', - - - # NEW - strategy=SimpleStrategy(), - max_workers_per_node=float('inf'), - mem_per_worker=None, - launch_cmd=None, - - # Container specific - worker_mode='no_container', - scheduler_mode='hard', - container_type=None, - container_cmd_options='', - cold_routing_interval=10.0, - - # Tuning info - prefetch_capacity=10, - - provider=LocalProvider(), - address="127.0.0.1", - worker_ports=None, - worker_port_range=(54000, 55000), - interchange_port_range=(55000, 56000), - storage_access=None, - working_dir=None, - worker_debug=False, - cores_per_worker=1.0, - heartbeat_threshold=120, - heartbeat_period=30, - poll_period=10, - container_image=None, - suppress_failure=False, - run_dir=None, - endpoint_id=None, - managed=True, - interchange_local=True, - passthrough=True, - funcx_service_address=None, - task_status_queue=None): + def __init__( + self, + label="HighThroughputExecutor", + # NEW + strategy=SimpleStrategy(), + max_workers_per_node=float("inf"), + mem_per_worker=None, + launch_cmd=None, + # Container specific + worker_mode="no_container", + scheduler_mode="hard", + container_type=None, + container_cmd_options="", + cold_routing_interval=10.0, + # Tuning info + prefetch_capacity=10, + provider=LocalProvider(), + address="127.0.0.1", + worker_ports=None, + worker_port_range=(54000, 55000), + interchange_port_range=(55000, 56000), + storage_access=None, + working_dir=None, + worker_debug=False, + cores_per_worker=1.0, + heartbeat_threshold=120, + heartbeat_period=30, + poll_period=10, + container_image=None, + suppress_failure=False, + run_dir=None, + endpoint_id=None, + managed=True, + interchange_local=True, + passthrough=True, + funcx_service_address=None, + task_status_queue=None, + ): logger.debug("Initializing HighThroughputExecutor") StatusHandlingExecutor.__init__(self, provider) @@ -261,7 +282,9 @@ def __init__(self, self.storage_access = storage_access if storage_access is not None else [] if len(self.storage_access) > 1: - raise ConfigurationError('Multiple storage access schemes are not supported') + raise ConfigurationError( + "Multiple storage access schemes are not supported" + ) self.working_dir = working_dir self.managed = managed self.blocks = [] @@ -289,70 +312,88 @@ def __init__(self, self.last_response_time = time.time() if not launch_cmd: - self.launch_cmd = ("process_worker_pool.py {debug} {max_workers} " - "-c {cores_per_worker} " - "--poll {poll_period} " - "--task_url={task_url} " - "--result_url={result_url} " - "--logdir={logdir} " - "--hb_period={heartbeat_period} " - "--hb_threshold={heartbeat_threshold} " - "--mode={worker_mode} " - "--container_image={container_image} ") - - self.ix_launch_cmd = ("funcx-interchange {debug} -c={client_address} " - "--client_ports={client_ports} " - "--worker_port_range={worker_port_range} " - "--logdir={logdir} " - "{suppress_failure} " - ) + self.launch_cmd = ( + "process_worker_pool.py {debug} {max_workers} " + "-c {cores_per_worker} " + "--poll {poll_period} " + "--task_url={task_url} " + "--result_url={result_url} " + "--logdir={logdir} " + "--hb_period={heartbeat_period} " + "--hb_threshold={heartbeat_threshold} " + "--mode={worker_mode} " + "--container_image={container_image} " + ) + + self.ix_launch_cmd = ( + "funcx-interchange {debug} -c={client_address} " + "--client_ports={client_ports} " + "--worker_port_range={worker_port_range} " + "--logdir={logdir} " + "{suppress_failure} " + ) def initialize_scaling(self): - """ Compose the launch command and call the scale_out + """Compose the launch command and call the scale_out This should be implemented in the child classes to take care of executor specific oddities. """ debug_opts = "--debug" if self.worker_debug else "" - max_workers = "" if self.max_workers == float('inf') else "--max_workers={}".format(self.max_workers) - - l_cmd = self.launch_cmd.format(debug=debug_opts, - task_url=self.worker_task_url, - result_url=self.worker_result_url, - cores_per_worker=self.cores_per_worker, - max_workers=max_workers, - nodes_per_block=self.provider.nodes_per_block, - heartbeat_period=self.heartbeat_period, - heartbeat_threshold=self.heartbeat_threshold, - poll_period=self.poll_period, - logdir=os.path.join(self.run_dir, self.label), - worker_mode=self.worker_mode, - container_image=self.container_image) + max_workers = ( + "" + if self.max_workers == float("inf") + else f"--max_workers={self.max_workers}" + ) + + l_cmd = self.launch_cmd.format( + debug=debug_opts, + task_url=self.worker_task_url, + result_url=self.worker_result_url, + cores_per_worker=self.cores_per_worker, + max_workers=max_workers, + nodes_per_block=self.provider.nodes_per_block, + heartbeat_period=self.heartbeat_period, + heartbeat_threshold=self.heartbeat_threshold, + poll_period=self.poll_period, + logdir=os.path.join(self.run_dir, self.label), + worker_mode=self.worker_mode, + container_image=self.container_image, + ) self.launch_cmd = l_cmd - logger.debug("Launch command: {}".format(self.launch_cmd)) + logger.debug(f"Launch command: {self.launch_cmd}") self._scaling_enabled = self.provider.scaling_enabled - logger.debug("Starting HighThroughputExecutor with provider:\n%s", self.provider) - if hasattr(self.provider, 'init_blocks'): + logger.debug( + "Starting HighThroughputExecutor with provider:\n%s", self.provider + ) + if hasattr(self.provider, "init_blocks"): try: self.scale_out(blocks=self.provider.init_blocks) except Exception as e: - logger.error("Scaling out failed: {}".format(e)) + logger.error(f"Scaling out failed: {e}") raise e def start(self, results_passthrough=None): - """Create the Interchange process and connect to it. - """ - self.outgoing_q = zmq_pipes.TasksOutgoing("0.0.0.0", self.interchange_port_range) - self.incoming_q = zmq_pipes.ResultsIncoming("0.0.0.0", self.interchange_port_range) - self.command_client = zmq_pipes.CommandClient("0.0.0.0", self.interchange_port_range) + """Create the Interchange process and connect to it.""" + self.outgoing_q = zmq_pipes.TasksOutgoing( + "0.0.0.0", self.interchange_port_range + ) + self.incoming_q = zmq_pipes.ResultsIncoming( + "0.0.0.0", self.interchange_port_range + ) + self.command_client = zmq_pipes.CommandClient( + "0.0.0.0", self.interchange_port_range + ) self.is_alive = True if self.passthrough is True: if results_passthrough is None: - raise Exception("Executors configured in passthrough mode, must be started with" - "a multiprocessing queue for results_passthrough") + raise Exception( + "Executors configured in passthrough mode, must be started with" + "a multiprocessing queue for results_passthrough" + ) self.results_passthrough = results_passthrough logger.debug(f"Executor:{self.label} starting in results_passthrough mode") @@ -364,9 +405,13 @@ def start(self, results_passthrough=None): if self.interchange_local is True: logger.info("Attempting local interchange start") self._start_local_interchange_process() - logger.info(f"Started local interchange with ports: {self.worker_task_port}. {self.worker_result_port}") + logger.info( + "Started local interchange with ports: %s. %s", + self.worker_task_port, + self.worker_result_port, + ) - logger.debug("Created management thread: {}".format(self._queue_management_thread)) + logger.debug(f"Created management thread: {self._queue_management_thread}") if self.provider: # self.initialize_scaling() @@ -378,83 +423,105 @@ def start(self, results_passthrough=None): return (self.outgoing_q.port, self.incoming_q.port, self.command_client.port) def _start_local_interchange_process(self): - """ Starts the interchange process locally + """Starts the interchange process locally Starts the interchange process locally and uses an internal command queue to get the worker task and result ports that the interchange has bound to. """ comm_q = mpQueue(maxsize=10) print(f"Starting local interchange with endpoint id: {self.endpoint_id}") - self.queue_proc = Process(target=interchange.starter, - args=(comm_q,), - kwargs={"client_address": self.address, - "client_ports": (self.outgoing_q.port, - self.incoming_q.port, - self.command_client.port), - "provider": self.provider, - "strategy": self.strategy, - "poll_period": self.poll_period, - "heartbeat_period": self.heartbeat_period, - "heartbeat_threshold": self.heartbeat_threshold, - "working_dir": self.working_dir, - "worker_debug": self.worker_debug, - "max_workers_per_node": self.max_workers_per_node, - "mem_per_worker": self.mem_per_worker, - "cores_per_worker": self.cores_per_worker, - "prefetch_capacity": self.prefetch_capacity, - # "log_max_bytes": self.log_max_bytes, - # "log_backup_count": self.log_backup_count, - "scheduler_mode": self.scheduler_mode, - "worker_mode": self.worker_mode, - "container_type": self.container_type, - "container_cmd_options": self.container_cmd_options, - "cold_routing_interval": self.cold_routing_interval, - "funcx_service_address": self.funcx_service_address, - "interchange_address": self.address, - "worker_ports": self.worker_ports, - "worker_port_range": self.worker_port_range, - "logdir": os.path.join(self.run_dir, self.label), - "suppress_failure": self.suppress_failure, - "endpoint_id": self.endpoint_id, - "logging_level": logging.DEBUG if self.worker_debug else logging.INFO - }, + self.queue_proc = Process( + target=interchange.starter, + args=(comm_q,), + kwargs={ + "client_address": self.address, + "client_ports": ( + self.outgoing_q.port, + self.incoming_q.port, + self.command_client.port, + ), + "provider": self.provider, + "strategy": self.strategy, + "poll_period": self.poll_period, + "heartbeat_period": self.heartbeat_period, + "heartbeat_threshold": self.heartbeat_threshold, + "working_dir": self.working_dir, + "worker_debug": self.worker_debug, + "max_workers_per_node": self.max_workers_per_node, + "mem_per_worker": self.mem_per_worker, + "cores_per_worker": self.cores_per_worker, + "prefetch_capacity": self.prefetch_capacity, + # "log_max_bytes": self.log_max_bytes, + # "log_backup_count": self.log_backup_count, + "scheduler_mode": self.scheduler_mode, + "worker_mode": self.worker_mode, + "container_type": self.container_type, + "container_cmd_options": self.container_cmd_options, + "cold_routing_interval": self.cold_routing_interval, + "funcx_service_address": self.funcx_service_address, + "interchange_address": self.address, + "worker_ports": self.worker_ports, + "worker_port_range": self.worker_port_range, + "logdir": os.path.join(self.run_dir, self.label), + "suppress_failure": self.suppress_failure, + "endpoint_id": self.endpoint_id, + "logging_level": logging.DEBUG if self.worker_debug else logging.INFO, + }, ) self.queue_proc.start() try: - (self.worker_task_port, self.worker_result_port) = comm_q.get(block=True, timeout=120) + (self.worker_task_port, self.worker_result_port) = comm_q.get( + block=True, timeout=120 + ) except queue.Empty: - logger.error("Interchange has not completed initialization in 120s. Aborting") + logger.error( + "Interchange has not completed initialization in 120s. Aborting" + ) raise Exception("Interchange failed to start") - self.worker_task_url = "tcp://{}:{}".format(self.address, self.worker_task_port) - self.worker_result_url = "tcp://{}:{}".format(self.address, self.worker_result_port) + self.worker_task_url = f"tcp://{self.address}:{self.worker_task_port}" + self.worker_result_url = "tcp://{}:{}".format( + self.address, self.worker_result_port + ) def _start_remote_interchange_process(self): - """ Starts the interchange process locally + """Starts the interchange process locally - Starts the interchange process remotely via the provider.channel and uses the command channel - to request worker urls that the interchange has selected. + Starts the interchange process remotely via the provider.channel and + uses the command channel to request worker urls that the interchange + has selected. """ - logger.debug("Attempting Interchange deployment via channel: {}".format(self.provider.channel)) + logger.debug( + "Attempting Interchange deployment via channel: {}".format( + self.provider.channel + ) + ) debug_opts = "--debug" if self.worker_debug else "" suppress_failure = "--suppress_failure" if self.suppress_failure else "" - logger.debug("Before : \n{}\n".format(self.ix_launch_cmd)) - launch_command = self.ix_launch_cmd.format(debug=debug_opts, - client_address=self.address, - client_ports="{},{},{}".format(self.outgoing_q.port, - self.incoming_q.port, - self.command_client.port), - worker_port_range="{},{}".format(self.worker_port_range[0], - self.worker_port_range[1]), - logdir=os.path.join(self.provider.channel.script_dir, 'runinfo', - os.path.basename(self.run_dir), self.label), - suppress_failure=suppress_failure) + logger.debug(f"Before : \n{self.ix_launch_cmd}\n") + launch_command = self.ix_launch_cmd.format( + debug=debug_opts, + client_address=self.address, + client_ports="{},{},{}".format( + self.outgoing_q.port, self.incoming_q.port, self.command_client.port + ), + worker_port_range="{},{}".format( + self.worker_port_range[0], self.worker_port_range[1] + ), + logdir=os.path.join( + self.provider.channel.script_dir, + "runinfo", + os.path.basename(self.run_dir), + self.label, + ), + suppress_failure=suppress_failure, + ) if self.provider.worker_init: - launch_command = self.provider.worker_init + '\n' + launch_command + launch_command = self.provider.worker_init + "\n" + launch_command - logger.debug("Launch command : \n{}\n".format(launch_command)) + logger.debug(f"Launch command : \n{launch_command}\n") return def _queue_management_worker(self): @@ -502,12 +569,16 @@ def _queue_management_worker(self): # Timed out. pass - except IOError as e: - logger.exception("[MTHREAD] Caught broken queue with exception code {}: {}".format(e.errno, e)) + except OSError as e: + logger.exception( + "[MTHREAD] Caught broken queue with exception code {}: {}".format( + e.errno, e + ) + ) return except Exception as e: - logger.exception("[MTHREAD] Caught unknown exception: {}".format(e)) + logger.exception(f"[MTHREAD] Caught unknown exception: {e}") return else: @@ -517,65 +588,86 @@ def _queue_management_worker(self): return elif isinstance(msgs, EPStatusReport): - logger.debug("[MTHREAD] Received EPStatusReport {}".format(msgs)) + logger.debug(f"[MTHREAD] Received EPStatusReport {msgs}") if self.passthrough: - self.results_passthrough.put({ - "task_id": None, - "message": pickle.dumps(msgs) - }) + self.results_passthrough.put( + {"task_id": None, "message": pickle.dumps(msgs)} + ) else: logger.debug("[MTHREAD] Unpacking results") for serialized_msg in msgs: try: msg = pickle.loads(serialized_msg) - tid = msg['task_id'] + tid = msg["task_id"] except pickle.UnpicklingError: raise BadMessage("Message received could not be unpickled") except Exception: - raise BadMessage("Message received does not contain 'task_id' field") - - if tid == -2 and 'info' in msg: - logger.warning("[MTHREAD[ Received info response : {}".format(msg['info'])) - - if tid == -1 and 'exception' in msg: - # TODO: This could be handled better we are essentially shutting down the - # client with little indication to the user. - logger.warning("[MTHREAD] Executor shutting down due to fatal exception from interchange") - self._executor_exception = fx_serializer.deserialize(msg['exception']) - logger.exception("[MTHREAD] Exception: {}".format(self._executor_exception)) + raise BadMessage( + "Message received does not contain 'task_id' field" + ) + + if tid == -2 and "info" in msg: + logger.warning( + "[MTHREAD[ Received info response : {}".format( + msg["info"] + ) + ) + + if tid == -1 and "exception" in msg: + # TODO: This could be handled better we are + # essentially shutting down the client with little + # indication to the user. + logger.warning( + "[MTHREAD] Executor shutting down due to fatal " + "exception from interchange" + ) + self._executor_exception = fx_serializer.deserialize( + msg["exception"] + ) + logger.exception( + "[MTHREAD] Exception: {}".format( + self._executor_exception + ) + ) # Set bad state to prevent new tasks from being submitted self._executor_bad_state.set() - # We set all current tasks to this exception to make sure that - # this is raised in the main context. + # We set all current tasks to this exception to make sure + # that this is raised in the main context. for task in self.tasks: self.tasks[task].set_exception(self._executor_exception) break if self.passthrough is True: logger.debug(f"[MTHREAD] Pushing results for task:{tid}") - # we are only interested in actual task ids here, not identifiers - # for other message types + # we are only interested in actual task ids here, not + # identifiers for other message types sent_task_id = tid if isinstance(tid, str) else None - x = self.results_passthrough.put({ - "task_id": sent_task_id, - "message": serialized_msg - }) + x = self.results_passthrough.put( + {"task_id": sent_task_id, "message": serialized_msg} + ) logger.debug(f"[MTHREAD] task:{tid} ret value: {x}") - logger.debug(f"[MTHREAD] task:{tid} items in queue: {self.results_passthrough.qsize()}") + logger.debug( + "[MTHREAD] task:%s items in queue: %s", + tid, + self.results_passthrough.qsize(), + ) continue task_fut = self.tasks.pop(tid) - if 'result' in msg: - result = fx_serializer.deserialize(msg['result']) + if "result" in msg: + result = fx_serializer.deserialize(msg["result"]) task_fut.set_result(result) - elif 'exception' in msg: - exception = fx_serializer.deserialize(msg['exception']) + elif "exception" in msg: + exception = fx_serializer.deserialize(msg["exception"]) task_fut.set_result(exception) else: - raise BadMessage("[MTHREAD] Message received is neither result or exception") + raise BadMessage( + "[MTHREAD] Message received is neither result or " + "exception" + ) if not self.is_alive: break @@ -595,7 +687,9 @@ def _start_queue_management_thread(self): """ if self._queue_management_thread is None: logger.debug("Starting queue management thread") - self._queue_management_thread = threading.Thread(target=self._queue_management_worker) + self._queue_management_thread = threading.Thread( + target=self._queue_management_worker + ) self._queue_management_thread.daemon = True self._queue_management_thread.start() logger.debug("Started queue management thread") @@ -615,8 +709,8 @@ def hold_worker(self, worker_id): worker_id : str Worker id to be put on hold """ - c = self.command_client.run("HOLD_WORKER;{}".format(worker_id)) - logger.debug("Sent hold request to worker: {}".format(worker_id)) + c = self.command_client.run(f"HOLD_WORKER;{worker_id}") + logger.debug(f"Sent hold request to worker: {worker_id}") return c def send_heartbeat(self): @@ -632,18 +726,19 @@ def wait_for_endpoint(self): @property def outstanding(self): outstanding_c = self.command_client.run("OUTSTANDING_C") - logger.debug("Got outstanding count: {}".format(outstanding_c)) + logger.debug(f"Got outstanding count: {outstanding_c}") return outstanding_c @property def connected_workers(self): workers = self.command_client.run("MANAGERS") - logger.debug("Got managers: {}".format(workers)) + logger.debug(f"Got managers: {workers}") return workers - def submit(self, func, *args, container_id: str = 'RAW', task_id: str = None, **kwargs): - """ Submits the function and it's params for execution. - """ + def submit( + self, func, *args, container_id: str = "RAW", task_id: str = None, **kwargs + ): + """Submits the function and it's params for execution.""" self._task_counter += 1 if task_id is None: task_id = self._task_counter @@ -651,11 +746,10 @@ def submit(self, func, *args, container_id: str = 'RAW', task_id: str = None, ** fn_code = fx_serializer.serialize(func) ser_code = fx_serializer.pack_buffers([fn_code]) - ser_params = fx_serializer.pack_buffers([fx_serializer.serialize(args), - fx_serializer.serialize(kwargs)]) - payload = Task(task_id, - container_id, - ser_code + ser_params) + ser_params = fx_serializer.pack_buffers( + [fx_serializer.serialize(args), fx_serializer.serialize(kwargs)] + ) + payload = Task(task_id, container_id, ser_code + ser_params) self.submit_raw(payload.pack()) self.tasks[task_id] = Future() @@ -667,11 +761,13 @@ def submit_raw(self, packed_task): The outgoing_q is an external process listens on this queue for new work. This method behaves like a - submit call as described here `Python docs: `_ + submit call as described in the `Python docs \ + `_ Parameters ---------- - Packed Task (messages.Task) - A packed Task object which contains task_id, container_id, and serialized fn, args, kwargs packages. + Packed Task (messages.Task) - A packed Task object which contains task_id, + container_id, and serialized fn, args, kwargs packages. Returns: Submit status @@ -695,17 +791,18 @@ def _get_block_and_job_ids(self): @property def connection_info(self): - """ All connection info necessary for the endpoint to connect back + """All connection info necessary for the endpoint to connect back Returns: Dict with connection info """ - return {'address': self.address, - # A memorial to the ungodly amount of time and effort spent, - # troubleshooting the order of these ports. - 'client_ports': '{},{},{}'.format(self.outgoing_q.port, - self.incoming_q.port, - self.command_client.port) + return { + "address": self.address, + # A memorial to the ungodly amount of time and effort spent, + # troubleshooting the order of these ports. + "client_ports": "{},{},{}".format( + self.outgoing_q.port, self.incoming_q.port, self.command_client.port + ), } @property @@ -722,10 +819,14 @@ def scale_out(self, blocks=1): for i in range(blocks): if self.provider: block = self.provider.submit(self.launch_cmd, 1, 1) - logger.debug("Launched block {}:{}".format(i, block)) + logger.debug(f"Launched block {i}:{block}") if not block: - raise(ScalingFailed(self.provider.label, - "Attempts to provision nodes via provider has failed")) + raise ( + ScalingFailed( + self.provider.label, + "Attempts to provision nodes via provider has failed", + ) + ) self.blocks.extend([block]) else: logger.error("No execution provider available") @@ -759,7 +860,7 @@ def status(self): return status - def shutdown(self, hub=True, targets='all', block=False): + def shutdown(self, hub=True, targets="all", block=False): """Shutdown the executor, including all workers and controllers. This is not implemented. @@ -784,15 +885,17 @@ def shutdown(self, hub=True, targets='all', block=False): def executor_starter(htex, logdir, endpoint_id, logging_level=logging.DEBUG): - stdout = open(os.path.join(logdir, "executor.{}.stdout".format(endpoint_id)), 'w') - stderr = open(os.path.join(logdir, "executor.{}.stderr".format(endpoint_id)), 'w') + stdout = open(os.path.join(logdir, f"executor.{endpoint_id}.stdout"), "w") + stderr = open(os.path.join(logdir, f"executor.{endpoint_id}.stderr"), "w") logdir = os.path.abspath(logdir) with daemon.DaemonContext(stdout=stdout, stderr=stderr): global logger print("cwd: ", os.getcwd()) - logger = set_file_logger(os.path.join(logdir, "executor.{}.log".format(endpoint_id)), - level=logging_level) + logger = set_file_logger( + os.path.join(logdir, f"executor.{endpoint_id}.log"), + level=logging_level, + ) htex.start() stdout.close() diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_manager.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_manager.py index 1ef16eed9..f1894bdde 100755 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_manager.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_manager.py @@ -1,38 +1,35 @@ #!/usr/bin/env python3 import argparse +import json import logging +import math +import multiprocessing import os -import sys +import pickle import platform +import queue +import subprocess +import sys import threading -import pickle import time -import queue import uuid -import zmq -import math -import json -import multiprocessing + import psutil -import subprocess +import zmq +from parsl.version import VERSION as PARSL_VERSION +from funcx import set_file_logger +from funcx.serialize import FuncXSerializer from funcx_endpoint.executors.high_throughput.container_sched import naive_scheduler +from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue from funcx_endpoint.executors.high_throughput.messages import ( - EPStatusReport, - Heartbeat, ManagerStatusReport, - TaskStatusCode + Message, + Task, + TaskStatusCode, ) from funcx_endpoint.executors.high_throughput.worker_map import WorkerMap -from funcx_endpoint.executors.high_throughput.messages import Message, COMMAND_TYPES, MessageType, Task -from funcx.serialize import FuncXSerializer -from funcx_endpoint.executors.high_throughput.mac_safe_queue import mpQueue - -from parsl.version import VERSION as PARSL_VERSION - -from funcx import set_file_logger - RESULT_TAG = 10 TASK_REQUEST_TAG = 11 @@ -42,8 +39,8 @@ logger = None -class Manager(object): - """ Manager manages task execution by the workers +class Manager: + """Manager manages task execution by the workers | 0mq | Manager | Worker Processes | | | @@ -59,26 +56,28 @@ class Manager(object): """ - def __init__(self, - task_q_url="tcp://127.0.0.1:50097", - result_q_url="tcp://127.0.0.1:50098", - max_queue_size=10, - cores_per_worker=1, - max_workers=float('inf'), - uid=None, - heartbeat_threshold=120, - heartbeat_period=30, - logdir=None, - debug=False, - block_id=None, - internal_worker_port_range=(50000, 60000), - worker_mode="singularity_reuse", - container_cmd_options="", - scheduler_mode="hard", - worker_type=None, - worker_max_idletime=60, - # TODO : This should be 10ms - poll_period=100): + def __init__( + self, + task_q_url="tcp://127.0.0.1:50097", + result_q_url="tcp://127.0.0.1:50098", + max_queue_size=10, + cores_per_worker=1, + max_workers=float("inf"), + uid=None, + heartbeat_threshold=120, + heartbeat_period=30, + logdir=None, + debug=False, + block_id=None, + internal_worker_port_range=(50000, 60000), + worker_mode="singularity_reuse", + container_cmd_options="", + scheduler_mode="hard", + worker_type=None, + worker_max_idletime=60, + # TODO : This should be 10ms + poll_period=100, + ): """ Parameters ---------- @@ -98,23 +97,29 @@ def __init__(self, heartbeat_threshold : int Seconds since the last message from the interchange after which the - interchange is assumed to be un-available, and the manager initiates shutdown. Default:120s + interchange is assumed to be un-available, and the manager initiates + shutdown. Default:120s - Number of seconds since the last message from the interchange after which the worker - assumes that the interchange is lost and the manager shuts down. Default:120 + Number of seconds since the last message from the interchange after which + the worker assumes that the interchange is lost and the manager shuts down. + Default:120 heartbeat_period : int - Number of seconds after which a heartbeat message is sent to the interchange + Number of seconds after which a heartbeat message is sent to the + interchange internal_worker_port_range : tuple(int, int) - Port range from which the port(s) for the workers to connect to the manager is picked. + Port range from which the port(s) for the workers to connect to the manager + is picked. Default: (50000,60000) worker_mode : str Pick between 3 supported modes for the worker: 1. no_container : Worker launched without containers - 2. singularity_reuse : Worker launched inside a singularity container that will be reused - 3. singularity_single_use : Each worker and task runs inside a new container instance. + 2. singularity_reuse : Worker launched inside a singularity container that + will be reused + 3. singularity_single_use : Each worker and task runs inside a new + container instance. container_cmd_options: str Container command strings to be added to associated container command. @@ -135,15 +140,17 @@ def __init__(self, global logger # This is expected to be used only in unit test if logger is None: - logger = set_file_logger(os.path.join(logdir, uid, 'manager.log'), - name='funcx_manager', - level=logging.DEBUG) + logger = set_file_logger( + os.path.join(logdir, uid, "manager.log"), + name="funcx_manager", + level=logging.DEBUG, + ) logger.info("Manager started") self.context = zmq.Context() self.task_incoming = self.context.socket(zmq.DEALER) - self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode('utf-8')) + self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode("utf-8")) # Linger is set to 0, so that the manager can exit even when there might be # messages in the pipe self.task_incoming.setsockopt(zmq.LINGER, 0) @@ -153,7 +160,7 @@ def __init__(self, self.debug = debug self.block_id = block_id self.result_outgoing = self.context.socket(zmq.DEALER) - self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode('utf-8')) + self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode("utf-8")) self.result_outgoing.setsockopt(zmq.LINGER, 0) self.result_outgoing.connect(result_q_url) @@ -169,22 +176,30 @@ def __init__(self, self.cores_on_node = multiprocessing.cpu_count() self.max_workers = max_workers self.cores_per_workers = cores_per_worker - self.available_mem_on_node = round(psutil.virtual_memory().available / (2**30), 1) - self.max_worker_count = min(max_workers, - math.floor(self.cores_on_node / cores_per_worker)) + self.available_mem_on_node = round( + psutil.virtual_memory().available / (2 ** 30), 1 + ) + self.max_worker_count = min( + max_workers, math.floor(self.cores_on_node / cores_per_worker) + ) self.worker_map = WorkerMap(self.max_worker_count) self.internal_worker_port_range = internal_worker_port_range self.funcx_task_socket = self.context.socket(zmq.ROUTER) self.funcx_task_socket.set_hwm(0) - self.address = '127.0.0.1' + self.address = "127.0.0.1" self.worker_port = self.funcx_task_socket.bind_to_random_port( "tcp://*", min_port=self.internal_worker_port_range[0], - max_port=self.internal_worker_port_range[1]) + max_port=self.internal_worker_port_range[1], + ) - logger.info("Manager listening on {} port for incoming worker connections".format(self.worker_port)) + logger.info( + "Manager listening on {} port for incoming worker connections".format( + self.worker_port + ) + ) self.task_queues = {} if worker_type: @@ -208,12 +223,10 @@ def __init__(self, self._kill_event = threading.Event() self._result_pusher_thread = threading.Thread( - target=self.push_results, - args=(self._kill_event,) + target=self.push_results, args=(self._kill_event,) ) self._status_report_thread = threading.Thread( - target=self._status_report_loop, - args=(self._kill_event,) + target=self._status_report_loop, args=(self._kill_event,) ) self.container_switch_count = 0 @@ -224,26 +237,26 @@ def __init__(self, self.task_done_counter = 0 def create_reg_message(self): - """ Creates a registration message to identify the worker to the interchange - """ - msg = {'parsl_v': PARSL_VERSION, - 'python_v': "{}.{}.{}".format(sys.version_info.major, - sys.version_info.minor, - sys.version_info.micro), - 'max_worker_count': self.max_worker_count, - 'cores': self.cores_on_node, - 'mem': self.available_mem_on_node, - 'block_id': self.block_id, - 'worker_type': self.worker_type, - 'os': platform.system(), - 'hname': platform.node(), - 'dir': os.getcwd(), + """Creates a registration message to identify the worker to the interchange""" + msg = { + "parsl_v": PARSL_VERSION, + "python_v": "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro + ), + "max_worker_count": self.max_worker_count, + "cores": self.cores_on_node, + "mem": self.available_mem_on_node, + "block_id": self.block_id, + "worker_type": self.worker_type, + "os": platform.system(), + "hname": platform.node(), + "dir": os.getcwd(), } - b_msg = json.dumps(msg).encode('utf-8') + b_msg = json.dumps(msg).encode("utf-8") return b_msg def pull_tasks(self, kill_event): - """ Pull tasks from the incoming tasks 0mq pipe onto the internal + """Pull tasks from the incoming tasks 0mq pipe onto the internal pending task queue @@ -265,7 +278,7 @@ def pull_tasks(self, kill_event): # Send a registration message msg = self.create_reg_message() - logger.debug("Sending registration message: {}".format(msg)) + logger.debug(f"Sending registration message: {msg}") self.task_incoming.send(msg) last_interchange_contact = time.time() task_recv_counter = 0 @@ -278,32 +291,44 @@ def pull_tasks(self, kill_event): logger.debug("[TASK_PULL_THREAD] Loop start") pending_task_count = task_recv_counter - self.task_done_counter ready_worker_count = self.worker_map.ready_worker_count() - logger.debug("[TASK_PULL_THREAD pending_task_count: {}, Ready_worker_count: {}".format( - pending_task_count, ready_worker_count)) + logger.debug( + "[TASK_PULL_THREAD pending_task_count: %s, Ready_worker_count: %s", + pending_task_count, + ready_worker_count, + ) if pending_task_count < self.max_queue_size and ready_worker_count > 0: ads = self.worker_map.advertisement() - logger.debug("[TASK_PULL_THREAD] Requesting tasks: {}".format(ads)) + logger.debug(f"[TASK_PULL_THREAD] Requesting tasks: {ads}") msg = pickle.dumps(ads) self.task_incoming.send(msg) # Receive results from the workers, if any socks = dict(self.poller.poll(timeout=poll_timer)) - if self.funcx_task_socket in socks and socks[self.funcx_task_socket] == zmq.POLLIN: + if ( + self.funcx_task_socket in socks + and socks[self.funcx_task_socket] == zmq.POLLIN + ): self.poll_funcx_task_socket() # Receive task batches from Interchange and forward to workers if self.task_incoming in socks and socks[self.task_incoming] == zmq.POLLIN: - # If we want to wrap the task_incoming polling into a separate function, we need to - # self.poll_task_incoming(poll_timer, last_interchange_contact, kill_event, task_revc_counter) + # If we want to wrap the task_incoming polling into a separate function, + # we need to + # self.poll_task_incoming( + # poll_timer, + # last_interchange_contact, + # kill_event, + # task_revc_counter + # ) poll_timer = 0 _, pkl_msg = self.task_incoming.recv_multipart() message = pickle.loads(pkl_msg) last_interchange_contact = time.time() - if message == 'STOP': + if message == "STOP": logger.critical("[TASK_PULL_THREAD] Received stop request") kill_event.set() break @@ -312,14 +337,20 @@ def pull_tasks(self, kill_event): logger.debug("Got heartbeat from interchange") else: - tasks = [(rt['local_container'], Message.unpack(rt['raw_buffer'])) for rt in message] + tasks = [ + (rt["local_container"], Message.unpack(rt["raw_buffer"])) + for rt in message + ] task_recv_counter += len(tasks) - logger.debug("[TASK_PULL_THREAD] Got tasks: {} of {}".format([t[1].task_id for t in tasks], - task_recv_counter)) + logger.debug( + "[TASK_PULL_THREAD] Got tasks: {} of {}".format( + [t[1].task_id for t in tasks], task_recv_counter + ) + ) for task_type, task in tasks: - logger.debug("[TASK DEBUG] Task is of type: {}".format(task_type)) + logger.debug(f"[TASK DEBUG] Task is of type: {task_type}") if task_type not in self.task_queues: self.task_queues[task_type] = queue.Queue() @@ -328,8 +359,12 @@ def pull_tasks(self, kill_event): self.task_queues[task_type].put(task) self.outstanding_task_count[task_type] += 1 self.task_type_mapping[task.task_id] = task_type - logger.debug("Got task: Outstanding task counts: {}".format(self.outstanding_task_count)) - logger.debug("Task {} pushed to a task queue {}".format(task, task_type)) + logger.debug( + "Got task: Outstanding task counts: {}".format( + self.outstanding_task_count + ) + ) + logger.debug(f"Task {task} pushed to a task queue {task_type}") else: logger.debug("[TASK_PULL_THREAD] No incoming tasks") @@ -341,7 +376,10 @@ def pull_tasks(self, kill_event): # Only check if no messages were received. if time.time() > last_interchange_contact + self.heartbeat_threshold: - logger.critical("[TASK_PULL_THREAD] Missing contact with interchange beyond heartbeat_threshold") + logger.critical( + "[TASK_PULL_THREAD] Missing contact with interchange beyond " + "heartbeat_threshold" + ) kill_event.set() logger.critical("Killing all workers") for proc in self.worker_procs.values(): @@ -349,97 +387,147 @@ def pull_tasks(self, kill_event): logger.critical("[TASK_PULL_THREAD] Exiting") break - logger.debug("To-Die Counts: {}".format(self.worker_map.to_die_count)) - logger.debug("Alive worker counts: {}".format(self.worker_map.total_worker_type_counts)) + logger.debug(f"To-Die Counts: {self.worker_map.to_die_count}") + logger.debug( + "Alive worker counts: {}".format( + self.worker_map.total_worker_type_counts + ) + ) - new_worker_map = naive_scheduler(self.task_queues, - self.outstanding_task_count, - self.max_worker_count, - new_worker_map, - self.worker_map.to_die_count, - logger=logger) - logger.debug("[SCHEDULER] New worker map: {}".format(new_worker_map)) + new_worker_map = naive_scheduler( + self.task_queues, + self.outstanding_task_count, + self.max_worker_count, + new_worker_map, + self.worker_map.to_die_count, + logger=logger, + ) + logger.debug(f"[SCHEDULER] New worker map: {new_worker_map}") - # NOTE: Wipes the queue -- previous scheduling loops don't affect what's needed now. - self.next_worker_q, need_more = self.worker_map.get_next_worker_q(new_worker_map) + # NOTE: Wipes the queue -- previous scheduling loops don't affect what's + # needed now. + self.next_worker_q, need_more = self.worker_map.get_next_worker_q( + new_worker_map + ) # Spin up any new workers according to the worker queue. # Returns the total number of containers that have spun up. - self.worker_procs.update(self.worker_map.spin_up_workers(self.next_worker_q, - mode=self.worker_mode, - debug=self.debug, - container_cmd_options=self.container_cmd_options, - address=self.address, - uid=self.uid, - logdir=self.logdir, - worker_port=self.worker_port)) + self.worker_procs.update( + self.worker_map.spin_up_workers( + self.next_worker_q, + mode=self.worker_mode, + debug=self.debug, + container_cmd_options=self.container_cmd_options, + address=self.address, + uid=self.uid, + logdir=self.logdir, + worker_port=self.worker_port, + ) + ) logger.debug(f"[SPIN UP] Worker processes: {self.worker_procs}") # Count the workers of each type that need to be removed - spin_downs, container_switch_count = self.worker_map.spin_down_workers(new_worker_map, - worker_max_idletime=self.worker_max_idletime, - need_more=need_more, - scheduler_mode=self.scheduler_mode) + spin_downs, container_switch_count = self.worker_map.spin_down_workers( + new_worker_map, + worker_max_idletime=self.worker_max_idletime, + need_more=need_more, + scheduler_mode=self.scheduler_mode, + ) self.container_switch_count += container_switch_count - logger.debug("Container switch count: total {}, cur {}".format(self.container_switch_count, container_switch_count)) + logger.debug( + "Container switch count: total {}, cur {}".format( + self.container_switch_count, container_switch_count + ) + ) for w_type in spin_downs: self.remove_worker_init(w_type) current_worker_map = self.worker_map.get_worker_counts() for task_type in current_worker_map: - if task_type == 'unused': + if task_type == "unused": continue # *** Match tasks to workers *** # else: available_workers = current_worker_map[task_type] - logger.debug("Available workers of type {}: {}".format(task_type, - available_workers)) + logger.debug( + "Available workers of type {}: {}".format( + task_type, available_workers + ) + ) for _i in range(available_workers): - if task_type in self.task_queues and not self.task_queues[task_type].qsize() == 0 \ - and not self.worker_map.worker_queues[task_type].qsize() == 0: - - logger.debug("Task type {} has task queue size {}" - .format(task_type, self.task_queues[task_type].qsize())) - logger.debug("... and available workers: {}" - .format(self.worker_map.worker_queues[task_type].qsize())) + if ( + task_type in self.task_queues + and not self.task_queues[task_type].qsize() == 0 + and not self.worker_map.worker_queues[task_type].qsize() + == 0 + ): + + logger.debug( + "Task type {} has task queue size {}".format( + task_type, self.task_queues[task_type].qsize() + ) + ) + logger.debug( + "... and available workers: {}".format( + self.worker_map.worker_queues[task_type].qsize() + ) + ) self.send_task_to_worker(task_type) def poll_funcx_task_socket(self, test=False): try: w_id, m_type, message = self.funcx_task_socket.recv_multipart() - if m_type == b'REGISTER': + if m_type == b"REGISTER": reg_info = pickle.loads(message) - logger.debug("Registration received from worker:{} {}".format(w_id, reg_info)) - self.worker_map.register_worker(w_id, reg_info['worker_type']) + logger.debug(f"Registration received from worker:{w_id} {reg_info}") + self.worker_map.register_worker(w_id, reg_info["worker_type"]) - elif m_type == b'TASK_RET': - logger.debug("Result received from worker: {}".format(w_id)) - logger.debug("[TASK_PULL_THREAD] Got result: {}".format(message)) + elif m_type == b"TASK_RET": + logger.debug(f"Result received from worker: {w_id}") + logger.debug(f"[TASK_PULL_THREAD] Got result: {message}") self.pending_result_queue.put(message) self.worker_map.put_worker(w_id) self.task_done_counter += 1 - task_id = pickle.loads(message)['task_id'] + task_id = pickle.loads(message)["task_id"] task_type = self.task_type_mapping.pop(task_id) self.task_status_deltas.pop(task_id, None) - logger.debug("Task type: {}".format(task_type)) + logger.debug(f"Task type: {task_type}") self.outstanding_task_count[task_type] -= 1 - logger.debug("Got result: Outstanding task counts: {}".format(self.outstanding_task_count)) - - elif m_type == b'WRKR_DIE': - logger.debug("[WORKER_REMOVE] Removing worker {} from worker_map...".format(w_id)) - logger.debug("Ready worker counts: {}".format(self.worker_map.ready_worker_type_counts)) - logger.debug("Total worker counts: {}".format(self.worker_map.total_worker_type_counts)) + logger.debug( + "Got result: Outstanding task counts: {}".format( + self.outstanding_task_count + ) + ) + + elif m_type == b"WRKR_DIE": + logger.debug( + f"[WORKER_REMOVE] Removing worker {w_id} from worker_map..." + ) + logger.debug( + "Ready worker counts: {}".format( + self.worker_map.ready_worker_type_counts + ) + ) + logger.debug( + "Total worker counts: {}".format( + self.worker_map.total_worker_type_counts + ) + ) self.worker_map.remove_worker(w_id) proc = self.worker_procs.pop(w_id.decode()) if not proc.poll(): try: proc.wait(timeout=1) except subprocess.TimeoutExpired: - logger.warning(f"[WORKER_REMOVE] Timeout waiting for worker {w_id} process to terminate") + logger.warning( + "[WORKER_REMOVE] Timeout waiting for worker %s process to " + "terminate", + w_id, + ) logger.debug(f"[WORKER_REMOVE] Removing worker {w_id} process object") logger.debug(f"[WORKER_REMOVE] Worker processes: {self.worker_procs}") @@ -447,15 +535,20 @@ def poll_funcx_task_socket(self, test=False): return pickle.loads(message) except Exception as e: - logger.exception("[TASK_PULL_THREAD] FUNCX : caught {}".format(e)) + logger.exception(f"[TASK_PULL_THREAD] FUNCX : caught {e}") def send_task_to_worker(self, task_type): task = self.task_queues[task_type].get() worker_id = self.worker_map.get_worker(task_type) - logger.debug("Sending task {} to {}".format(task.task_id, worker_id)) + logger.debug(f"Sending task {task.task_id} to {worker_id}") # TODO: Some duplication of work could be avoided here - to_send = [worker_id, pickle.dumps(task.task_id), pickle.dumps(task.container_id), task.pack()] + to_send = [ + worker_id, + pickle.dumps(task.task_id), + pickle.dumps(task.container_id), + task.pack(), + ] self.funcx_task_socket.send_multipart(to_send) self.worker_map.update_worker_idle(task_type) if task.task_id != "KILL": @@ -471,14 +564,16 @@ def _status_report_loop(self, kill_event): self.task_status_deltas, self.container_switch_count, ) - logger.info(f"[STATUS] Sending status report to interchange: {msg.task_statuses}") + logger.info( + f"[STATUS] Sending status report to interchange: {msg.task_statuses}" + ) self.pending_result_queue.put(msg) logger.info("[STATUS] Clearing task deltas") self.task_status_deltas.clear() time.sleep(self.heartbeat_period) def push_results(self, kill_event, max_result_batch_size=1): - """ Listens on the pending_result_queue and sends out results via 0mq + """Listens on the pending_result_queue and sends out results via 0mq Parameters: ----------- @@ -488,8 +583,10 @@ def push_results(self, kill_event, max_result_batch_size=1): logger.debug("[RESULT_PUSH_THREAD] Starting thread") - push_poll_period = max(10, self.poll_period) / 1000 # push_poll_period must be atleast 10 ms - logger.debug("[RESULT_PUSH_THREAD] push poll period: {}".format(push_poll_period)) + push_poll_period = ( + max(10, self.poll_period) / 1000 + ) # push_poll_period must be atleast 10 ms + logger.debug(f"[RESULT_PUSH_THREAD] push poll period: {push_poll_period}") last_beat = time.time() items = [] @@ -497,8 +594,10 @@ def push_results(self, kill_event, max_result_batch_size=1): while not kill_event.is_set(): try: r = self.pending_result_queue.get(block=True, timeout=push_poll_period) - # This avoids the interchange searching and attempting to unpack every message in case it's a - # status report. (Would be better to use Task Messages eventually to make this more uniform) + # This avoids the interchange searching and attempting to unpack every + # message in case it's a status report. + # (It would be better to use Task Messages eventually to make this more + # uniform) # TODO: use task messages, and don't have to prepend if isinstance(r, ManagerStatusReport): items.insert(0, r.pack()) @@ -507,10 +606,14 @@ def push_results(self, kill_event, max_result_batch_size=1): except queue.Empty: pass except Exception as e: - logger.exception("[RESULT_PUSH_THREAD] Got an exception: {}".format(e)) - - # If we have reached poll_period duration or timer has expired, we send results - if len(items) >= self.max_queue_size or time.time() > last_beat + push_poll_period: + logger.exception(f"[RESULT_PUSH_THREAD] Got an exception: {e}") + + # If we have reached poll_period duration or timer has expired, we send + # results + if ( + len(items) >= self.max_queue_size + or time.time() > last_beat + push_poll_period + ): last_beat = time.time() if items: self.result_outgoing.send_multipart(items) @@ -520,18 +623,21 @@ def push_results(self, kill_event, max_result_batch_size=1): def remove_worker_init(self, worker_type): """ - Kill/Remove a worker of a given worker_type. + Kill/Remove a worker of a given worker_type. - Add a kill message to the task_type queue. + Add a kill message to the task_type queue. - Assumption : All workers of the same type are uniform, and therefore don't discriminate when killing. + Assumption : All workers of the same type are uniform, and therefore don't + discriminate when killing. """ - logger.debug("[WORKER_REMOVE] Appending KILL message to worker queue {}".format(worker_type)) + logger.debug( + "[WORKER_REMOVE] Appending KILL message to worker queue {}".format( + worker_type + ) + ) self.worker_map.to_die_count[worker_type] += 1 - task = Task(task_id='KILL', - container_id='RAW', - task_buffer='KILL') + task = Task(task_id="KILL", container_id="RAW", task_buffer="KILL") self.task_queues[worker_type].put(task) def start(self): @@ -542,16 +648,24 @@ def start(self): Forward results """ - if self.worker_type and self.scheduler_mode == 'hard': - logger.debug("[MANAGER] Start an initial worker with worker type {}".format(self.worker_type)) - self.worker_procs.update(self.worker_map.add_worker(worker_id=str(self.worker_map.worker_id_counter), - worker_type=self.worker_type, - container_cmd_options=self.container_cmd_options, - address=self.address, - debug=self.debug, - uid=self.uid, - logdir=self.logdir, - worker_port=self.worker_port)) + if self.worker_type and self.scheduler_mode == "hard": + logger.debug( + "[MANAGER] Start an initial worker with worker type {}".format( + self.worker_type + ) + ) + self.worker_procs.update( + self.worker_map.add_worker( + worker_id=str(self.worker_map.worker_id_counter), + worker_type=self.worker_type, + container_cmd_options=self.container_cmd_options, + address=self.address, + debug=self.debug, + uid=self.uid, + logdir=self.logdir, + worker_port=self.worker_port, + ) + ) logger.debug("Initial workers launched") self._result_pusher_thread.start() @@ -563,42 +677,86 @@ def start(self): def cli_run(): parser = argparse.ArgumentParser() - parser.add_argument("-d", "--debug", action='store_true', - help="Count of apps to launch") - parser.add_argument("-l", "--logdir", default="process_worker_pool_logs", - help="Process worker pool log directory") - parser.add_argument("-u", "--uid", default=str(uuid.uuid4()).split('-')[-1], - help="Unique identifier string for Manager") - parser.add_argument("-b", "--block_id", default=None, - help="Block identifier string for Manager") - parser.add_argument("-c", "--cores_per_worker", default="1.0", - help="Number of cores assigned to each worker process. Default=1.0") - parser.add_argument("-t", "--task_url", required=True, - help="REQUIRED: ZMQ url for receiving tasks") - parser.add_argument("--max_workers", default=float('inf'), - help="Caps the maximum workers that can be launched, default:infinity") - parser.add_argument("--hb_period", default=30, - help="Heartbeat period in seconds. Uses manager default unless set") - parser.add_argument("--hb_threshold", default=120, - help="Heartbeat threshold in seconds. Uses manager default unless set") - parser.add_argument("--poll", default=10, - help="Poll period used in milliseconds") - parser.add_argument("--worker_type", default=None, - help="Fixed worker type of manager") - parser.add_argument("--worker_mode", default="singularity_reuse", - help=("Choose the mode of operation from " - "(no_container, singularity_reuse, singularity_single_use")) - parser.add_argument("--container_cmd_options", default="", - help=("Container cmd options to add to container startup cmd")) - parser.add_argument("--scheduler_mode", default="soft", - help=("Choose the mode of scheduler " - "(hard, soft")) - parser.add_argument("-r", "--result_url", required=True, - help="REQUIRED: ZMQ url for posting results") - parser.add_argument("--log_max_bytes", default=256 * 1024 * 1024, - help="The maximum bytes per logger file in bytes") - parser.add_argument("--log_backup_count", default=1, - help="The number of backup (must be non-zero) per logger file") + parser.add_argument( + "-d", "--debug", action="store_true", help="Count of apps to launch" + ) + parser.add_argument( + "-l", + "--logdir", + default="process_worker_pool_logs", + help="Process worker pool log directory", + ) + parser.add_argument( + "-u", + "--uid", + default=str(uuid.uuid4()).split("-")[-1], + help="Unique identifier string for Manager", + ) + parser.add_argument( + "-b", "--block_id", default=None, help="Block identifier string for Manager" + ) + parser.add_argument( + "-c", + "--cores_per_worker", + default="1.0", + help="Number of cores assigned to each worker process. Default=1.0", + ) + parser.add_argument( + "-t", "--task_url", required=True, help="REQUIRED: ZMQ url for receiving tasks" + ) + parser.add_argument( + "--max_workers", + default=float("inf"), + help="Caps the maximum workers that can be launched, default:infinity", + ) + parser.add_argument( + "--hb_period", + default=30, + help="Heartbeat period in seconds. Uses manager default unless set", + ) + parser.add_argument( + "--hb_threshold", + default=120, + help="Heartbeat threshold in seconds. Uses manager default unless set", + ) + parser.add_argument("--poll", default=10, help="Poll period used in milliseconds") + parser.add_argument( + "--worker_type", default=None, help="Fixed worker type of manager" + ) + parser.add_argument( + "--worker_mode", + default="singularity_reuse", + help=( + "Choose the mode of operation from " + "(no_container, singularity_reuse, singularity_single_use" + ), + ) + parser.add_argument( + "--container_cmd_options", + default="", + help=("Container cmd options to add to container startup cmd"), + ) + parser.add_argument( + "--scheduler_mode", + default="soft", + help=("Choose the mode of scheduler " "(hard, soft"), + ) + parser.add_argument( + "-r", + "--result_url", + required=True, + help="REQUIRED: ZMQ url for posting results", + ) + parser.add_argument( + "--log_max_bytes", + default=256 * 1024 * 1024, + help="The maximum bytes per logger file in bytes", + ) + parser.add_argument( + "--log_backup_count", + default=1, + help="The number of backup (must be non-zero) per logger file", + ) args = parser.parse_args() @@ -609,52 +767,61 @@ def cli_run(): try: global logger - # TODO The config options for the rotatingfilehandler need to be implemented and checked so that it is user configurable - logger = set_file_logger(os.path.join(args.logdir, args.uid, 'manager.log'), - name='funcx_manager', - level=logging.DEBUG if args.debug is True else logging.INFO, - max_bytes=float(args.log_max_bytes), # TODO: Test if this still works on forwarder_rearch_1 - backup_count=int(args.log_backup_count)) # TODO: Test if this still works on forwarder_rearch_1 - - logger.info("Python version: {}".format(sys.version)) - logger.info("Debug logging: {}".format(args.debug)) - logger.info("Log dir: {}".format(args.logdir)) - logger.info("Manager ID: {}".format(args.uid)) - logger.info("Block ID: {}".format(args.block_id)) - logger.info("cores_per_worker: {}".format(args.cores_per_worker)) - logger.info("task_url: {}".format(args.task_url)) - logger.info("result_url: {}".format(args.result_url)) - logger.info("hb_period: {}".format(args.hb_period)) - logger.info("hb_threshold: {}".format(args.hb_threshold)) - logger.info("max_workers: {}".format(args.max_workers)) - logger.info("poll_period: {}".format(args.poll)) - logger.info("worker_mode: {}".format(args.worker_mode)) - logger.info("container_cmd_options: {}".format(args.container_cmd_options)) - logger.info("scheduler_mode: {}".format(args.scheduler_mode)) - logger.info("worker_type: {}".format(args.worker_type)) - logger.info("log_max_bytes: {}".format(args.log_max_bytes)) - logger.info("log_backup_count: {}".format(args.log_backup_count)) - - manager = Manager(task_q_url=args.task_url, - result_q_url=args.result_url, - uid=args.uid, - block_id=args.block_id, - cores_per_worker=float(args.cores_per_worker), - max_workers=args.max_workers if args.max_workers == float('inf') else int(args.max_workers), - heartbeat_threshold=int(args.hb_threshold), - heartbeat_period=int(args.hb_period), - logdir=args.logdir, - debug=args.debug, - worker_mode=args.worker_mode, - container_cmd_options=args.container_cmd_options, - scheduler_mode=args.scheduler_mode, - worker_type=args.worker_type, - poll_period=int(args.poll)) + # TODO The config options for the rotatingfilehandler need to be implemented + # and checked so that it is user configurable + logger = set_file_logger( + os.path.join(args.logdir, args.uid, "manager.log"), + name="funcx_manager", + level=logging.DEBUG if args.debug is True else logging.INFO, + max_bytes=float( + args.log_max_bytes + ), # TODO: Test if this still works on forwarder_rearch_1 + backup_count=int(args.log_backup_count), + ) # TODO: Test if this still works on forwarder_rearch_1 + + logger.info(f"Python version: {sys.version}") + logger.info(f"Debug logging: {args.debug}") + logger.info(f"Log dir: {args.logdir}") + logger.info(f"Manager ID: {args.uid}") + logger.info(f"Block ID: {args.block_id}") + logger.info(f"cores_per_worker: {args.cores_per_worker}") + logger.info(f"task_url: {args.task_url}") + logger.info(f"result_url: {args.result_url}") + logger.info(f"hb_period: {args.hb_period}") + logger.info(f"hb_threshold: {args.hb_threshold}") + logger.info(f"max_workers: {args.max_workers}") + logger.info(f"poll_period: {args.poll}") + logger.info(f"worker_mode: {args.worker_mode}") + logger.info(f"container_cmd_options: {args.container_cmd_options}") + logger.info(f"scheduler_mode: {args.scheduler_mode}") + logger.info(f"worker_type: {args.worker_type}") + logger.info(f"log_max_bytes: {args.log_max_bytes}") + logger.info(f"log_backup_count: {args.log_backup_count}") + + manager = Manager( + task_q_url=args.task_url, + result_q_url=args.result_url, + uid=args.uid, + block_id=args.block_id, + cores_per_worker=float(args.cores_per_worker), + max_workers=args.max_workers + if args.max_workers == float("inf") + else int(args.max_workers), + heartbeat_threshold=int(args.hb_threshold), + heartbeat_period=int(args.hb_period), + logdir=args.logdir, + debug=args.debug, + worker_mode=args.worker_mode, + container_cmd_options=args.container_cmd_options, + scheduler_mode=args.scheduler_mode, + worker_type=args.worker_type, + poll_period=int(args.poll), + ) manager.start() except Exception as e: logger.critical("process_worker_pool exiting from an exception") - logger.exception("Caught error: {}".format(e)) + logger.exception(f"Caught error: {e}") raise else: logger.info("process_worker_pool main event loop exiting normally") diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_worker.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_worker.py index 75c250fb4..c3d41fc4c 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_worker.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/funcx_worker.py @@ -1,32 +1,37 @@ #!/usr/bin/env python3 -import logging import argparse -import zmq -import sys -import pickle +import logging import os +import pickle +import sys +import zmq from parsl.app.errors import RemoteExceptionWrapper from funcx import set_file_logger from funcx.serialize import FuncXSerializer -from funcx_endpoint.executors.high_throughput.messages import Message, COMMAND_TYPES, MessageType, Task +from funcx_endpoint.executors.high_throughput.messages import Message class MaxResultSizeExceeded(Exception): - """ Result produced by the function exceeds the maximum supported result size threshold of 512000B """ + Result produced by the function exceeds the maximum supported result size + threshold of 512000B""" + def __init__(self, result_size, result_size_limit): self.result_size = result_size self.result_size_limit = result_size_limit def __str__(self): - return f"Task result of {self.result_size}B exceeded current limit of {self.result_size_limit}B" + return ( + f"Task result of {self.result_size}B exceeded current " + f"limit of {self.result_size_limit}B" + ) -class FuncXWorker(object): - """ The FuncX worker +class FuncXWorker: + """The FuncX worker Parameters ---------- @@ -56,7 +61,16 @@ class FuncXWorker(object): send(result) """ - def __init__(self, worker_id, address, port, logdir, debug=False, worker_type='RAW', result_size_limit=512000): + def __init__( + self, + worker_id, + address, + port, + logdir, + debug=False, + worker_type="RAW", + result_size_limit=512000, + ): self.worker_id = worker_id self.address = address @@ -70,15 +84,17 @@ def __init__(self, worker_id, address, port, logdir, debug=False, worker_type='R self.result_size_limit = result_size_limit global logger - logger = set_file_logger(os.path.join(logdir, f'funcx_worker_{worker_id}.log'), - name="worker_log", - level=logging.DEBUG if debug else logging.INFO) + logger = set_file_logger( + os.path.join(logdir, f"funcx_worker_{worker_id}.log"), + name="worker_log", + level=logging.DEBUG if debug else logging.INFO, + ) - logger.info('Initializing worker {}'.format(worker_id)) - logger.info('Worker is of type: {}'.format(worker_type)) + logger.info(f"Initializing worker {worker_id}") + logger.info(f"Worker is of type: {worker_type}") if debug: - logger.debug('Debug logging enabled') + logger.debug("Debug logging enabled") self.context = zmq.Context() self.poller = zmq.Poller() @@ -87,23 +103,23 @@ def __init__(self, worker_id, address, port, logdir, debug=False, worker_type='R self.task_socket = self.context.socket(zmq.DEALER) self.task_socket.setsockopt(zmq.IDENTITY, self.identity) - logger.info('Trying to connect to : tcp://{}:{}'.format(self.address, self.port)) - self.task_socket.connect('tcp://{}:{}'.format(self.address, self.port)) + logger.info(f"Trying to connect to : tcp://{self.address}:{self.port}") + self.task_socket.connect(f"tcp://{self.address}:{self.port}") self.poller.register(self.task_socket, zmq.POLLIN) def registration_message(self): - return {'worker_id': self.worker_id, - 'worker_type': self.worker_type} + return {"worker_id": self.worker_id, "worker_type": self.worker_type} def start(self): logger.info("Starting worker") result = self.registration_message() - task_type = b'REGISTER' + task_type = b"REGISTER" logger.debug("Sending registration") - self.task_socket.send_multipart([task_type, # Byte encoded - pickle.dumps(result)]) + self.task_socket.send_multipart( + [task_type, pickle.dumps(result)] # Byte encoded + ) while True: @@ -111,17 +127,19 @@ def start(self): p_task_id, p_container_id, msg = self.task_socket.recv_multipart() task_id = pickle.loads(p_task_id) container_id = pickle.loads(p_container_id) - logger.debug("Received task_id:{} with task:{}".format(task_id, msg)) + logger.debug(f"Received task_id:{task_id} with task:{msg}") result = None task_type = None if task_id == "KILL": task = Message.unpack(msg) - if task.task_buffer.decode('utf-8') == "KILL": + if task.task_buffer.decode("utf-8") == "KILL": logger.info("[KILL] -- Worker KILL message received! ") - task_type = b'WRKR_DIE' + task_type = b"WRKR_DIE" else: - logger.exception("Caught an exception of non-KILL message for KILL task") + logger.exception( + "Caught an exception of non-KILL message for KILL task" + ) continue else: logger.debug("Executing task...") @@ -131,29 +149,36 @@ def start(self): serialized_result = self.serialize(result) if sys.getsizeof(serialized_result) > self.result_size_limit: - raise MaxResultSizeExceeded(sys.getsizeof(serialized_result), - self.result_size_limit) + raise MaxResultSizeExceeded( + sys.getsizeof(serialized_result), self.result_size_limit + ) except Exception as e: logger.exception(f"Caught an exception {e}") - result_package = {'task_id': task_id, - 'container_id': container_id, - 'exception': self.serialize( - RemoteExceptionWrapper(*sys.exc_info()))} + result_package = { + "task_id": task_id, + "container_id": container_id, + "exception": self.serialize( + RemoteExceptionWrapper(*sys.exc_info()) + ), + } else: logger.debug("Execution completed without exception") - result_package = {'task_id': task_id, - 'container_id': container_id, - 'result': serialized_result} + result_package = { + "task_id": task_id, + "container_id": container_id, + "result": serialized_result, + } result = result_package - task_type = b'TASK_RET' + task_type = b"TASK_RET" logger.debug("Sending result") - self.task_socket.send_multipart([task_type, # Byte encoded - pickle.dumps(result)]) + self.task_socket.send_multipart( + [task_type, pickle.dumps(result)] # Byte encoded + ) - if task_type == b'WRKR_DIE': - logger.info("*** WORKER {} ABOUT TO DIE ***".format(self.worker_id)) + if task_type == b"WRKR_DIE": + logger.info(f"*** WORKER {self.worker_id} ABOUT TO DIE ***") # Kill the worker after accepting death in message to manager. sys.exit() # We need to return here to allow for sys.exit mocking in tests @@ -167,32 +192,48 @@ def execute_task(self, message): Returns the result or throws exception. """ task = Message.unpack(message) - f, args, kwargs = self.serializer.unpack_and_deserialize(task.task_buffer.decode('utf-8')) + f, args, kwargs = self.serializer.unpack_and_deserialize( + task.task_buffer.decode("utf-8") + ) return f(*args, **kwargs) def cli_run(): parser = argparse.ArgumentParser() - parser.add_argument("-w", "--worker_id", required=True, - help="ID of worker from process_worker_pool") - parser.add_argument("-t", "--type", required=False, - help="Container type of worker", default="RAW") - parser.add_argument("-a", "--address", required=True, - help="Address for the manager, eg X,Y,") - parser.add_argument("-p", "--port", required=True, - help="Internal port at which the worker connects to the manager") - parser.add_argument("--logdir", required=True, - help="Directory path where worker log files written") - parser.add_argument("-d", "--debug", action='store_true', - help="Directory path where worker log files written") + parser.add_argument( + "-w", "--worker_id", required=True, help="ID of worker from process_worker_pool" + ) + parser.add_argument( + "-t", "--type", required=False, help="Container type of worker", default="RAW" + ) + parser.add_argument( + "-a", "--address", required=True, help="Address for the manager, eg X,Y," + ) + parser.add_argument( + "-p", + "--port", + required=True, + help="Internal port at which the worker connects to the manager", + ) + parser.add_argument( + "--logdir", required=True, help="Directory path where worker log files written" + ) + parser.add_argument( + "-d", + "--debug", + action="store_true", + help="Directory path where worker log files written", + ) args = parser.parse_args() - worker = FuncXWorker(args.worker_id, - args.address, - int(args.port), - args.logdir, - worker_type=args.type, - debug=args.debug, ) + worker = FuncXWorker( + args.worker_id, + args.address, + int(args.port), + args.logdir, + worker_type=args.type, + debug=args.debug, + ) worker.start() return diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/global_config.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/global_config.py index 3ea43e94d..10d43462d 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/global_config.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/global_config.py @@ -1,10 +1,11 @@ import getpass + from parsl.addresses import address_by_hostname global_options = { - 'username': getpass.getuser(), - 'email': 'USER@USERDOMAIN.COM', - 'broker_address': '127.0.0.1', - 'broker_port': 8088, - 'endpoint_address': address_by_hostname(), + "username": getpass.getuser(), + "email": "USER@USERDOMAIN.COM", + "broker_address": "127.0.0.1", + "broker_port": 8088, + "endpoint_address": address_by_hostname(), } diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange.py index d98bc8727..b29b9fe1e 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange.py @@ -1,33 +1,37 @@ #!/usr/bin/env python import argparse -from typing import Tuple, Dict - -import zmq +import collections +import json +import logging import os -import sys -import platform -import random -import time import pickle -import logging +import platform import queue +import sys import threading -import json -import daemon -import collections - -from logging.handlers import RotatingFileHandler +import time +from typing import Dict, Tuple +import daemon +import zmq +from parsl.app.errors import RemoteExceptionWrapper from parsl.executors.errors import ScalingFailed from parsl.version import VERSION as PARSL_VERSION -from parsl.app.errors import RemoteExceptionWrapper -from funcx_endpoint.executors.high_throughput.messages import Message, COMMAND_TYPES, MessageType, Task -from funcx_endpoint.executors.high_throughput.messages import EPStatusReport, Heartbeat, TaskStatusCode -from funcx.sdk.client import FuncXClient from funcx import set_file_logger -from funcx_endpoint.executors.high_throughput.interchange_task_dispatch import naive_interchange_task_dispatch +from funcx.sdk.client import FuncXClient from funcx.serialize import FuncXSerializer +from funcx_endpoint.executors.high_throughput.interchange_task_dispatch import ( + naive_interchange_task_dispatch, +) +from funcx_endpoint.executors.high_throughput.messages import ( + COMMAND_TYPES, + EPStatusReport, + Heartbeat, + Message, + MessageType, + TaskStatusCode, +) LOOP_SLOWDOWN = 0.0 # in seconds HEARTBEAT_CODE = (2 ** 32) - 1 @@ -35,21 +39,20 @@ class ShutdownRequest(Exception): - """ Exception raised when any async component receives a ShutdownRequest - """ + """Exception raised when any async component receives a ShutdownRequest""" def __init__(self): self.tstamp = time.time() def __repr__(self): - return "Shutdown request received at {}".format(self.tstamp) + return f"Shutdown request received at {self.tstamp}" def __str__(self): return self.__repr__() class ManagerLost(Exception): - """ Task lost due to worker loss. Worker is considered lost when multiple heartbeats + """Task lost due to worker loss. Worker is considered lost when multiple heartbeats have been missed. """ @@ -58,31 +61,32 @@ def __init__(self, worker_id): self.tstamp = time.time() def __repr__(self): - return "Task failure due to loss of manager {}".format(self.worker_id) + return f"Task failure due to loss of manager {self.worker_id}" def __str__(self): return self.__repr__() class BadRegistration(Exception): - ''' A new Manager tried to join the executor with a BadRegistration message - ''' + """A new Manager tried to join the executor with a BadRegistration message""" + def __init__(self, worker_id, critical=False): self.worker_id = worker_id self.tstamp = time.time() self.handled = "critical" if critical else "suppressed" def __repr__(self): - return "Manager {} attempted to register with a bad registration message. Caused a {} failure".format( - self.worker_id, - self.handled) + return ( + f"Manager {self.worker_id} attempted to register with a bad " + f"registration message. Caused a {self.handled} failure" + ) def __str__(self): return self.__repr__() -class Interchange(object): - """ Interchange is a task orchestrator for distributed systems. +class Interchange: + """Interchange is a task orchestrator for distributed systems. 1. Asynchronously queue large volume of tasks (>100K) 2. Allow for workers to join and leave the union @@ -94,42 +98,41 @@ class Interchange(object): TODO: We most likely need a PUB channel to send out global commands, like shutdown """ - def __init__(self, - # - strategy=None, - poll_period=None, - heartbeat_period=None, - heartbeat_threshold=None, - working_dir=None, - provider=None, - max_workers_per_node=None, - mem_per_worker=None, - prefetch_capacity=None, - - scheduler_mode=None, - container_type=None, - container_cmd_options='', - worker_mode=None, - cold_routing_interval=10.0, - - funcx_service_address=None, - scaling_enabled=True, - # - client_address="127.0.0.1", - interchange_address="127.0.0.1", - client_ports: Tuple[int, int, int] = (50055, 50056, 50057), - worker_ports=None, - worker_port_range=(54000, 55000), - cores_per_worker=1.0, - worker_debug=False, - launch_cmd=None, - logdir=".", - logging_level=logging.INFO, - endpoint_id=None, - suppress_failure=False, - log_max_bytes=256 * 1024 * 1024, - log_backup_count=1, - ): + def __init__( + self, + # + strategy=None, + poll_period=None, + heartbeat_period=None, + heartbeat_threshold=None, + working_dir=None, + provider=None, + max_workers_per_node=None, + mem_per_worker=None, + prefetch_capacity=None, + scheduler_mode=None, + container_type=None, + container_cmd_options="", + worker_mode=None, + cold_routing_interval=10.0, + funcx_service_address=None, + scaling_enabled=True, + # + client_address="127.0.0.1", + interchange_address="127.0.0.1", + client_ports: Tuple[int, int, int] = (50055, 50056, 50057), + worker_ports=None, + worker_port_range=(54000, 55000), + cores_per_worker=1.0, + worker_debug=False, + launch_cmd=None, + logdir=".", + logging_level=logging.INFO, + endpoint_id=None, + suppress_failure=False, + log_max_bytes=256 * 1024 * 1024, + log_backup_count=1, + ): """ Parameters ---------- @@ -137,10 +140,12 @@ def __init__(self, Funcx config object that describes how compute should be provisioned client_address : str - The ip address at which the parsl client can be reached. Default: "127.0.0.1" + The ip address at which the parsl client can be reached. + Default: "127.0.0.1" interchange_address : str - The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1" + The ip address at which the workers will be able to reach the Interchange. + Default: "127.0.0.1" client_ports : Tuple[int, int, int] The ports at which the client can be reached @@ -149,11 +154,13 @@ def __init__(self, TODO : update worker_ports : tuple(int, int) - The specific two ports at which workers will connect to the Interchange. Default: None + The specific two ports at which workers will connect to the Interchange. + Default: None worker_port_range : tuple(int, int) - The interchange picks ports at random from the range which will be used by workers. - This is overridden when the worker_ports option is set. Defauls: (54000, 55000) + The interchange picks ports at random from the range which will be used by + workers. This is overridden when the worker_ports option is set. + Default: (54000, 55000) cores_per_worker : float cores to be assigned to each worker. Oversubscription is possible @@ -164,7 +171,8 @@ def __init__(self, For example, singularity exec {container_cmd_options} cold_routing_interval: float - The time interval between warm and cold function routing in SOFT scheduler_mode. + The time interval between warm and cold function routing in SOFT + scheduler_mode. It is ONLY used when using soft scheduler_mode. We need this to avoid container workers being idle for too long. But we dont't want this cold routing to occur too often, @@ -184,11 +192,12 @@ def __init__(self, Identity string that identifies the endpoint to the broker suppress_failure : Bool - When set to True, the interchange will attempt to suppress failures. Default: False + When set to True, the interchange will attempt to suppress failures. + Default: False funcx_service_address: str - Override funcx_service_address used by the FuncXClient. If no address is specified, - the FuncXClient's default funcx_service_address is used. + Override funcx_service_address used by the FuncXClient. If no address is + specified, the FuncXClient's default funcx_service_address is used. Default: None """ @@ -196,17 +205,21 @@ def __init__(self, os.makedirs(self.logdir, exist_ok=True) global logger - logger = set_file_logger(os.path.join(self.logdir, 'interchange.log'), - name="interchange", - level=logging_level, - max_bytes=log_max_bytes, - backup_count=log_backup_count) - - logger.info("logger location {}, logger filesize: {}, logger backup count: {}".format(logger.handlers, - log_max_bytes, - log_backup_count)) + logger = set_file_logger( + os.path.join(self.logdir, "interchange.log"), + name="interchange", + level=logging_level, + max_bytes=log_max_bytes, + backup_count=log_backup_count, + ) + + logger.info( + "logger location {}, logger filesize: {}, logger backup count: {}".format( + logger.handlers, log_max_bytes, log_backup_count + ) + ) - logger.info("Initializing Interchange process with Endpoint ID: {}".format(endpoint_id)) + logger.info(f"Initializing Interchange process with Endpoint ID: {endpoint_id}") # self.max_workers_per_node = max_workers_per_node @@ -240,25 +253,28 @@ def __init__(self, self.last_heartbeat = time.time() self.serializer = FuncXSerializer() - logger.info("Attempting connection to forwarder at {} on ports: {},{},{}".format( - client_address, client_ports[0], client_ports[1], client_ports[2])) + logger.info( + "Attempting connection to forwarder at {} on ports: {},{},{}".format( + client_address, client_ports[0], client_ports[1], client_ports[2] + ) + ) self.context = zmq.Context() self.task_incoming = self.context.socket(zmq.DEALER) self.task_incoming.set_hwm(0) self.task_incoming.RCVTIMEO = 10 # in milliseconds - logger.info("Task incoming on tcp://{}:{}".format(client_address, client_ports[0])) - self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0])) + logger.info(f"Task incoming on tcp://{client_address}:{client_ports[0]}") + self.task_incoming.connect(f"tcp://{client_address}:{client_ports[0]}") self.results_outgoing = self.context.socket(zmq.DEALER) self.results_outgoing.set_hwm(0) - logger.info("Results outgoing on tcp://{}:{}".format(client_address, client_ports[1])) - self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1])) + logger.info(f"Results outgoing on tcp://{client_address}:{client_ports[1]}") + self.results_outgoing.connect(f"tcp://{client_address}:{client_ports[1]}") self.command_channel = self.context.socket(zmq.DEALER) self.command_channel.RCVTIMEO = 1000 # in milliseconds # self.command_channel.set_hwm(0) - logger.info("Command channel on tcp://{}:{}".format(client_address, client_ports[2])) - self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2])) + logger.info(f"Command channel on tcp://{client_address}:{client_ports[2]}") + self.command_channel.connect(f"tcp://{client_address}:{client_ports[2]}") logger.info("Connected to forwarder") self.pending_task_queue = {} @@ -269,7 +285,7 @@ def __init__(self, else: self.fxs = FuncXClient() - logger.info("Interchange address is {}".format(self.interchange_address)) + logger.info(f"Interchange address is {self.interchange_address}") self.worker_ports = worker_ports self.worker_port_range = worker_port_range @@ -283,19 +299,28 @@ def __init__(self, self.worker_task_port = self.worker_ports[0] self.worker_result_port = self.worker_ports[1] - self.task_outgoing.bind("tcp://*:{}".format(self.worker_task_port)) - self.results_incoming.bind("tcp://*:{}".format(self.worker_result_port)) + self.task_outgoing.bind(f"tcp://*:{self.worker_task_port}") + self.results_incoming.bind(f"tcp://*:{self.worker_result_port}") else: - self.worker_task_port = self.task_outgoing.bind_to_random_port('tcp://*', - min_port=worker_port_range[0], - max_port=worker_port_range[1], max_tries=100) - self.worker_result_port = self.results_incoming.bind_to_random_port('tcp://*', - min_port=worker_port_range[0], - max_port=worker_port_range[1], max_tries=100) + self.worker_task_port = self.task_outgoing.bind_to_random_port( + "tcp://*", + min_port=worker_port_range[0], + max_port=worker_port_range[1], + max_tries=100, + ) + self.worker_result_port = self.results_incoming.bind_to_random_port( + "tcp://*", + min_port=worker_port_range[0], + max_port=worker_port_range[1], + max_tries=100, + ) - logger.info("Bound to ports {},{} for incoming worker connections".format( - self.worker_task_port, self.worker_result_port)) + logger.info( + "Bound to ports {},{} for incoming worker connections".format( + self.worker_task_port, self.worker_result_port + ) + ) self._ready_manager_queue = {} @@ -304,31 +329,35 @@ def __init__(self, self.launch_cmd = launch_cmd self.last_core_hr_counter = 0 if not launch_cmd: - self.launch_cmd = ("funcx-manager {debug} {max_workers} " - "-c {cores_per_worker} " - "--poll {poll_period} " - "--task_url={task_url} " - "--result_url={result_url} " - "--logdir={logdir} " - "--block_id={{block_id}} " - "--hb_period={heartbeat_period} " - "--hb_threshold={heartbeat_threshold} " - "--worker_mode={worker_mode} " - "--container_cmd_options='{container_cmd_options}' " - "--scheduler_mode={scheduler_mode} " - "--log_max_bytes={log_max_bytes} " - "--log_backup_count={log_backup_count} " - "--worker_type={{worker_type}} ") - - self.current_platform = {'parsl_v': PARSL_VERSION, - 'python_v': "{}.{}.{}".format(sys.version_info.major, - sys.version_info.minor, - sys.version_info.micro), - 'os': platform.system(), - 'hname': platform.node(), - 'dir': os.getcwd()} - - logger.info("Platform info: {}".format(self.current_platform)) + self.launch_cmd = ( + "funcx-manager {debug} {max_workers} " + "-c {cores_per_worker} " + "--poll {poll_period} " + "--task_url={task_url} " + "--result_url={result_url} " + "--logdir={logdir} " + "--block_id={{block_id}} " + "--hb_period={heartbeat_period} " + "--hb_threshold={heartbeat_threshold} " + "--worker_mode={worker_mode} " + "--container_cmd_options='{container_cmd_options}' " + "--scheduler_mode={scheduler_mode} " + "--log_max_bytes={log_max_bytes} " + "--log_backup_count={log_backup_count} " + "--worker_type={{worker_type}} " + ) + + self.current_platform = { + "parsl_v": PARSL_VERSION, + "python_v": "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro + ), + "os": platform.system(), + "hname": platform.node(), + "dir": os.getcwd(), + } + + logger.info(f"Platform info: {self.current_platform}") self._block_counter = 0 try: self.load_config() @@ -341,54 +370,64 @@ def __init__(self, self.container_switch_count = {} def load_config(self): - """ Load the config - """ + """Load the config""" logger.info("Loading endpoint local config") working_dir = self.working_dir if self.working_dir is None: working_dir = os.path.join(self.logdir, "worker_logs") - logger.info("Setting working_dir: {}".format(working_dir)) + logger.info(f"Setting working_dir: {working_dir}") self.provider.script_dir = working_dir - if hasattr(self.provider, 'channel'): - self.provider.channel.script_dir = os.path.join(working_dir, 'submit_scripts') - self.provider.channel.makedirs(self.provider.channel.script_dir, exist_ok=True) + if hasattr(self.provider, "channel"): + self.provider.channel.script_dir = os.path.join( + working_dir, "submit_scripts" + ) + self.provider.channel.makedirs( + self.provider.channel.script_dir, exist_ok=True + ) os.makedirs(self.provider.script_dir, exist_ok=True) debug_opts = "--debug" if self.worker_debug else "" - max_workers = "" if self.max_workers_per_node == float('inf') \ - else "--max_workers={}".format(self.max_workers_per_node) + max_workers = ( + "" + if self.max_workers_per_node == float("inf") + else f"--max_workers={self.max_workers_per_node}" + ) worker_task_url = f"tcp://{self.interchange_address}:{self.worker_task_port}" - worker_result_url = f"tcp://{self.interchange_address}:{self.worker_result_port}" - - l_cmd = self.launch_cmd.format(debug=debug_opts, - max_workers=max_workers, - cores_per_worker=self.cores_per_worker, - # mem_per_worker=self.mem_per_worker, - prefetch_capacity=self.prefetch_capacity, - task_url=worker_task_url, - result_url=worker_result_url, - nodes_per_block=self.provider.nodes_per_block, - heartbeat_period=self.heartbeat_period, - heartbeat_threshold=self.heartbeat_threshold, - poll_period=self.poll_period, - worker_mode=self.worker_mode, - container_cmd_options=self.container_cmd_options, - scheduler_mode=self.scheduler_mode, - logdir=working_dir, - log_max_bytes=self.log_max_bytes, - log_backup_count=self.log_backup_count) + worker_result_url = ( + f"tcp://{self.interchange_address}:{self.worker_result_port}" + ) + + l_cmd = self.launch_cmd.format( + debug=debug_opts, + max_workers=max_workers, + cores_per_worker=self.cores_per_worker, + # mem_per_worker=self.mem_per_worker, + prefetch_capacity=self.prefetch_capacity, + task_url=worker_task_url, + result_url=worker_result_url, + nodes_per_block=self.provider.nodes_per_block, + heartbeat_period=self.heartbeat_period, + heartbeat_threshold=self.heartbeat_threshold, + poll_period=self.poll_period, + worker_mode=self.worker_mode, + container_cmd_options=self.container_cmd_options, + scheduler_mode=self.scheduler_mode, + logdir=working_dir, + log_max_bytes=self.log_max_bytes, + log_backup_count=self.log_backup_count, + ) self.launch_cmd = l_cmd - logger.info("Launch command: {}".format(self.launch_cmd)) + logger.info(f"Launch command: {self.launch_cmd}") if self.scaling_enabled: logger.info("Scaling ...") self.scale_out(self.provider.init_blocks) def get_tasks(self, count): - """ Obtains a batch of tasks from the internal pending_task_queue + """Obtains a batch of tasks from the internal pending_task_queue Parameters ---------- @@ -432,79 +471,104 @@ def migrate_tasks_to_internal(self, kill_event, status_request): self.last_heartbeat = time.time() except zmq.Again: # We just timed out while attempting to receive - logger.debug("[TASK_PULL_THREAD] {} tasks in internal queue".format(self.total_pending_task_count)) + logger.debug( + "[TASK_PULL_THREAD] {} tasks in internal queue".format( + self.total_pending_task_count + ) + ) continue try: msg = Message.unpack(raw_msg) - logger.debug("[TASK_PULL_THREAD] received Message/Heartbeat? on task queue") + logger.debug( + "[TASK_PULL_THREAD] received Message/Heartbeat? on task queue" + ) except Exception: logger.exception("Failed to unpack message") pass - if msg == 'STOP': + if msg == "STOP": # TODO: Yadu. This should be replaced by a proper MessageType kill_event.set() break elif isinstance(msg, Heartbeat): logger.debug("Got heartbeat") else: - logger.info("[TASK_PULL_THREAD] Received task:{}".format(msg)) + logger.info(f"[TASK_PULL_THREAD] Received task:{msg}") local_container = self.get_container(msg.container_id) msg.set_local_container(local_container) if local_container not in self.pending_task_queue: - self.pending_task_queue[local_container] = queue.Queue(maxsize=10 ** 6) + self.pending_task_queue[local_container] = queue.Queue( + maxsize=10 ** 6 + ) # We pass the raw message along - self.pending_task_queue[local_container].put({'task_id': msg.task_id, - 'container_id': msg.container_id, - 'local_container': local_container, - 'raw_buffer': raw_msg}) + self.pending_task_queue[local_container].put( + { + "task_id": msg.task_id, + "container_id": msg.container_id, + "local_container": local_container, + "raw_buffer": raw_msg, + } + ) self.total_pending_task_count += 1 self.task_status_deltas[msg.task_id] = TaskStatusCode.WAITING_FOR_NODES - logger.debug(f"[TASK_PULL_THREAD] task {msg.task_id} is now WAITING_FOR_NODES") - logger.debug("[TASK_PULL_THREAD] pending task count: {}".format(self.total_pending_task_count)) + logger.debug( + f"[TASK_PULL_THREAD] task {msg.task_id} is now WAITING_FOR_NODES" + ) + logger.debug( + "[TASK_PULL_THREAD] pending task count: {}".format( + self.total_pending_task_count + ) + ) task_counter += 1 - logger.debug("[TASK_PULL_THREAD] Fetched task:{}".format(task_counter)) + logger.debug(f"[TASK_PULL_THREAD] Fetched task:{task_counter}") def get_container(self, container_uuid): - """ Get the container image location if it is not known to the interchange""" + """Get the container image location if it is not known to the interchange""" if container_uuid not in self.containers: - if container_uuid == 'RAW' or not container_uuid: - self.containers[container_uuid] = 'RAW' + if container_uuid == "RAW" or not container_uuid: + self.containers[container_uuid] = "RAW" else: try: - container = self.fxs.get_container(container_uuid, self.container_type) + container = self.fxs.get_container( + container_uuid, self.container_type + ) except Exception: - logger.exception("[FETCH_CONTAINER] Unable to resolve container location") - self.containers[container_uuid] = 'RAW' + logger.exception( + "[FETCH_CONTAINER] Unable to resolve container location" + ) + self.containers[container_uuid] = "RAW" else: - logger.info("[FETCH_CONTAINER] Got container info: {}".format(container)) - self.containers[container_uuid] = container.get('location', 'RAW') + logger.info(f"[FETCH_CONTAINER] Got container info: {container}") + self.containers[container_uuid] = container.get("location", "RAW") return self.containers[container_uuid] def get_total_tasks_outstanding(self): - """ Get the outstanding tasks in total - """ + """Get the outstanding tasks in total""" outstanding = {} for task_type in self.pending_task_queue: - outstanding[task_type] = outstanding.get(task_type, 0) + self.pending_task_queue[task_type].qsize() + outstanding[task_type] = ( + outstanding.get(task_type, 0) + + self.pending_task_queue[task_type].qsize() + ) for manager in self._ready_manager_queue: - for task_type in self._ready_manager_queue[manager]['tasks']: - outstanding[task_type] = outstanding.get(task_type, 0) + len(self._ready_manager_queue[manager]['tasks'][task_type]) + for task_type in self._ready_manager_queue[manager]["tasks"]: + outstanding[task_type] = outstanding.get(task_type, 0) + len( + self._ready_manager_queue[manager]["tasks"][task_type] + ) return outstanding def get_total_live_workers(self): - """ Get the total active workers - """ + """Get the total active workers""" active = 0 for manager in self._ready_manager_queue: - if self._ready_manager_queue[manager]['active']: - active += self._ready_manager_queue[manager]['max_worker_count'] + if self._ready_manager_queue[manager]["active"]: + active += self._ready_manager_queue[manager]["max_worker_count"] return active def get_outstanding_breakdown(self): - """ Get outstanding breakdown per manager and in the interchange queues + """Get outstanding breakdown per manager and in the interchange queues Returns ------- @@ -514,16 +578,23 @@ def get_outstanding_breakdown(self): pending_on_interchange = self.total_pending_task_count # Reporting pending on interchange is a deviation from Parsl - reply = [('interchange', pending_on_interchange, True)] + reply = [("interchange", pending_on_interchange, True)] for manager in self._ready_manager_queue: - resp = (manager.decode('utf-8'), - sum([len(tids) for tids in self._ready_manager_queue[manager]['tasks'].values()]), - self._ready_manager_queue[manager]['active']) + resp = ( + manager.decode("utf-8"), + sum( + [ + len(tids) + for tids in self._ready_manager_queue[manager]["tasks"].values() + ] + ), + self._ready_manager_queue[manager]["active"], + ) reply.append(resp) return reply def _hold_block(self, block_id): - """ Sends hold command to all managers which are in a specific block + """Sends hold command to all managers which are in a specific block Parameters ---------- @@ -531,13 +602,15 @@ def _hold_block(self, block_id): Block identifier of the block to be put on hold """ for manager in self._ready_manager_queue: - if self._ready_manager_queue[manager]['active'] and \ - self._ready_manager_queue[manager]['block_id'] == block_id: - logger.debug("[HOLD_BLOCK]: Sending hold to manager: {}".format(manager)) + if ( + self._ready_manager_queue[manager]["active"] + and self._ready_manager_queue[manager]["block_id"] == block_id + ): + logger.debug(f"[HOLD_BLOCK]: Sending hold to manager: {manager}") self.hold_manager(manager) def hold_manager(self, manager): - """ Put manager on hold + """Put manager on hold Parameters ---------- @@ -545,7 +618,7 @@ def hold_manager(self, manager): Manager id to be put on hold while being killed """ if manager in self._ready_manager_queue: - self._ready_manager_queue[manager]['active'] = False + self._ready_manager_queue[manager]["active"] = False def _status_report_loop(self, kill_event, status_report_queue: queue.Queue): logger.debug("[STATUS] Status reporting loop starting") @@ -553,19 +626,20 @@ def _status_report_loop(self, kill_event, status_report_queue: queue.Queue): while not kill_event.is_set(): logger.debug(f"Endpoint id : {self.endpoint_id}, {type(self.endpoint_id)}") msg = EPStatusReport( - self.endpoint_id, - self.get_status_report(), - self.task_status_deltas + self.endpoint_id, self.get_status_report(), self.task_status_deltas + ) + logger.debug( + "[STATUS] Sending status report to executor, and clearing task deltas." ) - logger.debug("[STATUS] Sending status report to executor, and clearing task deltas.") status_report_queue.put(msg.pack()) self.task_status_deltas.clear() time.sleep(self.heartbeat_period) def _command_server(self, kill_event): - """ Command server to run async command to the interchange + """Command server to run async command to the interchange - We want to be able to receive the following not yet implemented/updated commands: + We want to be able to receive the following not yet implemented/updated + commands: - OutstandingCount - ListManagers (get outstanding broken down by manager) - HoldWorker @@ -585,10 +659,12 @@ def _command_server(self, kill_event): if command.type is MessageType.HEARTBEAT_REQ: logger.info("[COMMAND] Received synchonous HEARTBEAT_REQ from hub") - logger.info(f"[COMMAND] Replying with Heartbeat({self.endpoint_id})") + logger.info( + f"[COMMAND] Replying with Heartbeat({self.endpoint_id})" + ) reply = Heartbeat(self.endpoint_id) - logger.debug("[COMMAND] Reply: {}".format(reply)) + logger.debug(f"[COMMAND] Reply: {reply}") self.command_channel.send(reply.pack()) except zmq.Again: @@ -603,7 +679,7 @@ def stop(self): self._command_thread.join() def start(self, poll_period=None): - """ Start the Interchange + """Start the Interchange Parameters: ---------- @@ -620,17 +696,25 @@ def start(self, poll_period=None): self._kill_event = threading.Event() self._status_request = threading.Event() - self._task_puller_thread = threading.Thread(target=self.migrate_tasks_to_internal, - args=(self._kill_event, self._status_request, )) + self._task_puller_thread = threading.Thread( + target=self.migrate_tasks_to_internal, + args=( + self._kill_event, + self._status_request, + ), + ) self._task_puller_thread.start() - self._command_thread = threading.Thread(target=self._command_server, - args=(self._kill_event, )) + self._command_thread = threading.Thread( + target=self._command_server, args=(self._kill_event,) + ) self._command_thread.start() status_report_queue = queue.Queue() - self._status_report_thread = threading.Thread(target=self._status_report_loop, - args=(self._kill_event, status_report_queue)) + self._status_report_thread = threading.Thread( + target=self._status_report_loop, + args=(self._kill_event, status_report_queue), + ) self._status_report_thread.start() try: @@ -652,16 +736,21 @@ def start(self, poll_period=None): interesting_managers = set() # This value records when the last cold routing in soft mode happens - # When the cold routing in soft mode happens, it may cause worker containers to switch - # Cold routing is to reduce the number idle workers of specific task types on the managers - # when there are not enough tasks of those types in the task queues on interchange + # When the cold routing in soft mode happens, it may cause worker containers to + # switch + # Cold routing is to reduce the number idle workers of specific task types on + # the managers when there are not enough tasks of those types in the task queues + # on interchange last_cold_routing_time = time.time() while not self._kill_event.is_set(): self.socks = dict(poller.poll(timeout=poll_period)) # Listen for requests for work - if self.task_outgoing in self.socks and self.socks[self.task_outgoing] == zmq.POLLIN: + if ( + self.task_outgoing in self.socks + and self.socks[self.task_outgoing] == zmq.POLLIN + ): logger.debug("[MAIN] starting task_outgoing section") message = self.task_outgoing.recv_multipart() manager = message[0] @@ -670,42 +759,65 @@ def start(self, poll_period=None): reg_flag = False try: - msg = json.loads(message[1].decode('utf-8')) + msg = json.loads(message[1].decode("utf-8")) reg_flag = True except Exception: - logger.warning("[MAIN] Got a non-json registration message from manager:{}".format( - manager)) - logger.debug("[MAIN] Message :\n{}\n".format(message)) + logger.warning( + "[MAIN] Got a non-json registration message from " + "manager:%s", + manager, + ) + logger.debug(f"[MAIN] Message :\n{message}\n") # By default we set up to ignore bad nodes/registration messages. - self._ready_manager_queue[manager] = {'last': time.time(), - 'reg_time': time.time(), - 'free_capacity': {'total_workers': 0}, - 'max_worker_count': 0, - 'active': True, - 'tasks': collections.defaultdict(set), - 'total_tasks': 0} + self._ready_manager_queue[manager] = { + "last": time.time(), + "reg_time": time.time(), + "free_capacity": {"total_workers": 0}, + "max_worker_count": 0, + "active": True, + "tasks": collections.defaultdict(set), + "total_tasks": 0, + } if reg_flag is True: interesting_managers.add(manager) - logger.info("[MAIN] Adding manager: {} to ready queue".format(manager)) + logger.info(f"[MAIN] Adding manager: {manager} to ready queue") self._ready_manager_queue[manager].update(msg) - logger.info("[MAIN] Registration info for manager {}: {}".format(manager, msg)) - - if (msg['python_v'].rsplit(".", 1)[0] != self.current_platform['python_v'].rsplit(".", 1)[0] or - msg['parsl_v'] != self.current_platform['parsl_v']): - logger.warn("[MAIN] Manager {} has incompatible version info with the interchange".format(manager)) + logger.info( + "[MAIN] Registration info for manager {}: {}".format( + manager, msg + ) + ) + + if ( + msg["python_v"].rsplit(".", 1)[0] + != self.current_platform["python_v"].rsplit(".", 1)[0] + or msg["parsl_v"] != self.current_platform["parsl_v"] + ): + logger.warn( + "[MAIN] Manager %s has incompatible version info with " + "the interchange", + manager, + ) if self.suppress_failure is False: logger.debug("Setting kill event") self._kill_event.set() e = ManagerLost(manager) - result_package = {'task_id': -1, - 'exception': self.serializer.serialize(e)} + result_package = { + "task_id": -1, + "exception": self.serializer.serialize(e), + } pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pickle.dumps([pkl_package])) - logger.warning("[MAIN] Sent failure reports, unregistering manager") + logger.warning( + "[MAIN] Sent failure reports, unregistering manager" + ) else: - logger.debug("[MAIN] Suppressing shutdown due to version incompatibility") + logger.debug( + "[MAIN] Suppressing shutdown due to version " + "incompatibility" + ) else: # Registration has failed. @@ -713,101 +825,161 @@ def start(self, poll_period=None): logger.debug("Setting kill event for bad manager") self._kill_event.set() e = BadRegistration(manager, critical=True) - result_package = {'task_id': -1, - 'exception': self.serializer.serialize(e)} + result_package = { + "task_id": -1, + "exception": self.serializer.serialize(e), + } pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pickle.dumps([pkl_package])) else: - logger.debug("[MAIN] Suppressing bad registration from manager:{}".format( - manager)) + logger.debug( + "[MAIN] Suppressing bad registration from manager: %s", + manager, + ) else: - self._ready_manager_queue[manager]['last'] = time.time() - if message[1] == b'HEARTBEAT': - logger.debug("[MAIN] Manager {} sends heartbeat".format(manager)) - self.task_outgoing.send_multipart([manager, b'', PKL_HEARTBEAT_CODE]) + self._ready_manager_queue[manager]["last"] = time.time() + if message[1] == b"HEARTBEAT": + logger.debug(f"[MAIN] Manager {manager} sends heartbeat") + self.task_outgoing.send_multipart( + [manager, b"", PKL_HEARTBEAT_CODE] + ) else: manager_adv = pickle.loads(message[1]) - logger.debug("[MAIN] Manager {} requested {}".format(manager, manager_adv)) - self._ready_manager_queue[manager]['free_capacity'].update(manager_adv) - self._ready_manager_queue[manager]['free_capacity']['total_workers'] = sum(manager_adv['free'].values()) + logger.debug( + "[MAIN] Manager {} requested {}".format( + manager, manager_adv + ) + ) + self._ready_manager_queue[manager]["free_capacity"].update( + manager_adv + ) + self._ready_manager_queue[manager]["free_capacity"][ + "total_workers" + ] = sum(manager_adv["free"].values()) interesting_managers.add(manager) - # If we had received any requests, check if there are tasks that could be passed + # If we had received any requests, check if there are tasks that could be + # passed - logger.debug("[MAIN] Managers count (total/interesting): {}/{}".format( - len(self._ready_manager_queue), - len(interesting_managers))) + logger.debug( + "[MAIN] Managers count (total/interesting): {}/{}".format( + len(self._ready_manager_queue), len(interesting_managers) + ) + ) if time.time() - last_cold_routing_time > self.cold_routing_interval: - task_dispatch, dispatched_task = naive_interchange_task_dispatch(interesting_managers, - self.pending_task_queue, - self._ready_manager_queue, - scheduler_mode=self.scheduler_mode, - cold_routing=True) + task_dispatch, dispatched_task = naive_interchange_task_dispatch( + interesting_managers, + self.pending_task_queue, + self._ready_manager_queue, + scheduler_mode=self.scheduler_mode, + cold_routing=True, + ) last_cold_routing_time = time.time() else: - task_dispatch, dispatched_task = naive_interchange_task_dispatch(interesting_managers, - self.pending_task_queue, - self._ready_manager_queue, - scheduler_mode=self.scheduler_mode, - cold_routing=False) + task_dispatch, dispatched_task = naive_interchange_task_dispatch( + interesting_managers, + self.pending_task_queue, + self._ready_manager_queue, + scheduler_mode=self.scheduler_mode, + cold_routing=False, + ) self.total_pending_task_count -= dispatched_task for manager in task_dispatch: tasks = task_dispatch[manager] if tasks: - logger.info("[MAIN] Sending task message {} to manager {}".format(tasks, manager)) + logger.info( + "[MAIN] Sending task message {} to manager {}".format( + tasks, manager + ) + ) serializd_raw_tasks_buffer = pickle.dumps(tasks) - # self.task_outgoing.send_multipart([manager, b'', pickle.dumps(tasks)]) - self.task_outgoing.send_multipart([manager, b'', serializd_raw_tasks_buffer]) + self.task_outgoing.send_multipart( + [manager, b"", serializd_raw_tasks_buffer] + ) for task in tasks: task_id = task["task_id"] logger.debug(f"[MAIN] Task {task_id} is now WAITING_FOR_LAUNCH") - self.task_status_deltas[task_id] = TaskStatusCode.WAITING_FOR_LAUNCH + self.task_status_deltas[ + task_id + ] = TaskStatusCode.WAITING_FOR_LAUNCH # Receive any results and forward to client - if self.results_incoming in self.socks and self.socks[self.results_incoming] == zmq.POLLIN: + if ( + self.results_incoming in self.socks + and self.socks[self.results_incoming] == zmq.POLLIN + ): logger.debug("[MAIN] entering results_incoming section") manager, *b_messages = self.results_incoming.recv_multipart() if manager not in self._ready_manager_queue: - logger.warning("[MAIN] Received a result from a un-registered manager: {}".format(manager)) + logger.warning( + "[MAIN] Received a result from a un-registered manager: %s", + manager, + ) else: - # We expect the batch of messages to be (optionally) a task status update message - # followed by 0 or more task results + # We expect the batch of messages to be (optionally) a task status + # update message followed by 0 or more task results try: logger.debug("[MAIN] Trying to unpack ") manager_report = Message.unpack(b_messages[0]) if manager_report.task_statuses: - logger.info(f"[MAIN] Got manager status report: {manager_report.task_statuses}") + logger.info( + "[MAIN] Got manager status report: %s", + manager_report.task_statuses, + ) self.task_status_deltas.update(manager_report.task_statuses) - self.task_outgoing.send_multipart([manager, b'', PKL_HEARTBEAT_CODE]) + self.task_outgoing.send_multipart( + [manager, b"", PKL_HEARTBEAT_CODE] + ) b_messages = b_messages[1:] - self._ready_manager_queue[manager]['last'] = time.time() - self.container_switch_count[manager] = manager_report.container_switch_count - logger.info(f"[MAIN] Got container switch count: {self.container_switch_count}") + self._ready_manager_queue[manager]["last"] = time.time() + self.container_switch_count[ + manager + ] = manager_report.container_switch_count + logger.info( + "[MAIN] Got container switch count: %s", + self.container_switch_count, + ) except Exception: pass if len(b_messages): - logger.info("[MAIN] Got {} result items in batch".format(len(b_messages))) + logger.info( + "[MAIN] Got {} result items in batch".format( + len(b_messages) + ) + ) for b_message in b_messages: r = pickle.loads(b_message) - logger.debug("[MAIN] Received result for task {} from {}".format(r, manager)) - task_type = self.containers[r['container_id']] - if r['task_id'] in self.task_status_deltas: - del self.task_status_deltas[r['task_id']] - self._ready_manager_queue[manager]['tasks'][task_type].remove(r['task_id']) - self._ready_manager_queue[manager]['total_tasks'] -= len(b_messages) + logger.debug( + "[MAIN] Received result for task {} from {}".format( + r, manager + ) + ) + task_type = self.containers[r["container_id"]] + if r["task_id"] in self.task_status_deltas: + del self.task_status_deltas[r["task_id"]] + self._ready_manager_queue[manager]["tasks"][task_type].remove( + r["task_id"] + ) + self._ready_manager_queue[manager]["total_tasks"] -= len(b_messages) # TODO: handle this with a Task message or something? - # previously used this; switched to mono-message, self.results_outgoing.send_multipart(b_messages) + # previously used this; switched to mono-message, + # self.results_outgoing.send_multipart(b_messages) self.results_outgoing.send(pickle.dumps(b_messages)) - logger.debug("[MAIN] Current tasks: {}".format(self._ready_manager_queue[manager]['tasks'])) + logger.debug( + "[MAIN] Current tasks: {}".format( + self._ready_manager_queue[manager]["tasks"] + ) + ) logger.debug("[MAIN] leaving results_incoming section") - # Send status reports from this main thread to avoid thread-safety on zmq sockets + # Send status reports from this main thread to avoid thread-safety on zmq + # sockets try: packed_status_report = status_report_queue.get(block=False) logger.debug(f"[MAIN] forwarding status report: {packed_status_report}") @@ -816,23 +988,42 @@ def start(self, poll_period=None): pass # logger.debug("[MAIN] entering bad_managers section") - bad_managers = [manager for manager in self._ready_manager_queue if - time.time() - self._ready_manager_queue[manager]['last'] > self.heartbeat_threshold] + bad_managers = [ + manager + for manager in self._ready_manager_queue + if time.time() - self._ready_manager_queue[manager]["last"] + > self.heartbeat_threshold + ] bad_manager_msgs = [] for manager in bad_managers: - logger.debug("[MAIN] Last: {} Current: {}".format(self._ready_manager_queue[manager]['last'], time.time())) - logger.warning("[MAIN] Too many heartbeats missed for manager {}".format(manager)) + logger.debug( + "[MAIN] Last: {} Current: {}".format( + self._ready_manager_queue[manager]["last"], time.time() + ) + ) + logger.warning( + f"[MAIN] Too many heartbeats missed for manager {manager}" + ) e = ManagerLost(manager) - for task_type in self._ready_manager_queue[manager]['tasks']: - for tid in self._ready_manager_queue[manager]['tasks'][task_type]: + for task_type in self._ready_manager_queue[manager]["tasks"]: + for tid in self._ready_manager_queue[manager]["tasks"][task_type]: try: raise ManagerLost(manager) except Exception: - result_package = {'task_id': tid, 'exception': self.serializer.serialize(RemoteExceptionWrapper(*sys.exc_info()))} + result_package = { + "task_id": tid, + "exception": self.serializer.serialize( + RemoteExceptionWrapper(*sys.exc_info()) + ), + } pkl_package = pickle.dumps(result_package) bad_manager_msgs.append(pkl_package) - logger.warning("[MAIN] Sent failure reports, unregistering manager {}".format(manager)) - self._ready_manager_queue.pop(manager, 'None') + logger.warning( + "[MAIN] Sent failure reports, unregistering manager {}".format( + manager + ) + ) + self._ready_manager_queue.pop(manager, "None") if manager in interesting_managers: interesting_managers.remove(manager) if bad_manager_msgs: @@ -848,12 +1039,11 @@ def start(self, poll_period=None): self._status_request.clear() delta = time.time() - start - logger.info("Processed {} tasks in {} seconds".format(count, delta)) + logger.info(f"Processed {count} tasks in {delta} seconds") logger.warning("Exiting") def get_status_report(self): - """ Get utilization numbers - """ + """Get utilization numbers""" total_cores = 0 total_mem = 0 core_hrs = 0 @@ -865,36 +1055,43 @@ def get_status_report(self): live_workers = self.get_total_live_workers() for manager in self._ready_manager_queue: - total_cores += self._ready_manager_queue[manager]['cores'] - total_mem += self._ready_manager_queue[manager]['mem'] - active_dur = abs(time.time() - self._ready_manager_queue[manager]['reg_time']) + total_cores += self._ready_manager_queue[manager]["cores"] + total_mem += self._ready_manager_queue[manager]["mem"] + active_dur = abs( + time.time() - self._ready_manager_queue[manager]["reg_time"] + ) core_hrs += (active_dur * total_cores) / 3600 - if self._ready_manager_queue[manager]['active']: + if self._ready_manager_queue[manager]["active"]: active_managers += 1 - free_capacity += self._ready_manager_queue[manager]['free_capacity']['total_workers'] - - result_package = {'task_id': -2, - 'info': {'total_cores': total_cores, - 'total_mem': total_mem, - 'new_core_hrs': core_hrs - self.last_core_hr_counter, - 'total_core_hrs': round(core_hrs, 2), - 'managers': num_managers, - 'active_managers': active_managers, - 'total_workers': live_workers, - 'idle_workers': free_capacity, - 'pending_tasks': pending_tasks, - 'outstanding_tasks': outstanding_tasks, - 'worker_mode': self.worker_mode, - 'scheduler_mode': self.scheduler_mode, - 'scaling_enabled': self.scaling_enabled, - 'mem_per_worker': self.mem_per_worker, - 'cores_per_worker': self.cores_per_worker, - 'prefetch_capacity': self.prefetch_capacity, - 'max_blocks': self.provider.max_blocks, - 'min_blocks': self.provider.min_blocks, - 'max_workers_per_node': self.max_workers_per_node, - 'nodes_per_block': self.provider.nodes_per_block - }} + free_capacity += self._ready_manager_queue[manager]["free_capacity"][ + "total_workers" + ] + + result_package = { + "task_id": -2, + "info": { + "total_cores": total_cores, + "total_mem": total_mem, + "new_core_hrs": core_hrs - self.last_core_hr_counter, + "total_core_hrs": round(core_hrs, 2), + "managers": num_managers, + "active_managers": active_managers, + "total_workers": live_workers, + "idle_workers": free_capacity, + "pending_tasks": pending_tasks, + "outstanding_tasks": outstanding_tasks, + "worker_mode": self.worker_mode, + "scheduler_mode": self.scheduler_mode, + "scaling_enabled": self.scaling_enabled, + "mem_per_worker": self.mem_per_worker, + "cores_per_worker": self.cores_per_worker, + "prefetch_capacity": self.prefetch_capacity, + "max_blocks": self.provider.max_blocks, + "min_blocks": self.provider.min_blocks, + "max_workers_per_node": self.max_workers_per_node, + "nodes_per_block": self.provider.nodes_per_block, + }, + } self.last_core_hr_counter = core_hrs return result_package @@ -910,18 +1107,26 @@ def scale_out(self, blocks=1, task_type=None): if self.provider: self._block_counter += 1 external_block_id = str(self._block_counter) - if not task_type and self.scheduler_mode == 'hard': - launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type='RAW') + if not task_type and self.scheduler_mode == "hard": + launch_cmd = self.launch_cmd.format( + block_id=external_block_id, worker_type="RAW" + ) else: - launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type=task_type) + launch_cmd = self.launch_cmd.format( + block_id=external_block_id, worker_type=task_type + ) if not task_type: internal_block = self.provider.submit(launch_cmd, 1) else: internal_block = self.provider.submit(launch_cmd, 1, task_type) - logger.debug("Launched block {}->{}".format(external_block_id, internal_block)) + logger.debug(f"Launched block {external_block_id}->{internal_block}") if not internal_block: - raise(ScalingFailed(self.config.provider.label, - "Attempts to provision nodes via provider has failed")) + raise ( + ScalingFailed( + self.config.provider.label, + "Attempts to provision nodes via provider has failed", + ) + ) self.blocks[external_block_id] = internal_block self.block_id_map[internal_block] = external_block_id else: @@ -943,14 +1148,22 @@ def scale_in(self, blocks=None, block_ids=None, task_type=None): if block_ids is None: block_ids = [] if task_type: - logger.info("Scaling in blocks of specific task type {}. Let the provider decide which to kill".format(task_type)) + logger.info( + "Scaling in blocks of specific task type %s. Let the provider decide " + "which to kill", + task_type, + ) if self.scaling_enabled and self.provider: to_kill, r = self.provider.cancel(blocks, task_type) - logger.info("Get the killed blocks: {}, and status: {}".format(to_kill, r)) + logger.info(f"Get the killed blocks: {to_kill}, and status: {r}") for job in to_kill: - logger.info("[scale_in] Getting the block_id map {} for job {}".format(self.block_id_map, job)) + logger.info( + "[scale_in] Getting the block_id map {} for job {}".format( + self.block_id_map, job + ) + ) block_id = self.block_id_map[job] - logger.info("[scale_in] Holding block {}".format(block_id)) + logger.info(f"[scale_in] Holding block {block_id}") self._hold_block(block_id) self.blocks.pop(block_id) return r @@ -974,13 +1187,16 @@ def scale_in(self, blocks=None, block_ids=None, task_type=None): return r def provider_status(self): - """ Get status of all blocks from the provider - """ + """Get status of all blocks from the provider""" status = [] if self.provider: - logger.debug("[MAIN] Getting the status of {} blocks.".format(list(self.blocks.values()))) + logger.debug( + "[MAIN] Getting the status of {} blocks.".format( + list(self.blocks.values()) + ) + ) status = self.provider.status(list(self.blocks.values())) - logger.debug("[MAIN] The status is {}".format(status)) + logger.debug(f"[MAIN] The status is {status}") return status @@ -988,58 +1204,79 @@ def provider_status(self): def starter(comm_q, *args, **kwargs): """Start the interchange process - The executor is expected to call this function. The args, kwargs match that of the Interchange.__init__ + The executor is expected to call this function. The args, kwargs match that of the + Interchange.__init__ """ # logger = multiprocessing.get_logger() ic = Interchange(*args, **kwargs) - comm_q.put((ic.worker_task_port, - ic.worker_result_port)) + comm_q.put((ic.worker_task_port, ic.worker_result_port)) ic.start() def cli_run(): parser = argparse.ArgumentParser() - parser.add_argument("-c", "--client_address", required=True, - help="Client address") - parser.add_argument("--client_ports", required=True, - help="client ports as a triple of outgoing,incoming,command") - parser.add_argument("--worker_port_range", - help="Worker port range as a tuple") - parser.add_argument("-l", "--logdir", default="./parsl_worker_logs", - help="Parsl worker log directory") - parser.add_argument("-p", "--poll_period", - help="REQUIRED: poll period used for main thread") - parser.add_argument("--worker_ports", default=None, - help="OPTIONAL, pair of workers ports to listen on, eg --worker_ports=50001,50005") - parser.add_argument("--suppress_failure", action='store_true', - help="Enables suppression of failures") - parser.add_argument("--endpoint_id", default=None, - help="Endpoint ID, used to identify the endpoint to the remote broker") - parser.add_argument("--hb_threshold", - help="Heartbeat threshold in seconds") - parser.add_argument("--config", default=None, - help="Configuration object that describes provisioning") - parser.add_argument("-d", "--debug", action='store_true', - help="Enables debug logging") + parser.add_argument("-c", "--client_address", required=True, help="Client address") + parser.add_argument( + "--client_ports", + required=True, + help="client ports as a triple of outgoing,incoming,command", + ) + parser.add_argument("--worker_port_range", help="Worker port range as a tuple") + parser.add_argument( + "-l", + "--logdir", + default="./parsl_worker_logs", + help="Parsl worker log directory", + ) + parser.add_argument( + "-p", "--poll_period", help="REQUIRED: poll period used for main thread" + ) + parser.add_argument( + "--worker_ports", + default=None, + help="OPTIONAL, pair of workers ports to listen on, " + "eg --worker_ports=50001,50005", + ) + parser.add_argument( + "--suppress_failure", + action="store_true", + help="Enables suppression of failures", + ) + parser.add_argument( + "--endpoint_id", + default=None, + help="Endpoint ID, used to identify the endpoint to the remote broker", + ) + parser.add_argument("--hb_threshold", help="Heartbeat threshold in seconds") + parser.add_argument( + "--config", + default=None, + help="Configuration object that describes provisioning", + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="Enables debug logging" + ) print("Starting HTEX Intechange") args = parser.parse_args() optionals = {} - optionals['suppress_failure'] = args.suppress_failure - optionals['logdir'] = os.path.abspath(args.logdir) - optionals['client_address'] = args.client_address - optionals['client_ports'] = [int(i) for i in args.client_ports.split(',')] - optionals['endpoint_id'] = args.endpoint_id - optionals['config'] = args.config + optionals["suppress_failure"] = args.suppress_failure + optionals["logdir"] = os.path.abspath(args.logdir) + optionals["client_address"] = args.client_address + optionals["client_ports"] = [int(i) for i in args.client_ports.split(",")] + optionals["endpoint_id"] = args.endpoint_id + optionals["config"] = args.config if args.debug: - optionals['logging_level'] = logging.DEBUG + optionals["logging_level"] = logging.DEBUG if args.worker_ports: - optionals['worker_ports'] = [int(i) for i in args.worker_ports.split(',')] + optionals["worker_ports"] = [int(i) for i in args.worker_ports.split(",")] if args.worker_port_range: - optionals['worker_port_range'] = [int(i) for i in args.worker_port_range.split(',')] + optionals["worker_port_range"] = [ + int(i) for i in args.worker_port_range.split(",") + ] with daemon.DaemonContext(): ic = Interchange(**optionals) diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange_task_dispatch.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange_task_dispatch.py index b6017f802..5d4650dde 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange_task_dispatch.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/interchange_task_dispatch.py @@ -1,50 +1,57 @@ -import math -import random -import queue -import logging import collections +import logging +import queue +import random logger = logging.getLogger("interchange.task_dispatch") logger.info("Interchange task dispatch started") -def naive_interchange_task_dispatch(interesting_managers, - pending_task_queue, - ready_manager_queue, - scheduler_mode='hard', - cold_routing=False): +def naive_interchange_task_dispatch( + interesting_managers, + pending_task_queue, + ready_manager_queue, + scheduler_mode="hard", + cold_routing=False, +): """ This is an initial task dispatching algorithm for interchange. - It returns a dictionary, whose key is manager, and the value is the list of tasks to be sent to manager, - and the total number of dispatched tasks. + It returns a dictionary, whose key is manager, and the value is the list of tasks + to be sent to manager, and the total number of dispatched tasks. """ - if scheduler_mode == 'hard': - return dispatch(interesting_managers, - pending_task_queue, - ready_manager_queue, - scheduler_mode='hard') + if scheduler_mode == "hard": + return dispatch( + interesting_managers, + pending_task_queue, + ready_manager_queue, + scheduler_mode="hard", + ) - elif scheduler_mode == 'soft': + elif scheduler_mode == "soft": task_dispatch, dispatched_tasks = {}, 0 - loops = ['warm'] if not cold_routing else ['warm', 'cold'] + loops = ["warm"] if not cold_routing else ["warm", "cold"] for loop in loops: - task_dispatch, dispatched_tasks = dispatch(interesting_managers, - pending_task_queue, - ready_manager_queue, - scheduler_mode='soft', - loop=loop, - task_dispatch=task_dispatch, - dispatched_tasks=dispatched_tasks) + task_dispatch, dispatched_tasks = dispatch( + interesting_managers, + pending_task_queue, + ready_manager_queue, + scheduler_mode="soft", + loop=loop, + task_dispatch=task_dispatch, + dispatched_tasks=dispatched_tasks, + ) return task_dispatch, dispatched_tasks -def dispatch(interesting_managers, - pending_task_queue, - ready_manager_queue, - scheduler_mode='hard', - loop='warm', - task_dispatch=None, - dispatched_tasks=0): +def dispatch( + interesting_managers, + pending_task_queue, + ready_manager_queue, + scheduler_mode="hard", + loop="warm", + task_dispatch=None, + dispatched_tasks=0, +): """ This is the core task dispatching algorithm for interchange. The algorithm depends on the scheduler mode and which loop. @@ -55,99 +62,129 @@ def dispatch(interesting_managers, shuffled_managers = list(interesting_managers) random.shuffle(shuffled_managers) for manager in shuffled_managers: - tasks_inflight = ready_manager_queue[manager]['total_tasks'] - real_capacity = min(ready_manager_queue[manager]['free_capacity']['total_workers'], - ready_manager_queue[manager]['max_worker_count'] - tasks_inflight) - if (real_capacity and ready_manager_queue[manager]['active']): - if scheduler_mode == 'hard': - tasks, tids = get_tasks_hard(pending_task_queue, - ready_manager_queue[manager], - real_capacity) + tasks_inflight = ready_manager_queue[manager]["total_tasks"] + real_capacity = min( + ready_manager_queue[manager]["free_capacity"]["total_workers"], + ready_manager_queue[manager]["max_worker_count"] - tasks_inflight, + ) + if real_capacity and ready_manager_queue[manager]["active"]: + if scheduler_mode == "hard": + tasks, tids = get_tasks_hard( + pending_task_queue, ready_manager_queue[manager], real_capacity + ) else: - tasks, tids = get_tasks_soft(pending_task_queue, - ready_manager_queue[manager], - real_capacity, - loop=loop) - logger.debug("[MAIN] Get tasks {} from queue".format(tasks)) + tasks, tids = get_tasks_soft( + pending_task_queue, + ready_manager_queue[manager], + real_capacity, + loop=loop, + ) + logger.debug(f"[MAIN] Get tasks {tasks} from queue") if tasks: for task_type in tids: # This line is a set update, not dict update - ready_manager_queue[manager]['tasks'][task_type].update(tids[task_type]) - logger.debug("[MAIN] The tasks on manager {} is {}".format(manager, ready_manager_queue[manager]['tasks'])) - ready_manager_queue[manager]['total_tasks'] += len(tasks) + ready_manager_queue[manager]["tasks"][task_type].update( + tids[task_type] + ) + logger.debug( + "[MAIN] The tasks on manager {} is {}".format( + manager, ready_manager_queue[manager]["tasks"] + ) + ) + ready_manager_queue[manager]["total_tasks"] += len(tasks) if manager not in task_dispatch: task_dispatch[manager] = [] task_dispatch[manager] += tasks dispatched_tasks += len(tasks) - logger.debug("[MAIN] Assigned tasks {} to manager {}".format(tids, manager)) - if ready_manager_queue[manager]['free_capacity']['total_workers'] > 0: - logger.debug("[MAIN] Manager {} still has free_capacity {}".format(manager, ready_manager_queue[manager]['free_capacity']['total_workers'])) + logger.debug(f"[MAIN] Assigned tasks {tids} to manager {manager}") + if ready_manager_queue[manager]["free_capacity"]["total_workers"] > 0: + logger.debug( + "[MAIN] Manager {} still has free_capacity {}".format( + manager, + ready_manager_queue[manager]["free_capacity"][ + "total_workers" + ], + ) + ) else: - logger.debug("[MAIN] Manager {} is now saturated".format(manager)) + logger.debug(f"[MAIN] Manager {manager} is now saturated") interesting_managers.remove(manager) else: interesting_managers.remove(manager) - logger.debug("The task dispatch of {} loop is {}, in total {} tasks".format(loop, task_dispatch, dispatched_tasks)) + logger.debug( + "The task dispatch of {} loop is {}, in total {} tasks".format( + loop, task_dispatch, dispatched_tasks + ) + ) return task_dispatch, dispatched_tasks def get_tasks_hard(pending_task_queue, manager_ads, real_capacity): tasks = [] tids = collections.defaultdict(set) - task_type = manager_ads['worker_type'] + task_type = manager_ads["worker_type"] if not task_type: - logger.warning("Using hard scheduler mode but with manager worker type unset. Use soft scheduler mode. Set this in the config.") + logger.warning( + "Using hard scheduler mode but with manager worker type unset. " + "Use soft scheduler mode. Set this in the config." + ) return tasks, tids if task_type not in pending_task_queue: - logger.debug("No task of type {}. Exiting task fetching.".format(task_type)) + logger.debug(f"No task of type {task_type}. Exiting task fetching.") return tasks, tids # dispatch tasks of available types on manager - if task_type in manager_ads['free_capacity']['free']: - while manager_ads['free_capacity']['free'][task_type] > 0 and real_capacity > 0: + if task_type in manager_ads["free_capacity"]["free"]: + while manager_ads["free_capacity"]["free"][task_type] > 0 and real_capacity > 0: try: x = pending_task_queue[task_type].get(block=False) except queue.Empty: break else: - logger.debug("Get task {}".format(x)) + logger.debug(f"Get task {x}") tasks.append(x) - tids[task_type].add(x['task_id']) - manager_ads['free_capacity']['free'][task_type] -= 1 - manager_ads['free_capacity']['total_workers'] -= 1 + tids[task_type].add(x["task_id"]) + manager_ads["free_capacity"]["free"][task_type] -= 1 + manager_ads["free_capacity"]["total_workers"] -= 1 real_capacity -= 1 # dispatch tasks to unused slots based on the manager type logger.debug("Second round of task fetching in hard mode") - while manager_ads['free_capacity']['free']["unused"] > 0 and real_capacity > 0: + while manager_ads["free_capacity"]["free"]["unused"] > 0 and real_capacity > 0: try: x = pending_task_queue[task_type].get(block=False) except queue.Empty: break else: - logger.debug("Get task {}".format(x)) + logger.debug(f"Get task {x}") tasks.append(x) - tids[task_type].add(x['task_id']) - manager_ads['free_capacity']['free']['unused'] -= 1 - manager_ads['free_capacity']['total_workers'] -= 1 + tids[task_type].add(x["task_id"]) + manager_ads["free_capacity"]["free"]["unused"] -= 1 + manager_ads["free_capacity"]["total_workers"] -= 1 real_capacity -= 1 return tasks, tids -def get_tasks_soft(pending_task_queue, manager_ads, real_capacity, loop='warm'): +def get_tasks_soft(pending_task_queue, manager_ads, real_capacity, loop="warm"): tasks = [] tids = collections.defaultdict(set) # Warm routing to dispatch tasks - if loop == 'warm': - for task_type in manager_ads['free_capacity']['free']: + if loop == "warm": + for task_type in manager_ads["free_capacity"]["free"]: # Dispatch tasks that are of the available container types on the manager - if task_type != 'unused': - type_inflight = len(manager_ads['tasks'].get(task_type, set())) - type_capacity = min(manager_ads['free_capacity']['free'][task_type], - manager_ads['free_capacity']['total'][task_type] - type_inflight) - while manager_ads['free_capacity']['free'][task_type] > 0 and real_capacity > 0 and type_capacity > 0: + if task_type != "unused": + type_inflight = len(manager_ads["tasks"].get(task_type, set())) + type_capacity = min( + manager_ads["free_capacity"]["free"][task_type], + manager_ads["free_capacity"]["total"][task_type] - type_inflight, + ) + while ( + manager_ads["free_capacity"]["free"][task_type] > 0 + and real_capacity > 0 + and type_capacity > 0 + ): try: if task_type not in pending_task_queue: break @@ -155,11 +192,11 @@ def get_tasks_soft(pending_task_queue, manager_ads, real_capacity, loop='warm'): except queue.Empty: break else: - logger.debug("Get task {}".format(x)) + logger.debug(f"Get task {x}") tasks.append(x) - tids[task_type].add(x['task_id']) - manager_ads['free_capacity']['free'][task_type] -= 1 - manager_ads['free_capacity']['total_workers'] -= 1 + tids[task_type].add(x["task_id"]) + manager_ads["free_capacity"]["free"][task_type] -= 1 + manager_ads["free_capacity"]["total_workers"] -= 1 real_capacity -= 1 type_capacity -= 1 # Dispatch tasks to unused container slots on the manager @@ -167,18 +204,21 @@ def get_tasks_soft(pending_task_queue, manager_ads, real_capacity, loop='warm'): task_types = list(pending_task_queue.keys()) random.shuffle(task_types) for task_type in task_types: - while (manager_ads['free_capacity']['free']['unused'] > 0 and - manager_ads['free_capacity']['total_workers'] > 0 and real_capacity > 0): + while ( + manager_ads["free_capacity"]["free"]["unused"] > 0 + and manager_ads["free_capacity"]["total_workers"] > 0 + and real_capacity > 0 + ): try: x = pending_task_queue[task_type].get(block=False) except queue.Empty: break else: - logger.debug("Get task {}".format(x)) + logger.debug(f"Get task {x}") tasks.append(x) - tids[task_type].add(x['task_id']) - manager_ads['free_capacity']['free']['unused'] -= 1 - manager_ads['free_capacity']['total_workers'] -= 1 + tids[task_type].add(x["task_id"]) + manager_ads["free_capacity"]["free"]["unused"] -= 1 + manager_ads["free_capacity"]["total_workers"] -= 1 real_capacity -= 1 return tasks, tids @@ -192,15 +232,15 @@ def get_tasks_soft(pending_task_queue, manager_ads, real_capacity, loop='warm'): task_types = list(pending_task_queue.keys()) random.shuffle(task_types) for task_type in task_types: - while manager_ads['free_capacity']['total_workers'] > 0 and real_capacity > 0: + while manager_ads["free_capacity"]["total_workers"] > 0 and real_capacity > 0: try: x = pending_task_queue[task_type].get(block=False) except queue.Empty: break else: - logger.debug("Get task {}".format(x)) + logger.debug(f"Get task {x}") tasks.append(x) - tids[task_type].add(x['task_id']) - manager_ads['free_capacity']['total_workers'] -= 1 + tids[task_type].add(x["task_id"]) + manager_ads["free_capacity"]["total_workers"] -= 1 real_capacity -= 1 return tasks, tids diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/mac_safe_queue.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/mac_safe_queue.py index 1ce44dd48..357f8e32c 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/mac_safe_queue.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/mac_safe_queue.py @@ -1,5 +1,8 @@ import platform -if platform.system() == 'Darwin': + +if platform.system() == "Darwin": from parsl.executors.high_throughput.mac_safe_queue import MacSafeQueue as mpQueue else: from multiprocessing import Queue as mpQueue + +__all__ = ("mpQueue",) diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/messages.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/messages.py index f3131f757..cbcfb9ddf 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/messages.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/messages.py @@ -3,8 +3,8 @@ from abc import ABC, abstractmethod from enum import Enum, auto from struct import Struct -from typing import Tuple -MESSAGE_TYPE_FORMATTER = Struct('b') + +MESSAGE_TYPE_FORMATTER = Struct("b") class MessageType(Enum): @@ -20,8 +20,8 @@ def pack(self): @classmethod def unpack(cls, buffer): - mtype, = MESSAGE_TYPE_FORMATTER.unpack_from(buffer, offset=0) - return MessageType(mtype), buffer[MESSAGE_TYPE_FORMATTER.size:] + (mtype,) = MESSAGE_TYPE_FORMATTER.unpack_from(buffer, offset=0) + return MessageType(mtype), buffer[MESSAGE_TYPE_FORMATTER.size :] class TaskStatusCode(int, Enum): @@ -32,9 +32,7 @@ class TaskStatusCode(int, Enum): FAILED = auto() -COMMAND_TYPES = { - MessageType.HEARTBEAT_REQ -} +COMMAND_TYPES = {MessageType.HEARTBEAT_REQ} class Message(ABC): @@ -85,9 +83,12 @@ class Task(Message): """ Task message from the forwarder->interchange """ + type = MessageType.TASK - def __init__(self, task_id: str, container_id: str, task_buffer: str, raw_buffer=None): + def __init__( + self, task_id: str, container_id: str, task_buffer: str, raw_buffer=None + ): super().__init__() self.task_id = task_id self.container_id = container_id @@ -97,15 +98,17 @@ def __init__(self, task_id: str, container_id: str, task_buffer: str, raw_buffer def pack(self) -> bytes: if self.raw_buffer is None: - add_ons = f'TID={self.task_id};CID={self.container_id};{self.task_buffer}' - self.raw_buffer = add_ons.encode('utf-8') + add_ons = f"TID={self.task_id};CID={self.container_id};{self.task_buffer}" + self.raw_buffer = add_ons.encode("utf-8") return self.type.pack() + self.raw_buffer @classmethod def unpack(cls, raw_buffer: bytes): - b_tid, b_cid, task_buf = raw_buffer.decode('utf-8').split(';', 2) - return cls(b_tid[4:], b_cid[4:], task_buf.encode('utf-8'), raw_buffer=raw_buffer) + b_tid, b_cid, task_buf = raw_buffer.decode("utf-8").split(";", 2) + return cls( + b_tid[4:], b_cid[4:], task_buf.encode("utf-8"), raw_buffer=raw_buffer + ) def set_local_container(self, container_id): self.local_container = container_id @@ -113,9 +116,12 @@ def set_local_container(self, container_id): class HeartbeatReq(Message): """ - Synchronous request for a Heartbeat. This is sent from the Forwarder to the endpoint on start to get - an initial connection and ensure liveness. + Synchronous request for a Heartbeat. + + This is sent from the Forwarder to the endpoint on start to get an initial + connection and ensure liveness. """ + type = MessageType.HEARTBEAT_REQ @property @@ -136,8 +142,10 @@ def pack(self): class Heartbeat(Message): """ - Generic Heartbeat message, sent in both directions between Forwarder and Interchange. + Generic Heartbeat message, sent in both directions between Forwarder and + Interchange. """ + type = MessageType.HEARTBEAT def __init__(self, endpoint_id): @@ -154,9 +162,11 @@ def pack(self): class EPStatusReport(Message): """ - Status report for an endpoint, sent from Interchange to Forwarder. Includes EP-wide info such as utilization, - as well as per-task status information. + Status report for an endpoint, sent from Interchange to Forwarder. + + Includes EP-wide info such as utilization, as well as per-task status information. """ + type = MessageType.EP_STATUS_REPORT def __init__(self, endpoint_id, ep_status_report, task_statuses): @@ -182,9 +192,10 @@ def pack(self): class ManagerStatusReport(Message): """ - Status report sent from the Manager to the Interchange, which mostly just amounts to saying which tasks are now - RUNNING. + Status report sent from the Manager to the Interchange, which mostly just amounts + to saying which tasks are now RUNNING. """ + type = MessageType.MANAGER_STATUS_REPORT def __init__(self, task_statuses, container_switch_count): @@ -194,7 +205,7 @@ def __init__(self, task_statuses, container_switch_count): @classmethod def unpack(cls, msg): - container_switch_count = int.from_bytes(msg[:10], 'little') + container_switch_count = int.from_bytes(msg[:10], "little") msg = msg[10:] jsonified = msg.decode("ascii") task_statuses = json.loads(jsonified) @@ -203,7 +214,11 @@ def unpack(cls, msg): def pack(self): # TODO: do better than JSON? jsonified = json.dumps(self.task_statuses) - return self.type.pack() + self.container_switch_count.to_bytes(10, 'little') + jsonified.encode("ascii") + return ( + self.type.pack() + + self.container_switch_count.to_bytes(10, "little") + + jsonified.encode("ascii") + ) class ResultsAck(Message): @@ -211,6 +226,7 @@ class ResultsAck(Message): Results acknowledgement to acknowledge a task result was received by the forwarder. Sent from forwarder->interchange """ + type = MessageType.RESULTS_ACK def __init__(self, task_id): diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/worker_map.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/worker_map.py index 5f4e95c2d..fc1d75c8b 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/worker_map.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/worker_map.py @@ -1,24 +1,27 @@ -from queue import Queue import logging +import os import random import subprocess import time -import os +from queue import Queue logger = logging.getLogger("funcx_manager.worker_map") -class WorkerMap(object): - """ WorkerMap keeps track of workers - """ +class WorkerMap: + """WorkerMap keeps track of workers""" def __init__(self, max_worker_count): self.max_worker_count = max_worker_count - self.total_worker_type_counts = {'unused': self.max_worker_count} - self.ready_worker_type_counts = {'unused': self.max_worker_count} + self.total_worker_type_counts = {"unused": self.max_worker_count} + self.ready_worker_type_counts = {"unused": self.max_worker_count} self.pending_worker_type_counts = {} - self.worker_queues = {} # a dict to keep track of all the worker_queues with the key of work_type - self.worker_types = {} # a dict to keep track of all the worker_types with the key of worker_id + self.worker_queues = ( + {} + ) # a dict to keep track of all the worker_queues with the key of work_type + self.worker_types = ( + {} + ) # a dict to keep track of all the worker_types with the key of worker_id self.worker_id_counter = 0 # used to create worker_ids # Only spin up containers if active_workers + pending_workers < max_workers. @@ -32,17 +35,22 @@ def __init__(self, max_worker_count): self.worker_idle_since = {} def register_worker(self, worker_id, worker_type): - """ Add a new worker - """ - logger.debug("In register worker worker_id: {} type:{}".format(worker_id, worker_type)) + """Add a new worker""" + logger.debug(f"In register worker worker_id: {worker_id} type:{worker_type}") self.worker_types[worker_id] = worker_type if worker_type not in self.worker_queues: self.worker_queues[worker_type] = Queue() - self.total_worker_type_counts[worker_type] = self.total_worker_type_counts.get(worker_type, 0) + 1 - self.ready_worker_type_counts[worker_type] = self.ready_worker_type_counts.get(worker_type, 0) + 1 - self.pending_worker_type_counts[worker_type] = self.pending_worker_type_counts.get(worker_type, 0) - 1 + self.total_worker_type_counts[worker_type] = ( + self.total_worker_type_counts.get(worker_type, 0) + 1 + ) + self.ready_worker_type_counts[worker_type] = ( + self.ready_worker_type_counts.get(worker_type, 0) + 1 + ) + self.pending_worker_type_counts[worker_type] = ( + self.pending_worker_type_counts.get(worker_type, 0) - 1 + ) self.pending_workers -= 1 self.active_workers += 1 self.worker_queues[worker_type].put(worker_id) @@ -52,9 +60,9 @@ def register_worker(self, worker_id, worker_type): self.to_die_count[worker_type] = 0 def remove_worker(self, worker_id): - """ Remove the worker from the WorkerMap + """Remove the worker from the WorkerMap - Should already be KILLed by this point. + Should already be KILLed by this point. """ worker_type = self.worker_types[worker_id] @@ -62,19 +70,21 @@ def remove_worker(self, worker_id): self.active_workers -= 1 self.total_worker_type_counts[worker_type] -= 1 self.to_die_count[worker_type] -= 1 - self.total_worker_type_counts['unused'] += 1 - self.ready_worker_type_counts['unused'] += 1 - - def spin_up_workers(self, - next_worker_q, - mode='no_container', - container_cmd_options='', - address=None, - debug=None, - uid=None, - logdir=None, - worker_port=None): - """ Helper function to call 'remove' for appropriate workers in 'new_worker_map'. + self.total_worker_type_counts["unused"] += 1 + self.ready_worker_type_counts["unused"] += 1 + + def spin_up_workers( + self, + next_worker_q, + mode="no_container", + container_cmd_options="", + address=None, + debug=None, + uid=None, + logdir=None, + worker_port=None, + ): + """Helper function to call 'remove' for appropriate workers in 'new_worker_map'. Parameters ---------- @@ -99,28 +109,43 @@ def spin_up_workers(self, """ spin_ups = {} - logger.debug("[SPIN UP] Next Worker Qsize: {}".format(len(next_worker_q))) - logger.debug("[SPIN UP] Active Workers: {}".format(self.active_workers)) - logger.debug("[SPIN UP] Pending Workers: {}".format(self.pending_workers)) - logger.debug("[SPIN UP] Max Worker Count: {}".format(self.max_worker_count)) + logger.debug(f"[SPIN UP] Next Worker Qsize: {len(next_worker_q)}") + logger.debug(f"[SPIN UP] Active Workers: {self.active_workers}") + logger.debug(f"[SPIN UP] Pending Workers: {self.pending_workers}") + logger.debug(f"[SPIN UP] Max Worker Count: {self.max_worker_count}") - if len(next_worker_q) > 0 and self.active_workers + self.pending_workers < self.max_worker_count: + if ( + len(next_worker_q) > 0 + and self.active_workers + self.pending_workers < self.max_worker_count + ): logger.debug("[SPIN UP] Spinning up new workers!") - logger.debug(f"[SPIN up] Empty slots: {self.max_worker_count - self.active_workers - self.pending_workers}") + logger.debug( + "[SPIN up] Empty slots: %s", + self.max_worker_count - self.active_workers - self.pending_workers, + ) logger.debug(f"[SPIN up] New workers: {len(next_worker_q)}") - logger.debug(f"[SPIN up] Unused slots: {self.total_worker_type_counts['unused']}") - num_slots = min(self.max_worker_count - self.active_workers - self.pending_workers, len(next_worker_q), self.total_worker_type_counts['unused']) + logger.debug( + f"[SPIN up] Unused slots: {self.total_worker_type_counts['unused']}" + ) + num_slots = min( + self.max_worker_count - self.active_workers - self.pending_workers, + len(next_worker_q), + self.total_worker_type_counts["unused"], + ) for _ in range(num_slots): try: - proc = self.add_worker(worker_id=str(self.worker_id_counter), - worker_type=next_worker_q.pop(0), - container_cmd_options=container_cmd_options, - mode=mode, - address=address, debug=debug, - uid=uid, - logdir=logdir, - worker_port=worker_port) + proc = self.add_worker( + worker_id=str(self.worker_id_counter), + worker_type=next_worker_q.pop(0), + container_cmd_options=container_cmd_options, + mode=mode, + address=address, + debug=debug, + uid=uid, + logdir=logdir, + worker_port=worker_port, + ) except Exception: logger.exception("Error spinning up worker! Skipping...") continue @@ -128,8 +153,14 @@ def spin_up_workers(self, spin_ups.update(proc) return spin_ups - def spin_down_workers(self, new_worker_map, worker_max_idletime=60, need_more=False, scheduler_mode='hard'): - """ Helper function to call 'remove' for appropriate workers in 'new_worker_map'. + def spin_down_workers( + self, + new_worker_map, + worker_max_idletime=60, + need_more=False, + scheduler_mode="hard", + ): + """Helper function to call 'remove' for appropriate workers in 'new_worker_map'. Parameters ---------- @@ -142,12 +173,28 @@ def spin_down_workers(self, new_worker_map, worker_max_idletime=60, need_more=Fa List of removed worker types. """ if need_more: - return self._spin_down(new_worker_map, worker_max_idletime=worker_max_idletime, scheduler_mode=scheduler_mode, check_idle=False) + return self._spin_down( + new_worker_map, + worker_max_idletime=worker_max_idletime, + scheduler_mode=scheduler_mode, + check_idle=False, + ) else: - return self._spin_down(new_worker_map, worker_max_idletime=worker_max_idletime, scheduler_mode=scheduler_mode, check_idle=True) - - def _spin_down(self, new_worker_map, worker_max_idletime=60, scheduler_mode='hard', check_idle=True): - """ Helper function to call 'remove' for appropriate workers in 'new_worker_map'. + return self._spin_down( + new_worker_map, + worker_max_idletime=worker_max_idletime, + scheduler_mode=scheduler_mode, + check_idle=True, + ) + + def _spin_down( + self, + new_worker_map, + worker_max_idletime=60, + scheduler_mode="hard", + check_idle=True, + ): + """Helper function to call 'remove' for appropriate workers in 'new_worker_map'. Parameters ---------- @@ -155,10 +202,12 @@ def _spin_down(self, new_worker_map, worker_max_idletime=60, scheduler_mode='har {worker_type: total_number_of_containers,...}. check_idle : boolean A boolean to indicate whether to check the idle time of containers or not - If checked, that means the workloads are not so busy, - and we can leave the container workers alive until the worker_max_idletime is reached. - Otherwise, that means the workloads are busy and we need to turn of some containers to acommodate - the workers, regardless of if it reaches the worker_max_idletime. + + If checked, that means the workloads are not so busy, and we can leave the + container workers alive until the worker_max_idletime is reached. Otherwise, + that means the workloads are busy and we need to turn of some containers to + acommodate the workers, regardless of if it reaches the worker_max_idletime. + Returns --------- List of removed worker types. @@ -166,21 +215,41 @@ def _spin_down(self, new_worker_map, worker_max_idletime=60, scheduler_mode='har spin_downs = [] container_switch_count = 0 for worker_type in self.total_worker_type_counts: - if worker_type == 'unused': + if worker_type == "unused": continue - if check_idle and time.time() - self.worker_idle_since[worker_type] < worker_max_idletime: + if ( + check_idle + and time.time() - self.worker_idle_since[worker_type] + < worker_max_idletime + ): logger.debug(f"[SPIN DOWN] Current time: {time.time()}") - logger.debug(f"[SPIN DOWN] Idle since: {self.worker_idle_since[worker_type]}") - logger.debug(f"[SPIN DOWN] Worker type {worker_type} has not exceeded maximum idle time {worker_max_idletime}, continuing") + logger.debug( + f"[SPIN DOWN] Idle since: {self.worker_idle_since[worker_type]}" + ) + logger.debug( + "[SPIN DOWN] Worker type %s has not exceeded maximum idle " + "time %s, continuing", + worker_type, + worker_max_idletime, + ) continue - num_remove = max(0, self.total_worker_type_counts[worker_type] - self.to_die_count.get(worker_type, 0) - new_worker_map.get(worker_type, 0)) - if scheduler_mode == 'hard': + num_remove = max( + 0, + self.total_worker_type_counts[worker_type] + - self.to_die_count.get(worker_type, 0) + - new_worker_map.get(worker_type, 0), + ) + if scheduler_mode == "hard": # Leave at least one worker alive in hard mode max_remove = max(0, self.total_worker_type_counts[worker_type] - 1) num_remove = min(num_remove, max_remove) if num_remove > 0: - logger.debug("[SPIN DOWN] Removing {} workers of type {}".format(num_remove, worker_type)) + logger.debug( + "[SPIN DOWN] Removing {} workers of type {}".format( + num_remove, worker_type + ) + ) for _i in range(num_remove): spin_downs.append(worker_type) # A container switching is defined as a warm container must be @@ -194,17 +263,17 @@ def _spin_down(self, new_worker_map, worker_max_idletime=60, scheduler_mode='har def add_worker( self, worker_id=None, - mode='no_container', - worker_type='RAW', + mode="no_container", + worker_type="RAW", container_cmd_options="", walltime=1, address=None, debug=None, worker_port=None, logdir=None, - uid=None + uid=None, ): - """ Launch the appropriate worker + """Launch the appropriate worker Parameters ---------- @@ -219,64 +288,76 @@ def add_worker( if worker_id is None: str(random.random()) - debug = ' --debug' if debug else '' + debug = " --debug" if debug else "" - worker_id = ' --worker_id {}'.format(worker_id) + worker_id = f" --worker_id {worker_id}" self.worker_id_counter += 1 - cmd = (f'funcx-worker {debug}{worker_id} ' - f'-a {address} ' - f'-p {worker_port} ' - f'-t {worker_type} ' - f'--logdir={os.path.join(logdir, uid)} ') + cmd = ( + f"funcx-worker {debug}{worker_id} " + f"-a {address} " + f"-p {worker_port} " + f"-t {worker_type} " + f"--logdir={os.path.join(logdir, uid)} " + ) container_uri = None - if worker_type != 'RAW': + if worker_type != "RAW": container_uri = worker_type - logger.info("Command string :\n {}".format(cmd)) - logger.info("Mode: {}".format(mode)) - logger.info("Container uri: {}".format(container_uri)) - logger.info("Container cmd options: {}".format(container_cmd_options)) - logger.info("Worker type: {}".format(worker_type)) + logger.info(f"Command string :\n {cmd}") + logger.info(f"Mode: {mode}") + logger.info(f"Container uri: {container_uri}") + logger.info(f"Container cmd options: {container_cmd_options}") + logger.info(f"Worker type: {worker_type}") - if mode == 'no_container': + if mode == "no_container": modded_cmd = cmd - elif mode == 'singularity_reuse': + elif mode == "singularity_reuse": if container_uri is None: - logger.warning("No container is specified for singularity mode. " - "Spawning a worker in a raw process instead.") + logger.warning( + "No container is specified for singularity mode. " + "Spawning a worker in a raw process instead." + ) modded_cmd = cmd elif not os.path.exists(container_uri): - logger.warning(f"Container uri {container_uri} is not found. " - "Spawning a worker in a raw process instead.") + logger.warning( + f"Container uri {container_uri} is not found. " + "Spawning a worker in a raw process instead." + ) modded_cmd = cmd else: - modded_cmd = f'singularity exec {container_cmd_options} {container_uri} {cmd}' - logger.info("Command string with singularity:\n {}".format(modded_cmd)) + modded_cmd = ( + f"singularity exec {container_cmd_options} {container_uri} {cmd}" + ) + logger.info(f"Command string with singularity:\n {modded_cmd}") else: raise NameError("Invalid container launch mode.") try: - proc = subprocess.Popen(modded_cmd.split(), - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - shell=False) + proc = subprocess.Popen( + modded_cmd.split(), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + shell=False, + ) except Exception: logger.exception("Got an error in worker launch") raise - self.total_worker_type_counts['unused'] -= 1 - self.ready_worker_type_counts['unused'] -= 1 - self.pending_worker_type_counts[worker_type] = self.pending_worker_type_counts.get(worker_type, 0) + 1 + self.total_worker_type_counts["unused"] -= 1 + self.ready_worker_type_counts["unused"] -= 1 + self.pending_worker_type_counts[worker_type] = ( + self.pending_worker_type_counts.get(worker_type, 0) + 1 + ) self.pending_workers += 1 return {str(self.worker_id_counter - 1): proc} def get_next_worker_q(self, new_worker_map): - """ Helper function to generate a queue of next workers to spin up . + """Helper function to generate a queue of next workers to spin up . From a mapping generated by the scheduler Parameters @@ -291,20 +372,30 @@ def get_next_worker_q(self, new_worker_map): # next_worker_q = [] new_worker_list = [] - logger.debug(f"[GET_NEXT_WORKER] total_worker_type_counts: {self.total_worker_type_counts}") - logger.debug(f"[GET_NEXT_WORKER] pending_worker_type_counts: {self.pending_worker_type_counts}") + logger.debug( + "[GET_NEXT_WORKER] total_worker_type_counts: %s", + self.total_worker_type_counts, + ) + logger.debug( + "[GET_NEXT_WORKER] pending_worker_type_counts: %s", + self.pending_worker_type_counts, + ) for worker_type in new_worker_map: - cur_workers = self.total_worker_type_counts.get(worker_type, 0) + self.pending_worker_type_counts.get(worker_type, 0) + cur_workers = self.total_worker_type_counts.get( + worker_type, 0 + ) + self.pending_worker_type_counts.get(worker_type, 0) if new_worker_map[worker_type] > cur_workers: for _i in range(new_worker_map[worker_type] - cur_workers): # Add worker new_worker_list.append(worker_type) - # need_more is to reflect if a manager needs more workers than the current unused slots - # If yes, that means the manager needs to turn off some warm workers to serve the requests + # need_more is to reflect if a manager needs more workers than the current + # unused slots + # If yes, that means the manager needs to turn off some warm workers to serve + # the requests need_more = False - if len(new_worker_list) > self.total_worker_type_counts['unused']: + if len(new_worker_list) > self.total_worker_type_counts["unused"]: need_more = True # Randomly assign order of newly needed containers... add to spin-up queue. if len(new_worker_list) > 0: @@ -313,14 +404,14 @@ def get_next_worker_q(self, new_worker_map): return new_worker_list, need_more def update_worker_idle(self, worker_type): - """ Update the workers' last idle time by worker type - """ - logger.debug(f"[UPDATE_WORKER_IDLE] Worker idle since: {self.worker_idle_since}") + """Update the workers' last idle time by worker type""" + logger.debug( + f"[UPDATE_WORKER_IDLE] Worker idle since: {self.worker_idle_since}" + ) self.worker_idle_since[worker_type] = time.time() def put_worker(self, worker): - """ Adds worker to the list of waiting workers - """ + """Adds worker to the list of waiting workers""" worker_type = self.worker_types[worker] if worker_type not in self.worker_queues: @@ -330,7 +421,7 @@ def put_worker(self, worker): self.worker_queues[worker_type].put(worker) def get_worker(self, worker_type): - """ Get a task and reduce the # of worker for that type by 1. + """Get a task and reduce the # of worker for that type by 1. Raises queue.Empty if empty """ worker = self.worker_queues[worker_type].get_nowait() @@ -338,28 +429,36 @@ def get_worker(self, worker_type): return worker def get_worker_counts(self): - """ Returns just the dict of worker_type and counts - """ + """Returns just the dict of worker_type and counts""" return self.total_worker_type_counts def ready_worker_count(self): return sum(self.ready_worker_type_counts.values()) def advertisement(self): - """ Manager capacity advertisement to interchange - The advertisement includes two parts. One is the read_worker_type_counts, - which reflects the capacity of different types of containers on the manager. - The other is the total number of workers of each type. - This include all the pending workers and to_die workers when advertising. - We need this "total" advertisement because we use killer tasks mechanisms to kill a worker. - When a manager is advertising, there may be some killer tssks in queue, - we want to ensure that the manager does not over-advertise its actualy capacity, - and let interchange decide if it is sending too many tasks to the manager. """ - ads = {'total': {}, 'free': {}} + Manager capacity advertisement to interchange. + + The advertisement includes two parts: + + One is the read_worker_type_counts, which reflects the capacity of different + types of containers on the manager. + + The other is the total number of workers of each type. This includes all the + pending workers and to_die workers when advertising. We need this "total" + advertisement because we use killer task mechanisms to kill a worker. When a + manager is advertising, there may be some killer tasks in queue, and we want to + ensure that the manager does not over-advertise its actual capacity. Instead, + let the interchange decide if it is sending too many tasks to the manager. + """ + ads = {"total": {}, "free": {}} total = dict(self.total_worker_type_counts) for worker_type in self.pending_worker_type_counts: - total[worker_type] = total.get(worker_type, 0) + self.pending_worker_type_counts[worker_type] - self.to_die_count.get(worker_type, 0) - ads['total'].update(total) - ads['free'].update(self.ready_worker_type_counts) + total[worker_type] = ( + total.get(worker_type, 0) + + self.pending_worker_type_counts[worker_type] + - self.to_die_count.get(worker_type, 0) + ) + ads["total"].update(total) + ads["free"].update(self.ready_worker_type_counts) return ads diff --git a/funcx_endpoint/funcx_endpoint/executors/high_throughput/zmq_pipes.py b/funcx_endpoint/funcx_endpoint/executors/high_throughput/zmq_pipes.py index 4409ae694..6fb21aba2 100644 --- a/funcx_endpoint/funcx_endpoint/executors/high_throughput/zmq_pipes.py +++ b/funcx_endpoint/funcx_endpoint/executors/high_throughput/zmq_pipes.py @@ -1,19 +1,18 @@ #!/usr/bin/env python3 -import zmq -import time -import pickle import logging +import pickle +import time + +import zmq -from funcx import set_file_logger from funcx_endpoint.executors.high_throughput.messages import Message logger = logging.getLogger(__name__) -class CommandClient(object): - """ CommandClient - """ +class CommandClient: + """CommandClient""" def __init__(self, ip_address, port_range): """ @@ -29,12 +28,14 @@ def __init__(self, ip_address, port_range): self.context = zmq.Context() self.zmq_socket = self.context.socket(zmq.DEALER) self.zmq_socket.set_hwm(0) - self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(ip_address), - min_port=port_range[0], - max_port=port_range[1]) + self.port = self.zmq_socket.bind_to_random_port( + f"tcp://{ip_address}", + min_port=port_range[0], + max_port=port_range[1], + ) def run(self, message): - """ This function needs to be fast at the same time aware of the possibility of + """This function needs to be fast at the same time aware of the possibility of ZMQ pipes overflowing. The timeout increases slowly if contention is detected on ZMQ pipes. @@ -51,9 +52,8 @@ def close(self): self.context.term() -class TasksOutgoing(object): - """ Outgoing task queue from the executor to the Interchange - """ +class TasksOutgoing: + """Outgoing task queue from the executor to the Interchange""" def __init__(self, ip_address, port_range): """ @@ -69,14 +69,16 @@ def __init__(self, ip_address, port_range): self.context = zmq.Context() self.zmq_socket = self.context.socket(zmq.DEALER) self.zmq_socket.set_hwm(0) - self.port = self.zmq_socket.bind_to_random_port("tcp://{}".format(ip_address), - min_port=port_range[0], - max_port=port_range[1]) + self.port = self.zmq_socket.bind_to_random_port( + f"tcp://{ip_address}", + min_port=port_range[0], + max_port=port_range[1], + ) self.poller = zmq.Poller() self.poller.register(self.zmq_socket, zmq.POLLOUT) def put(self, message, max_timeout=1000): - """ This function needs to be fast at the same time aware of the possibility of + """This function needs to be fast at the same time aware of the possibility of ZMQ pipes overflowing. The timeout increases slowly if contention is detected on ZMQ pipes. @@ -90,7 +92,8 @@ def put(self, message, max_timeout=1000): message : py object Python object to send max_timeout : int - Max timeout in milliseconds that we will wait for before raising an exception + Max timeout in milliseconds that we will wait for before raising an + exception Raises ------ @@ -108,11 +111,15 @@ def put(self, message, max_timeout=1000): return else: timeout_ms += 1 - logger.debug("Not sending due to full zmq pipe, timeout: {} ms".format(timeout_ms)) + logger.debug( + "Not sending due to full zmq pipe, timeout: {} ms".format( + timeout_ms + ) + ) current_wait += timeout_ms # Send has failed. - logger.debug("Remote side has been unresponsive for {}".format(current_wait)) + logger.debug(f"Remote side has been unresponsive for {current_wait}") raise zmq.error.Again def close(self): @@ -120,9 +127,8 @@ def close(self): self.context.term() -class ResultsIncoming(object): - """ Incoming results queue from the Interchange to the executor - """ +class ResultsIncoming: + """Incoming results queue from the Interchange to the executor""" def __init__(self, ip_address, port_range): """ @@ -138,9 +144,11 @@ def __init__(self, ip_address, port_range): self.context = zmq.Context() self.results_receiver = self.context.socket(zmq.DEALER) self.results_receiver.set_hwm(0) - self.port = self.results_receiver.bind_to_random_port("tcp://{}".format(ip_address), - min_port=port_range[0], - max_port=port_range[1]) + self.port = self.results_receiver.bind_to_random_port( + f"tcp://{ip_address}", + min_port=port_range[0], + max_port=port_range[1], + ) def get(self, block=True, timeout=None): block_messages = self.results_receiver.recv() @@ -150,7 +158,10 @@ def get(self, block=True, timeout=None): try: res = Message.unpack(block_messages) except Exception: - logger.exception(f"Message in results queue is not pickle/Message formatted:{block_messages}") + logger.exception( + "Message in results queue is not pickle/Message formatted: %s", + block_messages, + ) return res def request_close(self): diff --git a/funcx_endpoint/funcx_endpoint/providers/__init__.py b/funcx_endpoint/funcx_endpoint/providers/__init__.py index cc5002e00..72019166b 100644 --- a/funcx_endpoint/funcx_endpoint/providers/__init__.py +++ b/funcx_endpoint/funcx_endpoint/providers/__init__.py @@ -1,3 +1,3 @@ from funcx_endpoint.providers.kubernetes.kube import KubernetesProvider -__all__ = ['KubernetesProvider'] +__all__ = ["KubernetesProvider"] diff --git a/funcx_endpoint/funcx_endpoint/providers/kubernetes/kube.py b/funcx_endpoint/funcx_endpoint/providers/kubernetes/kube.py index a282c8ea3..040ebc54b 100644 --- a/funcx_endpoint/funcx_endpoint/providers/kubernetes/kube.py +++ b/funcx_endpoint/funcx_endpoint/providers/kubernetes/kube.py @@ -1,18 +1,15 @@ import logging import queue import time - -from funcx_endpoint.providers.kubernetes.template import template_string - -logger = logging.getLogger("interchange.kube_provider") - -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import typeguard from parsl.errors import OptionalModuleMissing from parsl.providers.provider_base import ExecutionProvider from parsl.utils import RepresentationMixin +from funcx_endpoint.providers.kubernetes.template import template_string + try: from kubernetes import client, config @@ -20,6 +17,8 @@ except (ImportError, NameError, FileNotFoundError): _kubernetes_enabled = False +logger = logging.getLogger("interchange.kube_provider") + class KubernetesProvider(ExecutionProvider, RepresentationMixin): """Kubernetes execution provider @@ -54,9 +53,10 @@ class KubernetesProvider(ExecutionProvider, RepresentationMixin): This is the memory "requests" option for resource specification on kubernetes. Check kubernetes docs for more details. Default is 250Mi. parallelism : float - Ratio of provisioned task slots to active tasks. A parallelism value of 1 represents aggressive - scaling where as many resources as possible are used; parallelism close to 0 represents - the opposite situation in which as few resources as possible (i.e., min_blocks) are used. + Ratio of provisioned task slots to active tasks. A parallelism value of 1 + represents aggressive scaling where as many resources as possible are used; + parallelism close to 0 represents the opposite situation in which as few + resources as possible (i.e., min_blocks) are used. worker_init : str Command to be run first for the workers, such as `python start.py`. secret : str @@ -156,23 +156,23 @@ def submit(self, cmd_string, tasks_per_node, task_type, job_name="funcx"): """ cur_timestamp = str(time.time() * 1000).split(".")[0] - job_name = "{0}-{1}".format(job_name, cur_timestamp) + job_name = f"{job_name}-{cur_timestamp}" # Use default image image = self.image if task_type == "RAW" else task_type # Set the pod name if not self.pod_name: - pod_name = "{}".format(job_name) + pod_name = f"{job_name}" else: - pod_name = "{}-{}".format(self.pod_name, cur_timestamp) + pod_name = f"{self.pod_name}-{cur_timestamp}" - logger.debug("cmd_string is {}".format(cmd_string)) + logger.debug(f"cmd_string is {cmd_string}") formatted_cmd = template_string.format( command=cmd_string, worker_init=self.worker_init ) - logger.info("[KUBERNETES] Scaling out a pod with name :{}".format(pod_name)) + logger.info(f"[KUBERNETES] Scaling out a pod with name :{pod_name}") self._create_pod( image=image, pod_name=pod_name, @@ -224,7 +224,7 @@ def cancel(self, num_pods, task_type=None): break else: num_pods -= 1 - logger.info("[KUBERNETES] The to_kill pods are {}".format(to_kill)) + logger.info(f"[KUBERNETES] The to_kill pods are {to_kill}") rets = self._cancel(to_kill) return to_kill, rets @@ -236,7 +236,7 @@ def _cancel(self, job_ids): [True/False...] : If the cancel operation fails the entire list will be False. """ for job in job_ids: - logger.debug("Terminating job/proc_id: {0}".format(job)) + logger.debug(f"Terminating job/proc_id: {job}") # Here we are assuming that for local, the job_ids are the process id's self._delete_pod(job) @@ -291,7 +291,7 @@ def _create_pod( # Create the enviornment variables and command to initiate IPP environment_vars = client.V1EnvVar(name="TEST", value="SOME DATA") - launch_args = ["-c", "{0}".format(cmd_string)] + launch_args = ["-c", f"{cmd_string}"] volume_mounts = [] # Create mount paths for the volumes @@ -342,7 +342,7 @@ def _create_pod( api_response = self.kube_client.create_namespaced_pod( namespace=self.namespace, body=pod ) - logger.debug("Pod created. status='{0}'".format(str(api_response.status))) + logger.debug(f"Pod created. status='{str(api_response.status)}'") def _delete_pod(self, pod_name): """Delete a pod""" @@ -350,7 +350,7 @@ def _delete_pod(self, pod_name): api_response = self.kube_client.delete_namespaced_pod( name=pod_name, namespace=self.namespace, body=client.V1DeleteOptions() ) - logger.debug("Pod deleted. status='{0}'".format(str(api_response.status))) + logger.debug(f"Pod deleted. status='{str(api_response.status)}'") @property def label(self): diff --git a/funcx_endpoint/funcx_endpoint/strategies/__init__.py b/funcx_endpoint/funcx_endpoint/strategies/__init__.py index a4c22e78d..5e6eb0ef7 100644 --- a/funcx_endpoint/funcx_endpoint/strategies/__init__.py +++ b/funcx_endpoint/funcx_endpoint/strategies/__init__.py @@ -1,8 +1,5 @@ from funcx_endpoint.strategies.base import BaseStrategy -from funcx_endpoint.strategies.simple import SimpleStrategy from funcx_endpoint.strategies.kube_simple import KubeSimpleStrategy +from funcx_endpoint.strategies.simple import SimpleStrategy - -__all__ = ['BaseStrategy', - 'SimpleStrategy', - 'KubeSimpleStrategy'] +__all__ = ["BaseStrategy", "SimpleStrategy", "KubeSimpleStrategy"] diff --git a/funcx_endpoint/funcx_endpoint/strategies/base.py b/funcx_endpoint/funcx_endpoint/strategies/base.py index 719fd8b0f..dabd1561e 100644 --- a/funcx_endpoint/funcx_endpoint/strategies/base.py +++ b/funcx_endpoint/funcx_endpoint/strategies/base.py @@ -1,12 +1,11 @@ -import sys -import threading import logging +import threading import time logger = logging.getLogger("interchange.strategy.base") -class BaseStrategy(object): +class BaseStrategy: """Implements threshold-interval based flow control. The overall goal is to trap the flow of apps from the @@ -61,7 +60,9 @@ def __init__(self, *args, threshold=20, interval=5): self._event_buffer = [] self._wake_up_time = time.time() + 1 self._kill_event = threading.Event() - self._thread = threading.Thread(target=self._wake_up_timer, args=(self._kill_event,)) + self._thread = threading.Thread( + target=self._wake_up_timer, args=(self._kill_event,) + ) self._thread.daemon = True def start(self, interchange): @@ -72,21 +73,25 @@ def start(self, interchange): Interchange to bind the strategy to """ self.interchange = interchange - if hasattr(interchange, 'provider'): - logger.debug("Strategy bounds-> init:{}, min:{}, max:{}".format( - interchange.provider.init_blocks, - interchange.provider.min_blocks, - interchange.provider.max_blocks)) + if hasattr(interchange, "provider"): + logger.debug( + "Strategy bounds-> init:{}, min:{}, max:{}".format( + interchange.provider.init_blocks, + interchange.provider.min_blocks, + interchange.provider.max_blocks, + ) + ) self._thread.start() def strategize(self, *args, **kwargs): - """ Strategize is called everytime the threshold or the interval is hit - """ - logger.debug("Strategize called with {} {}".format(args, kwargs)) + """Strategize is called everytime the threshold or the interval is hit""" + logger.debug(f"Strategize called with {args} {kwargs}") def _wake_up_timer(self, kill_event): - """Internal. This is the function that the thread will execute. - waits on an event so that the thread can make a quick exit when close() is called + """ + Internal. This is the function that the thread will execute. + waits on an event so that the thread can make a quick exit when close() is + called Args: - kill_event (threading.Event) : Event to wait on @@ -103,7 +108,7 @@ def _wake_up_timer(self, kill_event): return if prev == self._wake_up_time: - self.make_callback(kind='timer') + self.make_callback(kind="timer") else: print("Sleeping a bit more") @@ -135,7 +140,7 @@ def close(self): self._thread.join() -class Timer(object): +class Timer: """This timer is a simplified version of the FlowControl timer. This timer does not employ notify events. @@ -173,13 +178,16 @@ def __init__(self, callback, *args, interval=5): self._wake_up_time = time.time() + 1 self._kill_event = threading.Event() - self._thread = threading.Thread(target=self._wake_up_timer, args=(self._kill_event,)) + self._thread = threading.Thread( + target=self._wake_up_timer, args=(self._kill_event,) + ) self._thread.daemon = True self._thread.start() def _wake_up_timer(self, kill_event): """Internal. This is the function that the thread will execute. - waits on an event so that the thread can make a quick exit when close() is called + waits on an event so that the thread can make a quick exit when close() is + called Args: - kill_event (threading.Event) : Event to wait on @@ -197,18 +205,16 @@ def _wake_up_timer(self, kill_event): return if prev == self._wake_up_time: - self.make_callback(kind='timer') + self.make_callback(kind="timer") else: print("Sleeping a bit more") def make_callback(self, kind=None): - """Makes the callback and resets the timer. - """ + """Makes the callback and resets the timer.""" self._wake_up_time = time.time() + self.interval self.callback(*self.cb_args) def close(self): - """Merge the threads and terminate. - """ + """Merge the threads and terminate.""" self._kill_event.set() self._thread.join() diff --git a/funcx_endpoint/funcx_endpoint/strategies/kube_simple.py b/funcx_endpoint/funcx_endpoint/strategies/kube_simple.py index 88eb8f418..b73a4f66a 100644 --- a/funcx_endpoint/funcx_endpoint/strategies/kube_simple.py +++ b/funcx_endpoint/funcx_endpoint/strategies/kube_simple.py @@ -1,19 +1,16 @@ -from funcx_endpoint.strategies.base import BaseStrategy -import math import logging +import math import time +from funcx_endpoint.strategies.base import BaseStrategy + logger = logging.getLogger("interchange.strategy.KubeSimple") class KubeSimpleStrategy(BaseStrategy): - """ Implements the simple strategy for Kubernetes - """ + """Implements the simple strategy for Kubernetes""" - def __init__(self, *args, - threshold=20, - interval=1, - max_idletime=60): + def __init__(self, *args, threshold=20, interval=1, max_idletime=60): """Initialize the flowcontrol object. We start the timer thread here @@ -27,7 +24,8 @@ def __init__(self, *args, seconds after which timer expires max_idletime: (int) - maximum idle time(seconds) allowed for resources after which strategy will try to kill them. + maximum idle time(seconds) allowed for resources after which strategy will + try to kill them. default: 60s """ @@ -40,7 +38,7 @@ def strategize(self, *args, **kwargs): try: self._strategize(*args, **kwargs) except Exception as e: - logger.exception("Caught error in strategize : {}".format(e)) + logger.exception(f"Caught error in strategize : {e}") pass def _strategize(self, *args, **kwargs): @@ -51,7 +49,7 @@ def _strategize(self, *args, **kwargs): managers_per_pod = 1 workers_per_pod = self.interchange.max_workers_per_node - if workers_per_pod == float('inf'): + if workers_per_pod == float("inf"): workers_per_pod = 1 parallelism = self.interchange.provider.parallelism @@ -68,9 +66,14 @@ def _strategize(self, *args, **kwargs): active_tasks_per_type = active_tasks[task_type] logger.debug( - 'Endpoint has {} active tasks of {}, {} active blocks, {} connected workers for {}'.format( - active_tasks_per_type, task_type, active_pods, - self.interchange.get_total_live_workers(), task_type)) + "Endpoint has %s active tasks of %s, %s active blocks, " + "%s connected workers for %s", + active_tasks_per_type, + task_type, + active_pods, + self.interchange.get_total_live_workers(), + task_type, + ) # Reset the idle time if we are currently running tasks if active_tasks_per_type > 0: @@ -79,27 +82,45 @@ def _strategize(self, *args, **kwargs): # Scale down only if there are no active tasks to avoid having to find which # workers are unoccupied if active_tasks_per_type == 0 and active_pods > min_pods: - # We want to make sure that max_idletime is reached before killing off resources + # We want to make sure that max_idletime is reached before killing off + # resources if not self.executors_idle_since[task_type]: logger.debug( - "Endpoint has 0 active tasks of task type {}; starting kill timer (if idle time exceeds {}s, resources will be removed)". - format(task_type, self.max_idletime)) + "Endpoint has 0 active tasks of task type %s; " + "starting kill timer (if idle time exceeds %s seconds, " + "resources will be removed)", + task_type, + self.max_idletime, + ) self.executors_idle_since[task_type] = time.time() - # If we have resources idle for the max duration we have to scale_in now. - if (time.time() - self.executors_idle_since[task_type]) > self.max_idletime: + # If we have resources idle for the max duration we have to scale_in now + if ( + time.time() - self.executors_idle_since[task_type] + ) > self.max_idletime: logger.info( - "Idle time has reached {}s; removing resources of task type {}".format( - self.max_idletime, task_type) + "Idle time has reached %s seconds; " + "removing resources of task type %s", + self.max_idletime, + task_type, + ) + self.interchange.scale_in( + active_pods - min_pods, task_type=task_type ) - self.interchange.scale_in(active_pods - min_pods, task_type=task_type) # More tasks than the available slots. - elif active_tasks_per_type > 0 and (float(active_slots) / active_tasks_per_type) < parallelism: + elif ( + active_tasks_per_type > 0 + and (float(active_slots) / active_tasks_per_type) < parallelism + ): if active_pods < max_pods: - excess = math.ceil((active_tasks_per_type * parallelism) - active_slots) - excess_blocks = math.ceil(float(excess) / (workers_per_pod * managers_per_pod)) + excess = math.ceil( + (active_tasks_per_type * parallelism) - active_slots + ) + excess_blocks = math.ceil( + float(excess) / (workers_per_pod * managers_per_pod) + ) excess_blocks = min(excess_blocks, max_pods - active_pods) - logger.info("Requesting {} more blocks".format(excess_blocks)) + logger.info(f"Requesting {excess_blocks} more blocks") self.interchange.scale_out(excess_blocks, task_type=task_type) # Immediatly scale if we are stuck with zero pods and work to do elif active_slots == 0 and active_tasks_per_type > 0: diff --git a/funcx_endpoint/funcx_endpoint/strategies/simple.py b/funcx_endpoint/funcx_endpoint/strategies/simple.py index 0e6fb3bb0..c62725c3f 100644 --- a/funcx_endpoint/funcx_endpoint/strategies/simple.py +++ b/funcx_endpoint/funcx_endpoint/strategies/simple.py @@ -1,6 +1,7 @@ -import math import logging +import math import time + from parsl.providers.provider_base import JobState from funcx_endpoint.strategies.base import BaseStrategy @@ -9,13 +10,9 @@ class SimpleStrategy(BaseStrategy): - """ Implements the simple strategy - """ + """Implements the simple strategy""" - def __init__(self, *args, - threshold=20, - interval=1, - max_idletime=60): + def __init__(self, *args, threshold=20, interval=1, max_idletime=60): """Initialize the flowcontrol object. We start the timer thread here @@ -29,20 +26,21 @@ def __init__(self, *args, seconds after which timer expires max_idletime: (int) - maximum idle time(seconds) allowed for resources after which strategy will try to kill them. + maximum idle time(seconds) allowed for resources after which strategy will + try to kill them. default: 60s """ logger.info("SimpleStrategy Initialized") super().__init__(*args, threshold=threshold, interval=interval) self.max_idletime = max_idletime - self.executors = {'idle_since': None} + self.executors = {"idle_since": None} def strategize(self, *args, **kwargs): try: self._strategize(*args, **kwargs) except Exception as e: - logger.exception("Caught error in strategize : {}".format(e)) + logger.exception(f"Caught error in strategize : {e}") pass def _strategize(self, *args, **kwargs): @@ -55,7 +53,7 @@ def _strategize(self, *args, **kwargs): # Here we assume that each node has atleast 4 workers tasks_per_node = self.interchange.max_workers_per_node - if self.interchange.max_workers_per_node == float('inf'): + if self.interchange.max_workers_per_node == float("inf"): tasks_per_node = 1 nodes_per_block = self.interchange.provider.nodes_per_block @@ -70,12 +68,18 @@ def _strategize(self, *args, **kwargs): active_blocks = running + pending active_slots = active_blocks * tasks_per_node * nodes_per_block - logger.debug('Endpoint has {} active tasks, {}/{} running/pending blocks, and {} connected workers'.format( - active_tasks, running, pending, self.interchange.get_total_live_workers())) + logger.debug( + "Endpoint has %s active tasks, %s/%s running/pending blocks, " + "and %s connected workers", + active_tasks, + running, + pending, + self.interchange.get_total_live_workers(), + ) # reset kill timer if executor has active tasks - if active_tasks > 0 and self.executors['idle_since']: - self.executors['idle_since'] = None + if active_tasks > 0 and self.executors["idle_since"]: + self.executors["idle_since"] = None # Case 1 # No tasks. @@ -92,24 +96,30 @@ def _strategize(self, *args, **kwargs): else: # We want to make sure that max_idletime is reached # before killing off resources - if not self.executors['idle_since']: - logger.debug("Endpoint has 0 active tasks; starting kill timer (if idle time exceeds {}s, resources will be removed)".format( - self.max_idletime) + if not self.executors["idle_since"]: + logger.debug( + "Endpoint has 0 active tasks; starting kill timer " + "(if idle time exceeds %s seconds, resources will be removed)", + self.max_idletime, ) - self.executors['idle_since'] = time.time() + self.executors["idle_since"] = time.time() - idle_since = self.executors['idle_since'] + idle_since = self.executors["idle_since"] if (time.time() - idle_since) > self.max_idletime: # We have resources idle for the max duration, # we have to scale_in now. - logger.debug("Idle time has reached {}s; removing resources".format( - self.max_idletime) + logger.debug( + "Idle time has reached {}s; removing resources".format( + self.max_idletime + ) ) self.interchange.scale_in(active_blocks - min_blocks) else: pass - # logger.debug("Strategy: Case.1b. Waiting for timer : {0}".format(idle_since)) + # logger.debug( + # "Strategy: Case.1b. Waiting for timer : %s", idle_since + # ) # Case 2 # More tasks than the available slots. @@ -125,9 +135,11 @@ def _strategize(self, *args, **kwargs): else: # logger.debug("Strategy: Case.2b") excess = math.ceil((active_tasks * parallelism) - active_slots) - excess_blocks = math.ceil(float(excess) / (tasks_per_node * nodes_per_block)) + excess_blocks = math.ceil( + float(excess) / (tasks_per_node * nodes_per_block) + ) excess_blocks = min(excess_blocks, max_blocks - active_blocks) - logger.debug("Requesting {} more blocks".format(excess_blocks)) + logger.debug(f"Requesting {excess_blocks} more blocks") self.interchange.scale_out(excess_blocks) elif active_slots == 0 and active_tasks > 0: diff --git a/funcx_endpoint/funcx_endpoint/strategies/test.py b/funcx_endpoint/funcx_endpoint/strategies/test.py index e3745231f..b9543776d 100644 --- a/funcx_endpoint/funcx_endpoint/strategies/test.py +++ b/funcx_endpoint/funcx_endpoint/strategies/test.py @@ -4,8 +4,7 @@ from funcx_endpoint.strategies import SimpleStrategy -class MockInterchange(object): - +class MockInterchange: def __init__(self, max_blocks=1, tasks=10): self.tasks_pending = tasks self.max_blocks = max_blocks @@ -22,7 +21,7 @@ def get_outstanding_breakdown(self): this_round = self.tasks_pending self.tasks_pending = 0 - current = [('interchange', this_round, this_round)] + current = [("interchange", this_round, this_round)] for i in range(self.managers): current.extend((f"manager_{i}", 1, 1)) self.status.put(current) @@ -35,11 +34,11 @@ def scale_out(self): def create_data(self): q = queue.Queue() items = [ - [('interchange', 0, 0)], - [('interchange', 0, 0)], - [('interchange', 0, 0)], - [('interchange', self.tasks_pending, self.tasks_pending)], - [('interchange', self.tasks_pending, self.tasks_pending)] + [("interchange", 0, 0)], + [("interchange", 0, 0)], + [("interchange", 0, 0)], + [("interchange", self.tasks_pending, self.tasks_pending)], + [("interchange", self.tasks_pending, self.tasks_pending)], ] [q.put(i) for i in items] diff --git a/funcx_endpoint/funcx_endpoint/tests/strategies/test_kube_simple.py b/funcx_endpoint/funcx_endpoint/tests/strategies/test_kube_simple.py index d3863b01a..c5ae7a44d 100644 --- a/funcx_endpoint/funcx_endpoint/tests/strategies/test_kube_simple.py +++ b/funcx_endpoint/funcx_endpoint/tests/strategies/test_kube_simple.py @@ -1,5 +1,6 @@ import threading import time + from pytest import fixture from funcx_endpoint.executors.high_throughput.interchange import Interchange @@ -17,7 +18,6 @@ def no_op_worker(): class TestKubeSimple: - @fixture def mock_interchange(self, mocker): mock_interchange = mocker.MagicMock(Interchange) @@ -29,8 +29,10 @@ def mock_interchange(self, mocker): mock_interchange.config.provider.max_blocks = 4 mock_interchange.config.provider.nodes_per_block = 1 mock_interchange.config.provider.parallelism = 1.0 - mock_interchange.get_total_tasks_outstanding = mocker.Mock(return_value={'RAW': 0}) - mock_interchange.provider_status = mocker.Mock(return_value={'RAW': 16}) + mock_interchange.get_total_tasks_outstanding = mocker.Mock( + return_value={"RAW": 0} + ) + mock_interchange.provider_status = mocker.Mock(return_value={"RAW": 16}) mock_interchange.get_total_live_workers = mocker.Mock(return_value=0) mock_interchange.scale_in = mocker.Mock() mock_interchange.scale_out = mocker.Mock() @@ -48,7 +50,8 @@ def kube_strategy(self): def test_no_tasks_no_pods(self, mock_interchange, kube_strategy): mock_interchange.get_outstanding_breakdown.return_value = [ - ('interchange', 0, True)] + ("interchange", 0, True) + ] mock_interchange.get_total_tasks_outstanding.return_value = [] kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") @@ -57,8 +60,8 @@ def test_no_tasks_no_pods(self, mock_interchange, kube_strategy): def test_scale_in_with_no_tasks(self, mock_interchange, kube_strategy): # First there is work to do and pods are scaled up - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 16} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 16} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -66,7 +69,7 @@ def test_scale_in_with_no_tasks(self, mock_interchange, kube_strategy): # Now tasks are all done, but pods are still running. Idle time has not yet # been reached, so the pods will still be running. - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 0} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 0} kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() mock_interchange.scale_out.assert_not_called() @@ -81,8 +84,8 @@ def test_scale_in_with_no_tasks(self, mock_interchange, kube_strategy): def test_task_arrives_during_idle_time(self, mock_interchange, kube_strategy): # First there is work to do and pods are scaled up - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 16} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 16} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -90,20 +93,20 @@ def test_task_arrives_during_idle_time(self, mock_interchange, kube_strategy): # Now tasks are all done, but pods are still running. Idle time has not yet # been reached, so the pods will still be running. - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 0} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 0} kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() mock_interchange.scale_out.assert_not_called() # Now add a new task - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 1} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 1} kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() mock_interchange.scale_out.assert_not_called() # Verify that idle time is reset time.sleep(5) - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 0} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 0} kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() mock_interchange.scale_out.assert_not_called() @@ -112,8 +115,8 @@ def test_task_backlog_within_parallelism(self, mock_interchange, kube_strategy): # Aggressive scaling so new tasks will create new pods mock_interchange.config.provider.parallelism = 1.0 mock_interchange.config.provider.max_blocks = 16 - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 1} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 1} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -123,8 +126,8 @@ def test_task_backlog_gated_by_parallelism(self, mock_interchange, kube_strategy # Lazy scaling, so just a single new task won't spawn a new pod mock_interchange.config.provider.parallelism = 0.5 mock_interchange.config.provider.max_blocks = 16 - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 1} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 1} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -133,8 +136,8 @@ def test_task_backlog_gated_by_parallelism(self, mock_interchange, kube_strategy def test_task_backlog_gated_by_max_blocks(self, mock_interchange, kube_strategy): mock_interchange.config.provider.parallelism = 1.0 mock_interchange.config.provider.max_blocks = 8 - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 1} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 1} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -143,8 +146,8 @@ def test_task_backlog_gated_by_max_blocks(self, mock_interchange, kube_strategy) def test_task_backlog_already_max_blocks(self, mock_interchange, kube_strategy): mock_interchange.config.provider.parallelism = 1.0 mock_interchange.config.provider.max_blocks = 8 - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 16} - mock_interchange.provider_status.return_value = {'RAW': 16} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 16} + mock_interchange.provider_status.return_value = {"RAW": 16} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() @@ -152,8 +155,8 @@ def test_task_backlog_already_max_blocks(self, mock_interchange, kube_strategy): def test_scale_when_no_pods(self, mock_interchange, kube_strategy): mock_interchange.config.provider.parallelism = 0.01 # Very lazy scaling - mock_interchange.provider_status.return_value = {'RAW': 0} - mock_interchange.get_total_tasks_outstanding.return_value = {'RAW': 1} + mock_interchange.provider_status.return_value = {"RAW": 0} + mock_interchange.get_total_tasks_outstanding.return_value = {"RAW": 1} kube_strategy.start(mock_interchange) kube_strategy.make_callback(kind="timer") mock_interchange.scale_in.assert_not_called() diff --git a/funcx_endpoint/funcx_endpoint/version.py b/funcx_endpoint/funcx_endpoint/version.py index 395c3fe21..3dbef89d5 100644 --- a/funcx_endpoint/funcx_endpoint/version.py +++ b/funcx_endpoint/funcx_endpoint/version.py @@ -4,4 +4,4 @@ VERSION = __version__ # app name to send as part of requests -app_name = "funcX Endpoint v{}".format(__version__) +app_name = f"funcX Endpoint v{__version__}" diff --git a/funcx_endpoint/setup.py b/funcx_endpoint/setup.py index 6008f513f..a6581328d 100644 --- a/funcx_endpoint/setup.py +++ b/funcx_endpoint/setup.py @@ -79,9 +79,12 @@ entry_points={ "console_scripts": [ "funcx-endpoint=funcx_endpoint.endpoint.endpoint:cli_run", - "funcx-interchange=funcx_endpoint.executors.high_throughput.interchange:cli_run", - "funcx-manager=funcx_endpoint.executors.high_throughput.funcx_manager:cli_run", - "funcx-worker=funcx_endpoint.executors.high_throughput.funcx_worker:cli_run", + "funcx-interchange" + "=funcx_endpoint.executors.high_throughput.interchange:cli_run", + "funcx-manager" + "=funcx_endpoint.executors.high_throughput.funcx_manager:cli_run", + "funcx-worker" + "=funcx_endpoint.executors.high_throughput.funcx_worker:cli_run", ] }, include_package_data=True, diff --git a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint.py b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint.py index 79e089719..1bde31486 100644 --- a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint.py +++ b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint.py @@ -1,12 +1,14 @@ import os + import pytest -from funcx_endpoint.endpoint.endpoint import app from typer.testing import CliRunner +from funcx_endpoint.endpoint.endpoint import app + runner = CliRunner() -config_string = ''' +config_string = """ from funcx_endpoint.endpoint.utils.config import Config from parsl.providers import LocalProvider @@ -18,11 +20,10 @@ max_blocks=1, ), funcx_service_address='https://api.funcx.org/v1' -)''' +)""" class TestEndpoint: - @pytest.fixture(autouse=True) def test_setup_teardown(self, mocker): mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") @@ -30,12 +31,12 @@ def test_setup_teardown(self, mocker): def test_non_configured_endpoint(self, mocker): result = runner.invoke(app, ["start", "newendpoint"]) - assert 'newendpoint' in result.stdout - assert 'not configured' in result.stdout + assert "newendpoint" in result.stdout + assert "not configured" in result.stdout def test_using_outofdate_config(self, mocker): - mock_loader = mocker.patch('funcx_endpoint.endpoint.endpoint.os.path.join') - mock_loader.return_value = './config.py' + mock_loader = mocker.patch("funcx_endpoint.endpoint.endpoint.os.path.join") + mock_loader.return_value = "./config.py" config_file = open("./config.py", "w") config_file.write(config_string) config_file.close() diff --git a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint_manager.py b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint_manager.py index f63cb2dd4..d89c27ac5 100644 --- a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint_manager.py +++ b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_endpoint_manager.py @@ -1,26 +1,25 @@ -from funcx_endpoint.endpoint.endpoint_manager import EndpointManager -from importlib.machinery import SourceFileLoader -import os +import json import logging -import sys +import os import shutil -import pytest -import json -from pytest import fixture +from importlib.machinery import SourceFileLoader from unittest.mock import ANY -from globus_sdk import GlobusHTTPResponse, GlobusAPIError + +import pytest +from globus_sdk import GlobusAPIError, GlobusHTTPResponse from requests import Response -logger = logging.getLogger('mock_funcx') +from funcx_endpoint.endpoint.endpoint_manager import EndpointManager + +logger = logging.getLogger("mock_funcx") class TestStart: - @pytest.fixture(autouse=True) def test_setup_teardown(self): # Code that will run before your test, for example: - funcx_dir = f'{os.getcwd()}' + funcx_dir = f"{os.getcwd()}" config_dir = os.path.join(funcx_dir, "mock_endpoint") assert not os.path.exists(config_dir) # A test function will be run at this point @@ -40,20 +39,27 @@ def test_double_configure(self): manager.configure_endpoint("mock_endpoint", None) assert os.path.exists(config_dir) - with pytest.raises(Exception, match='ConfigExists'): + with pytest.raises(Exception, match="ConfigExists"): manager.configure_endpoint("mock_endpoint", None) def test_start(self, mocker): - mock_client = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") - reg_info = {'endpoint_id': 'abcde12345', - 'address': 'localhost', - 'client_ports': '8080'} + mock_client = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.FuncXClient" + ) + reg_info = { + "endpoint_id": "abcde12345", + "address": "localhost", + "client_ports": "8080", + } mock_client.return_value.register_endpoint.return_value = reg_info - mock_zmq_create = mocker.patch("zmq.auth.create_certificates", - return_value=("public/key/file", None)) - mock_zmq_load = mocker.patch("zmq.auth.load_certificate", - return_value=("12345abcde".encode(), "12345abcde".encode())) + mock_zmq_create = mocker.patch( + "zmq.auth.create_certificates", return_value=("public/key/file", None) + ) + mock_zmq_load = mocker.patch( + "zmq.auth.load_certificate", + return_value=(b"12345abcde", b"12345abcde"), + ) mock_context = mocker.patch("daemon.DaemonContext") @@ -61,28 +67,36 @@ def test_start(self, mocker): mock_context.return_value.__enter__.return_value = None mock_context.return_value.__exit__.return_value = None - mock_context.return_value.pidfile.path = '' + mock_context.return_value.pidfile.path = "" - mock_daemon = mocker.patch.object(EndpointManager, 'daemon_launch', - return_value=None) + mock_daemon = mocker.patch.object( + EndpointManager, "daemon_launch", return_value=None + ) - mock_uuid = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4') + mock_uuid = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4") mock_uuid.return_value = 123456 - mock_pidfile = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile') + mock_pidfile = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile" + ) mock_pidfile.return_value = None - mock_results_ack_handler = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler') + mock_results_ack_handler = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler" + ) manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() manager.start_endpoint("mock_endpoint", None, endpoint_config) - mock_zmq_create.assert_called_with(os.path.join(config_dir, "certificates"), "endpoint") + mock_zmq_create.assert_called_with( + os.path.join(config_dir, "certificates"), "endpoint" + ) mock_zmq_load.assert_called_with("public/key/file") funcx_client_options = { @@ -90,20 +104,24 @@ def test_start(self, mocker): "check_endpoint_version": True, } - mock_daemon.assert_called_with('123456', - config_dir, - os.path.join(config_dir, "certificates"), - endpoint_config, - reg_info, - funcx_client_options, - mock_results_ack_handler.return_value) - - mock_context.assert_called_with(working_directory=config_dir, - umask=0o002, - pidfile=None, - stdout=ANY, # open(os.path.join(config_dir, './interchange.stdout'), 'w+'), - stderr=ANY, # open(os.path.join(config_dir, './interchange.stderr'), 'w+'), - detach_process=True) + mock_daemon.assert_called_with( + "123456", + config_dir, + os.path.join(config_dir, "certificates"), + endpoint_config, + reg_info, + funcx_client_options, + mock_results_ack_handler.return_value, + ) + + mock_context.assert_called_with( + working_directory=config_dir, + umask=0o002, + pidfile=None, + stdout=ANY, # open(os.path.join(config_dir, './interchange.stdout'), 'w+'), + stderr=ANY, # open(os.path.join(config_dir, './interchange.stderr'), 'w+'), + detach_process=True, + ) def test_start_registration_error(self, mocker): """This tests what happens if a 400 error response comes back from the @@ -115,70 +133,84 @@ def test_start_registration_error(self, mocker): mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") base_r = Response() - base_r.headers = { - "Content-Type": "json" - } + base_r.headers = {"Content-Type": "json"} base_r.status_code = 400 r = GlobusHTTPResponse(base_r) r.status_code = base_r.status_code r.headers = base_r.headers - mock_register_endpoint = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.register_endpoint') + mock_register_endpoint = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.register_endpoint" + ) mock_register_endpoint.side_effect = GlobusAPIError(r) - mock_zmq_create = mocker.patch("zmq.auth.create_certificates", - return_value=("public/key/file", None)) - mock_zmq_load = mocker.patch("zmq.auth.load_certificate", - return_value=("12345abcde".encode(), "12345abcde".encode())) + mock_zmq_create = mocker.patch( + "zmq.auth.create_certificates", return_value=("public/key/file", None) + ) + mock_zmq_load = mocker.patch( + "zmq.auth.load_certificate", + return_value=(b"12345abcde", b"12345abcde"), + ) - mock_uuid = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4') + mock_uuid = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4") mock_uuid.return_value = 123456 - mock_pidfile = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile') + mock_pidfile = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile" + ) mock_pidfile.return_value = None - mocker.patch('funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler') + mocker.patch("funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler") manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() with pytest.raises(GlobusAPIError): manager.start_endpoint("mock_endpoint", None, endpoint_config) - mock_zmq_create.assert_called_with(os.path.join(config_dir, "certificates"), "endpoint") + mock_zmq_create.assert_called_with( + os.path.join(config_dir, "certificates"), "endpoint" + ) mock_zmq_load.assert_called_with("public/key/file") def test_start_registration_5xx_error(self, mocker): - """This tests what happens if a 500 error response comes back from the - initial endpoint registration. It is expected that this exception should - NOT be raised and that the interchange should be started without any registration - info being passed in. The registration should then be retried in the interchange - daemon, because a 5xx error suggests that there is a temporary service issue - that will resolve on its own. mock_zmq_create and mock_zmq_load are being - asserted against because this zmq setup happens before registration occurs. + """ + This tests what happens if a 500 error response comes back from the initial + endpoint registration. + + It is expected that this exception should NOT be raised and that the interchange + should be started without any registration info being passed in. The + registration should then be retried in the interchange daemon, because a 5xx + error suggests that there is a temporary service issue that will resolve on its + own. mock_zmq_create and mock_zmq_load are being asserted against because this + zmq setup happens before registration occurs. """ mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") base_r = Response() - base_r.headers = { - "Content-Type": "json" - } + base_r.headers = {"Content-Type": "json"} base_r.status_code = 500 r = GlobusHTTPResponse(base_r) r.status_code = base_r.status_code r.headers = base_r.headers - mock_register_endpoint = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.register_endpoint') + mock_register_endpoint = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.register_endpoint" + ) mock_register_endpoint.side_effect = GlobusAPIError(r) - mock_zmq_create = mocker.patch("zmq.auth.create_certificates", - return_value=("public/key/file", None)) - mock_zmq_load = mocker.patch("zmq.auth.load_certificate", - return_value=("12345abcde".encode(), "12345abcde".encode())) + mock_zmq_create = mocker.patch( + "zmq.auth.create_certificates", return_value=("public/key/file", None) + ) + mock_zmq_load = mocker.patch( + "zmq.auth.load_certificate", + return_value=(b"12345abcde", b"12345abcde"), + ) mock_context = mocker.patch("daemon.DaemonContext") @@ -186,29 +218,37 @@ def test_start_registration_5xx_error(self, mocker): mock_context.return_value.__enter__.return_value = None mock_context.return_value.__exit__.return_value = None - mock_context.return_value.pidfile.path = '' + mock_context.return_value.pidfile.path = "" - mock_daemon = mocker.patch.object(EndpointManager, 'daemon_launch', - return_value=None) + mock_daemon = mocker.patch.object( + EndpointManager, "daemon_launch", return_value=None + ) - mock_uuid = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4') + mock_uuid = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4") mock_uuid.return_value = 123456 - mock_pidfile = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile') + mock_pidfile = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.daemon.pidfile.PIDLockFile" + ) mock_pidfile.return_value = None - mock_results_ack_handler = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler') + mock_results_ack_handler = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.ResultsAckHandler" + ) manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() manager.start_endpoint("mock_endpoint", None, endpoint_config) - mock_zmq_create.assert_called_with(os.path.join(config_dir, "certificates"), "endpoint") + mock_zmq_create.assert_called_with( + os.path.join(config_dir, "certificates"), "endpoint" + ) mock_zmq_load.assert_called_with("public/key/file") funcx_client_options = { @@ -216,29 +256,37 @@ def test_start_registration_5xx_error(self, mocker): "check_endpoint_version": True, } - # We should expect reg_info in this test to be None when passed into daemon_launch - # because a 5xx GlobusAPIError was raised during registration + # We should expect reg_info in this test to be None when passed into + # daemon_launch because a 5xx GlobusAPIError was raised during registration reg_info = None - mock_daemon.assert_called_with('123456', - config_dir, - os.path.join(config_dir, "certificates"), - endpoint_config, - reg_info, - funcx_client_options, - mock_results_ack_handler.return_value) - - mock_context.assert_called_with(working_directory=config_dir, - umask=0o002, - pidfile=None, - stdout=ANY, # open(os.path.join(config_dir, './interchange.stdout'), 'w+'), - stderr=ANY, # open(os.path.join(config_dir, './interchange.stderr'), 'w+'), - detach_process=True) + mock_daemon.assert_called_with( + "123456", + config_dir, + os.path.join(config_dir, "certificates"), + endpoint_config, + reg_info, + funcx_client_options, + mock_results_ack_handler.return_value, + ) + + mock_context.assert_called_with( + working_directory=config_dir, + umask=0o002, + pidfile=None, + stdout=ANY, # open(os.path.join(config_dir, './interchange.stdout'), 'w+'), + stderr=ANY, # open(os.path.join(config_dir, './interchange.stderr'), 'w+'), + detach_process=True, + ) def test_start_without_executors(self, mocker): - mock_client = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") - mock_client.return_value.register_endpoint.return_value = {'endpoint_id': 'abcde12345', - 'address': 'localhost', - 'client_ports': '8080'} + mock_client = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.FuncXClient" + ) + mock_client.return_value.register_endpoint.return_value = { + "endpoint_id": "abcde12345", + "address": "localhost", + "client_ports": "8080", + } mock_context = mocker.patch("daemon.DaemonContext") @@ -246,101 +294,136 @@ def test_start_without_executors(self, mocker): mock_context.return_value.__enter__.return_value = None mock_context.return_value.__exit__.return_value = None - mock_context.return_value.pidfile.path = '' + mock_context.return_value.pidfile.path = "" - class mock_load(): - class mock_executors(): + class mock_load: + class mock_executors: executors = None + config = mock_executors() manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - with pytest.raises(Exception, match=f'Endpoint config file at {config_dir} is missing executor definitions'): + with pytest.raises( + Exception, + match=f"Endpoint config file at {config_dir} is " + "missing executor definitions", + ): manager.start_endpoint("mock_endpoint", None, mock_load()) def test_daemon_launch(self, mocker): - mock_interchange = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.EndpointInterchange') + mock_interchange = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.EndpointInterchange" + ) mock_interchange.return_value.start.return_value = None mock_interchange.return_value.stop.return_value = None manager = EndpointManager(funcx_dir=os.getcwd()) - manager.name = 'test' + manager.name = "test" config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") mock_optionals = {} - mock_optionals['logdir'] = config_dir + mock_optionals["logdir"] = config_dir manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() funcx_client_options = {} - manager.daemon_launch('mock_endpoint_uuid', config_dir, 'mock_keys_dir', endpoint_config, None, funcx_client_options, None) - - mock_interchange.assert_called_with(endpoint_config.config, - endpoint_id='mock_endpoint_uuid', - keys_dir='mock_keys_dir', - endpoint_dir=config_dir, - endpoint_name=manager.name, - reg_info=None, - funcx_client_options=funcx_client_options, - results_ack_handler=None, - **mock_optionals) + manager.daemon_launch( + "mock_endpoint_uuid", + config_dir, + "mock_keys_dir", + endpoint_config, + None, + funcx_client_options, + None, + ) + + mock_interchange.assert_called_with( + endpoint_config.config, + endpoint_id="mock_endpoint_uuid", + keys_dir="mock_keys_dir", + endpoint_dir=config_dir, + endpoint_name=manager.name, + reg_info=None, + funcx_client_options=funcx_client_options, + results_ack_handler=None, + **mock_optionals, + ) def test_with_funcx_config(self, mocker): - mock_interchange = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.EndpointInterchange') + mock_interchange = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.EndpointInterchange" + ) mock_interchange.return_value.start.return_value = None mock_interchange.return_value.stop.return_value = None mock_optionals = {} - mock_optionals['interchange_address'] = '127.0.0.1' + mock_optionals["interchange_address"] = "127.0.0.1" mock_funcx_config = {} - mock_funcx_config['endpoint_address'] = '127.0.0.1' + mock_funcx_config["endpoint_address"] = "127.0.0.1" manager = EndpointManager(funcx_dir=os.getcwd()) - manager.name = 'test' + manager.name = "test" config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") - mock_optionals['logdir'] = config_dir + mock_optionals["logdir"] = config_dir manager.funcx_config = mock_funcx_config manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() funcx_client_options = {} - manager.daemon_launch('mock_endpoint_uuid', config_dir, 'mock_keys_dir', endpoint_config, None, funcx_client_options, None) - - mock_interchange.assert_called_with(endpoint_config.config, - endpoint_id='mock_endpoint_uuid', - keys_dir='mock_keys_dir', - endpoint_dir=config_dir, - endpoint_name=manager.name, - reg_info=None, - funcx_client_options=funcx_client_options, - results_ack_handler=None, - **mock_optionals) + manager.daemon_launch( + "mock_endpoint_uuid", + config_dir, + "mock_keys_dir", + endpoint_config, + None, + funcx_client_options, + None, + ) + + mock_interchange.assert_called_with( + endpoint_config.config, + endpoint_id="mock_endpoint_uuid", + keys_dir="mock_keys_dir", + endpoint_dir=config_dir, + endpoint_name=manager.name, + reg_info=None, + funcx_client_options=funcx_client_options, + results_ack_handler=None, + **mock_optionals, + ) def test_check_endpoint_json_no_json_no_uuid(self, mocker): - mock_uuid = mocker.patch('funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4') + mock_uuid = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.uuid.uuid4") mock_uuid.return_value = 123456 manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - assert '123456' == manager.check_endpoint_json(os.path.join(config_dir, 'endpoint.json'), None) + assert "123456" == manager.check_endpoint_json( + os.path.join(config_dir, "endpoint.json"), None + ) def test_check_endpoint_json_no_json_given_uuid(self, mocker): manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") manager.configure_endpoint("mock_endpoint", None) - assert '234567' == manager.check_endpoint_json(os.path.join(config_dir, 'endpoint.json'), '234567') + assert "234567" == manager.check_endpoint_json( + os.path.join(config_dir, "endpoint.json"), "234567" + ) def test_check_endpoint_json_given_json(self, mocker): manager = EndpointManager(funcx_dir=os.getcwd()) @@ -348,8 +431,10 @@ def test_check_endpoint_json_given_json(self, mocker): manager.configure_endpoint("mock_endpoint", None) - mock_dict = {'endpoint_id': 'abcde12345'} - with open(os.path.join(config_dir, 'endpoint.json'), "w") as fd: + mock_dict = {"endpoint_id": "abcde12345"} + with open(os.path.join(config_dir, "endpoint.json"), "w") as fd: json.dump(mock_dict, fd) - assert 'abcde12345' == manager.check_endpoint_json(os.path.join(config_dir, 'endpoint.json'), '234567') + assert "abcde12345" == manager.check_endpoint_json( + os.path.join(config_dir, "endpoint.json"), "234567" + ) diff --git a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_interchange.py b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_interchange.py index 0f78c8ab2..2d8f41311 100644 --- a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_interchange.py +++ b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_interchange.py @@ -1,26 +1,22 @@ -from funcx_endpoint.endpoint.endpoint_manager import EndpointManager -from funcx_endpoint.endpoint.interchange import EndpointInterchange -from funcx_endpoint.endpoint.register_endpoint import register_endpoint -from importlib.machinery import SourceFileLoader -import os import logging -import sys +import os import shutil +from importlib.machinery import SourceFileLoader + import pytest -import json -from pytest import fixture -from unittest.mock import ANY -logger = logging.getLogger('mock_funcx') +from funcx_endpoint.endpoint.endpoint_manager import EndpointManager +from funcx_endpoint.endpoint.interchange import EndpointInterchange + +logger = logging.getLogger("mock_funcx") class TestStart: - @pytest.fixture(autouse=True) def test_setup_teardown(self): # Code that will run before your test, for example: - funcx_dir = f'{os.getcwd()}' + funcx_dir = f"{os.getcwd()}" config_dir = os.path.join(funcx_dir, "mock_endpoint") assert not os.path.exists(config_dir) # A test function will be run at this point @@ -34,62 +30,72 @@ def test_endpoint_id(self, mocker): manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") - keys_dir = os.path.join(config_dir, 'certificates') + keys_dir = os.path.join(config_dir, "certificates") optionals = {} - optionals['client_address'] = '127.0.0.1' - optionals['client_ports'] = (8080, 8081, 8082) - optionals['logdir'] = './mock_endpoint' + optionals["client_address"] = "127.0.0.1" + optionals["client_ports"] = (8080, 8081, 8082) + optionals["logdir"] = "./mock_endpoint" manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() for executor in endpoint_config.config.executors: executor.passthrough = False - ic = EndpointInterchange(endpoint_config.config, - endpoint_id='mock_endpoint_id', - keys_dir=keys_dir, - **optionals) + ic = EndpointInterchange( + endpoint_config.config, + endpoint_id="mock_endpoint_id", + keys_dir=keys_dir, + **optionals, + ) for executor in ic.executors.values(): - assert executor.endpoint_id == 'mock_endpoint_id' + assert executor.endpoint_id == "mock_endpoint_id" def test_register_endpoint(self, mocker): mock_client = mocker.patch("funcx_endpoint.endpoint.interchange.FuncXClient") mock_client.return_value = None - mock_register_endpoint = mocker.patch('funcx_endpoint.endpoint.interchange.register_endpoint') - mock_register_endpoint.return_value = {'endpoint_id': 'abcde12345', - 'public_ip': '127.0.0.1', - 'tasks_port': 8080, - 'results_port': 8081, - 'commands_port': 8082, } + mock_register_endpoint = mocker.patch( + "funcx_endpoint.endpoint.interchange.register_endpoint" + ) + mock_register_endpoint.return_value = { + "endpoint_id": "abcde12345", + "public_ip": "127.0.0.1", + "tasks_port": 8080, + "results_port": 8081, + "commands_port": 8082, + } manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") - keys_dir = os.path.join(config_dir, 'certificates') + keys_dir = os.path.join(config_dir, "certificates") optionals = {} - optionals['client_address'] = '127.0.0.1' - optionals['client_ports'] = (8080, 8081, 8082) - optionals['logdir'] = './mock_endpoint' + optionals["client_address"] = "127.0.0.1" + optionals["client_ports"] = (8080, 8081, 8082) + optionals["logdir"] = "./mock_endpoint" manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() for executor in endpoint_config.config.executors: executor.passthrough = False - ic = EndpointInterchange(endpoint_config.config, - endpoint_id='mock_endpoint_id', - keys_dir=keys_dir, - **optionals) + ic = EndpointInterchange( + endpoint_config.config, + endpoint_id="mock_endpoint_id", + keys_dir=keys_dir, + **optionals, + ) ic.register_endpoint() - assert ic.client_address == '127.0.0.1' + assert ic.client_address == "127.0.0.1" assert ic.client_ports == (8080, 8081, 8082) def test_start_no_reg_info(self, mocker): @@ -100,38 +106,47 @@ def test_start_no_reg_info(self, mocker): mock_client = mocker.patch("funcx_endpoint.endpoint.interchange.FuncXClient") mock_client.return_value = None - mock_register_endpoint = mocker.patch('funcx_endpoint.endpoint.interchange.register_endpoint') - mock_register_endpoint.return_value = {'endpoint_id': 'abcde12345', - 'public_ip': '127.0.0.1', - 'tasks_port': 8080, - 'results_port': 8081, - 'commands_port': 8082, } + mock_register_endpoint = mocker.patch( + "funcx_endpoint.endpoint.interchange.register_endpoint" + ) + mock_register_endpoint.return_value = { + "endpoint_id": "abcde12345", + "public_ip": "127.0.0.1", + "tasks_port": 8080, + "results_port": 8081, + "commands_port": 8082, + } manager = EndpointManager(funcx_dir=os.getcwd()) config_dir = os.path.join(manager.funcx_dir, "mock_endpoint") - keys_dir = os.path.join(config_dir, 'certificates') + keys_dir = os.path.join(config_dir, "certificates") optionals = {} - optionals['client_address'] = '127.0.0.1' - optionals['client_ports'] = (8080, 8081, 8082) - optionals['logdir'] = './mock_endpoint' + optionals["client_address"] = "127.0.0.1" + optionals["client_ports"] = (8080, 8081, 8082) + optionals["logdir"] = "./mock_endpoint" manager.configure_endpoint("mock_endpoint", None) - endpoint_config = SourceFileLoader('config', - os.path.join(config_dir, 'config.py')).load_module() + endpoint_config = SourceFileLoader( + "config", os.path.join(config_dir, "config.py") + ).load_module() for executor in endpoint_config.config.executors: executor.passthrough = False - mock_quiesce = mocker.patch.object(EndpointInterchange, 'quiesce', - return_value=None) - mock_main_loop = mocker.patch.object(EndpointInterchange, '_main_loop', - return_value=None) - - ic = EndpointInterchange(endpoint_config.config, - endpoint_id='mock_endpoint_id', - keys_dir=keys_dir, - **optionals) + mock_quiesce = mocker.patch.object( + EndpointInterchange, "quiesce", return_value=None + ) + mock_main_loop = mocker.patch.object( + EndpointInterchange, "_main_loop", return_value=None + ) + + ic = EndpointInterchange( + endpoint_config.config, + endpoint_id="mock_endpoint_id", + keys_dir=keys_dir, + **optionals, + ) ic.results_outgoing = mocker.Mock() diff --git a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_register_endpoint.py b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_register_endpoint.py index 39cedb217..84386a2eb 100644 --- a/funcx_endpoint/tests/funcx_endpoint/endpoint/test_register_endpoint.py +++ b/funcx_endpoint/tests/funcx_endpoint/endpoint/test_register_endpoint.py @@ -1,23 +1,20 @@ -from funcx_endpoint.endpoint.register_endpoint import register_endpoint -import os import logging -import sys +import os import shutil + import pytest -import json -from pytest import fixture -from unittest.mock import ANY -logger = logging.getLogger('mock_funcx') +from funcx_endpoint.endpoint.register_endpoint import register_endpoint +logger = logging.getLogger("mock_funcx") -class TestRegisterEndpoint: +class TestRegisterEndpoint: @pytest.fixture(autouse=True) def test_setup_teardown(self): # Code that will run before your test, for example: - funcx_dir = f'{os.getcwd()}' + funcx_dir = f"{os.getcwd()}" config_dir = os.path.join(funcx_dir, "mock_endpoint") assert not os.path.exists(config_dir) # A test function will be run at this point @@ -27,22 +24,34 @@ def test_setup_teardown(self): shutil.rmtree(config_dir) def test_register_endpoint_no_endpoint_id(self, mocker): - mock_client = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") - mock_client.return_value.register_endpoint.return_value = {'status': 'okay'} + mock_client = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.FuncXClient" + ) + mock_client.return_value.register_endpoint.return_value = {"status": "okay"} funcx_dir = os.getcwd() config_dir = os.path.join(funcx_dir, "mock_endpoint") - with pytest.raises(Exception, match='Endpoint ID was not included in the service\'s registration response.'): - register_endpoint(mock_client(), 'mock_endpoint_uuid', config_dir, 'test') + with pytest.raises( + Exception, + match="Endpoint ID was not included in the service's " + "registration response.", + ): + register_endpoint(mock_client(), "mock_endpoint_uuid", config_dir, "test") def test_register_endpoint_int_endpoint_id(self, mocker): - mock_client = mocker.patch("funcx_endpoint.endpoint.endpoint_manager.FuncXClient") - mock_client.return_value.register_endpoint.return_value = {'status': 'okay', - 'endpoint_id': 123456} + mock_client = mocker.patch( + "funcx_endpoint.endpoint.endpoint_manager.FuncXClient" + ) + mock_client.return_value.register_endpoint.return_value = { + "status": "okay", + "endpoint_id": 123456, + } funcx_dir = os.getcwd() config_dir = os.path.join(funcx_dir, "mock_endpoint") - with pytest.raises(Exception, match='Endpoint ID sent by the service was not a string.'): - register_endpoint(mock_client(), 'mock_endpoint_uuid', config_dir, 'test') + with pytest.raises( + Exception, match="Endpoint ID sent by the service was not a string." + ): + register_endpoint(mock_client(), "mock_endpoint_uuid", config_dir, "test") diff --git a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_manager.py b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_manager.py index 7b59a77cd..7e42a2062 100644 --- a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_manager.py +++ b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_manager.py @@ -1,27 +1,28 @@ -from funcx_endpoint.executors.high_throughput.funcx_manager import Manager -from funcx_endpoint.executors.high_throughput.messages import Task -import queue -import logging -import pickle -import zmq import os +import pickle +import queue import shutil + import pytest +from funcx_endpoint.executors.high_throughput.funcx_manager import Manager +from funcx_endpoint.executors.high_throughput.messages import Task -class TestManager: +class TestManager: @pytest.fixture(autouse=True) def test_setup_teardown(self): - os.makedirs(os.path.join(os.getcwd(), 'mock_uid')) + os.makedirs(os.path.join(os.getcwd(), "mock_uid")) yield - shutil.rmtree(os.path.join(os.getcwd(), 'mock_uid')) + shutil.rmtree(os.path.join(os.getcwd(), "mock_uid")) def test_remove_worker_init(self, mocker): # zmq is being mocked here because it was making tests hang - mocker.patch('funcx_endpoint.executors.high_throughput.funcx_manager.zmq.Context') + mocker.patch( + "funcx_endpoint.executors.high_throughput.funcx_manager.zmq.Context" + ) - manager = Manager(logdir='./', uid="mock_uid") + manager = Manager(logdir="./", uid="mock_uid") manager.worker_map.to_die_count["RAW"] = 0 manager.task_queues["RAW"] = queue.Queue() @@ -33,21 +34,33 @@ def test_remove_worker_init(self, mocker): def test_poll_funcx_task_socket(self, mocker): # zmq is being mocked here because it was making tests hang - mocker.patch('funcx_endpoint.executors.high_throughput.funcx_manager.zmq.Context') + mocker.patch( + "funcx_endpoint.executors.high_throughput.funcx_manager.zmq.Context" + ) - mock_worker_map = mocker.patch('funcx_endpoint.executors.high_throughput.funcx_manager.WorkerMap') + mock_worker_map = mocker.patch( + "funcx_endpoint.executors.high_throughput.funcx_manager.WorkerMap" + ) - manager = Manager(logdir='./', uid="mock_uid") + manager = Manager(logdir="./", uid="mock_uid") manager.task_queues["RAW"] = queue.Queue() manager.logdir = "./" - manager.worker_type = 'RAW' - manager.worker_procs['0'] = 'proc' + manager.worker_type = "RAW" + manager.worker_procs["0"] = "proc" - manager.funcx_task_socket.recv_multipart.return_value = b'0', b'REGISTER', pickle.dumps({'worker_type': 'RAW'}) + manager.funcx_task_socket.recv_multipart.return_value = ( + b"0", + b"REGISTER", + pickle.dumps({"worker_type": "RAW"}), + ) manager.poll_funcx_task_socket(test=True) - mock_worker_map.return_value.register_worker.assert_called_with(b'0', 'RAW') + mock_worker_map.return_value.register_worker.assert_called_with(b"0", "RAW") - manager.funcx_task_socket.recv_multipart.return_value = b'0', b'WRKR_DIE', pickle.dumps(None) + manager.funcx_task_socket.recv_multipart.return_value = ( + b"0", + b"WRKR_DIE", + pickle.dumps(None), + ) manager.poll_funcx_task_socket(test=True) - mock_worker_map.return_value.remove_worker.assert_called_with(b'0') + mock_worker_map.return_value.remove_worker.assert_called_with(b"0") assert len(manager.worker_procs) == 0 diff --git a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_worker.py b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_worker.py index 87e6f17e0..2f562b872 100644 --- a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_worker.py +++ b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_funcx_worker.py @@ -1,32 +1,37 @@ -from funcx_endpoint.executors.high_throughput.funcx_worker import FuncXWorker -from funcx_endpoint.executors.high_throughput.messages import Task import os import pickle +from funcx_endpoint.executors.high_throughput.funcx_worker import FuncXWorker +from funcx_endpoint.executors.high_throughput.messages import Task + class TestWorker: def test_register_and_kill(self, mocker): # we need to mock sys.exit here so that the worker while loop # can exit without the test being killed - mocker.patch('funcx_endpoint.executors.high_throughput.funcx_worker.sys.exit') + mocker.patch("funcx_endpoint.executors.high_throughput.funcx_worker.sys.exit") - mock_context = mocker.patch('funcx_endpoint.executors.high_throughput.funcx_worker.zmq.Context') + mock_context = mocker.patch( + "funcx_endpoint.executors.high_throughput.funcx_worker.zmq.Context" + ) # the worker will receive tasks and send messages on this mock socket mock_socket = mocker.Mock() mock_context.return_value.socket.return_value = mock_socket # send a kill message on the mock socket - task = Task(task_id='KILL', - container_id='RAW', - task_buffer='KILL') - mock_socket.recv_multipart.return_value = (pickle.dumps("KILL"), pickle.dumps("abc"), task.pack()) + task = Task(task_id="KILL", container_id="RAW", task_buffer="KILL") + mock_socket.recv_multipart.return_value = ( + pickle.dumps("KILL"), + pickle.dumps("abc"), + task.pack(), + ) # calling worker.start begins a while loop, where first a REGISTER # message is sent out, then the worker receives the KILL task, which # triggers a WRKR_DIE message to be sent before the while loop exits - worker = FuncXWorker('0', '127.0.0.1', 50001, os.getcwd()) + worker = FuncXWorker("0", "127.0.0.1", 50001, os.getcwd()) worker.start() # these 2 calls to send_multipart happen in a sequence - call1 = mocker.call([b'REGISTER', pickle.dumps(worker.registration_message())]) - call2 = mocker.call([b'WRKR_DIE', pickle.dumps(None)]) + call1 = mocker.call([b"REGISTER", pickle.dumps(worker.registration_message())]) + call2 = mocker.call([b"WRKR_DIE", pickle.dumps(None)]) mock_socket.send_multipart.assert_has_calls([call1, call2]) diff --git a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_worker_map.py b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_worker_map.py index 8712d3f49..53e9c2d2a 100644 --- a/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_worker_map.py +++ b/funcx_endpoint/tests/funcx_endpoint/executors/high_throughput/test_worker_map.py @@ -1,21 +1,26 @@ -from funcx_endpoint.executors.high_throughput.worker_map import WorkerMap import logging import os +from funcx_endpoint.executors.high_throughput.worker_map import WorkerMap + class TestWorkerMap: def test_add_worker(self, mocker): - mock_popen = mocker.patch('funcx_endpoint.executors.high_throughput.worker_map.subprocess.Popen') - mock_popen.return_value = 'proc' + mock_popen = mocker.patch( + "funcx_endpoint.executors.high_throughput.worker_map.subprocess.Popen" + ) + mock_popen.return_value = "proc" worker_map = WorkerMap(1) - worker = worker_map.add_worker(worker_id='0', - address='127.0.0.1', - debug=logging.DEBUG, - uid='test1', - logdir=os.getcwd(), - worker_port=50001) + worker = worker_map.add_worker( + worker_id="0", + address="127.0.0.1", + debug=logging.DEBUG, + uid="test1", + logdir=os.getcwd(), + worker_port=50001, + ) - assert list(worker.keys()) == ['0'] - assert worker['0'] == 'proc' + assert list(worker.keys()) == ["0"] + assert worker["0"] == "proc" assert worker_map.worker_id_counter == 1 diff --git a/funcx_endpoint/tests/integration/test_batch_submit.py b/funcx_endpoint/tests/integration/test_batch_submit.py index 6098a4986..c49002991 100644 --- a/funcx_endpoint/tests/integration/test_batch_submit.py +++ b/funcx_endpoint/tests/integration/test_batch_submit.py @@ -1,10 +1,10 @@ -import json -import sys import argparse import time + import funcx from funcx.sdk.client import FuncXClient from funcx.serialize import FuncXSerializer + fxs = FuncXSerializer() # funcx.set_stream_logger() @@ -16,35 +16,38 @@ def double(x): def test(fxc, ep_id, task_count=10): - fn_uuid = fxc.register_function(double, - description="Yadu double") + fn_uuid = fxc.register_function(double, description="Yadu double") print("FN_UUID : ", fn_uuid) start = time.time() - task_ids = fxc.map_run(list(range(task_count)), endpoint_id=ep_id, function_id=fn_uuid) + task_ids = fxc.map_run( + list(range(task_count)), endpoint_id=ep_id, function_id=fn_uuid + ) delta = time.time() - start - print("Time to launch {} tasks: {:8.3f} s".format(task_count, delta)) - print("Got {} tasks_ids ".format(len(task_ids))) + print(f"Time to launch {task_count} tasks: {delta:8.3f} s") + print(f"Got {len(task_ids)} tasks_ids ") for _i in range(3): x = fxc.get_batch_result(task_ids) - complete_count = sum([1 for t in task_ids if t in x and x[t].get('pending', False)]) - print("Batch status : {}/{} complete".format(complete_count, len(task_ids))) + complete_count = sum( + [1 for t in task_ids if t in x and x[t].get("pending", False)] + ) + print(f"Batch status : {complete_count}/{len(task_ids)} complete") if complete_count == len(task_ids): break time.sleep(2) delta = time.time() - start - print("Time to complete {} tasks: {:8.3f} s".format(task_count, delta)) - print("Throughput : {:8.3f} Tasks/s".format(task_count / delta)) + print(f"Time to complete {task_count} tasks: {delta:8.3f} s") + print(f"Throughput : {task_count / delta:8.3f} Tasks/s") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-e", "--endpoint", required=True) parser.add_argument("-c", "--count", default="10") args = parser.parse_args() print("FuncX version : ", funcx.__version__) - fxc = FuncXClient(funcx_service_address='https://dev.funcx.org/api/v1') + fxc = FuncXClient(funcx_service_address="https://dev.funcx.org/api/v1") test(fxc, args.endpoint, task_count=int(args.count)) diff --git a/funcx_endpoint/tests/integration/test_config.py b/funcx_endpoint/tests/integration/test_config.py index 80f3c6a42..47ed5c71f 100644 --- a/funcx_endpoint/tests/integration/test_config.py +++ b/funcx_endpoint/tests/integration/test_config.py @@ -1,12 +1,13 @@ -from funcx_endpoint.endpoint.utils.config import Config +import logging import os + import funcx -import logging +from funcx_endpoint.endpoint.utils.config import Config config = Config() -if __name__ == '__main__': +if __name__ == "__main__": funcx.set_stream_logger() logger = logging.getLogger(__file__) @@ -18,43 +19,50 @@ print("Loading : ", config) # Set script dir config.provider.script_dir = working_dir - config.provider.channel.script_dir = os.path.join(working_dir, 'submit_scripts') + config.provider.channel.script_dir = os.path.join(working_dir, "submit_scripts") config.provider.channel.makedirs(config.provider.channel.script_dir, exist_ok=True) os.makedirs(config.provider.script_dir, exist_ok=True) debug_opts = "--debug" if config.worker_debug else "" - max_workers = "" if config.max_workers_per_node == float('inf') \ - else "--max_workers={}".format(config.max_workers_per_node) + max_workers = ( + "" + if config.max_workers_per_node == float("inf") + else f"--max_workers={config.max_workers_per_node}" + ) worker_task_url = "tcp://127.0.0.1:54400" worker_result_url = "tcp://127.0.0.1:54401" - launch_cmd = ("funcx-worker {debug} {max_workers} " - "-c {cores_per_worker} " - "--poll {poll_period} " - "--task_url={task_url} " - "--result_url={result_url} " - "--logdir={logdir} " - "--hb_period={heartbeat_period} " - "--hb_threshold={heartbeat_threshold} " - "--mode={worker_mode} " - "--container_image={container_image} ") - - l_cmd = launch_cmd.format(debug=debug_opts, - max_workers=max_workers, - cores_per_worker=config.cores_per_worker, - prefetch_capacity=config.prefetch_capacity, - task_url=worker_task_url, - result_url=worker_result_url, - nodes_per_block=config.provider.nodes_per_block, - heartbeat_period=config.heartbeat_period, - heartbeat_threshold=config.heartbeat_threshold, - poll_period=config.poll_period, - worker_mode=config.worker_mode, - container_image=None, - logdir=working_dir) + launch_cmd = ( + "funcx-worker {debug} {max_workers} " + "-c {cores_per_worker} " + "--poll {poll_period} " + "--task_url={task_url} " + "--result_url={result_url} " + "--logdir={logdir} " + "--hb_period={heartbeat_period} " + "--hb_threshold={heartbeat_threshold} " + "--mode={worker_mode} " + "--container_image={container_image} " + ) + + l_cmd = launch_cmd.format( + debug=debug_opts, + max_workers=max_workers, + cores_per_worker=config.cores_per_worker, + prefetch_capacity=config.prefetch_capacity, + task_url=worker_task_url, + result_url=worker_result_url, + nodes_per_block=config.provider.nodes_per_block, + heartbeat_period=config.heartbeat_period, + heartbeat_threshold=config.heartbeat_threshold, + poll_period=config.poll_period, + worker_mode=config.worker_mode, + container_image=None, + logdir=working_dir, + ) config.launch_cmd = l_cmd - print("Launch command: {}".format(config.launch_cmd)) + print(f"Launch command: {config.launch_cmd}") if config.scaling_enabled: print("About to scale things") diff --git a/funcx_endpoint/tests/integration/test_containers.py b/funcx_endpoint/tests/integration/test_containers.py index a1bf16da9..0ee5cf4be 100644 --- a/funcx_endpoint/tests/integration/test_containers.py +++ b/funcx_endpoint/tests/integration/test_containers.py @@ -1,7 +1,5 @@ -import json -import sys import argparse -import time + import funcx from funcx.sdk.client import FuncXClient @@ -12,9 +10,11 @@ def container_sum(event): def test(fxc, ep_id): - fn_uuid = fxc.register_function(container_sum, - container_uuid='3861862b-152e-49a4-b15e-9a5da4205cad', - description="New sum function defined without string spec") + fn_uuid = fxc.register_function( + container_sum, + container_uuid="3861862b-152e-49a4-b15e-9a5da4205cad", + description="New sum function defined without string spec", + ) print("FN_UUID : ", fn_uuid) task_id = fxc.run([1, 2, 3, 9001], endpoint_id=ep_id, function_id=fn_uuid) @@ -22,7 +22,7 @@ def test(fxc, ep_id): print("Got from status :", r) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-e", "--endpoint", required=True) args = parser.parse_args() diff --git a/funcx_endpoint/tests/integration/test_deserialization.py b/funcx_endpoint/tests/integration/test_deserialization.py index 36edb91f0..3e9c9b382 100644 --- a/funcx_endpoint/tests/integration/test_deserialization.py +++ b/funcx_endpoint/tests/integration/test_deserialization.py @@ -1,6 +1,7 @@ -from funcx.serialize import FuncXSerializer import numpy as np +from funcx.serialize import FuncXSerializer + def double(x, y=3): return x * y diff --git a/funcx_endpoint/tests/integration/test_executor.py b/funcx_endpoint/tests/integration/test_executor.py index 59686d2d0..0fe63d01e 100644 --- a/funcx_endpoint/tests/integration/test_executor.py +++ b/funcx_endpoint/tests/integration/test_executor.py @@ -1,6 +1,4 @@ from funcx_endpoint.executors.high_throughput.executor import HighThroughputExecutor -import logging -from funcx import set_file_logger def double(x): diff --git a/funcx_endpoint/tests/integration/test_executor_passthrough.py b/funcx_endpoint/tests/integration/test_executor_passthrough.py index 842f5e51f..b7f0a1bc2 100644 --- a/funcx_endpoint/tests/integration/test_executor_passthrough.py +++ b/funcx_endpoint/tests/integration/test_executor_passthrough.py @@ -1,13 +1,12 @@ -from funcx_endpoint.executors.high_throughput.executor import HighThroughputExecutor -import logging -from funcx import set_file_logger -import uuid -from funcx.serialize import FuncXSerializer -from funcx_endpoint.executors.high_throughput.messages import Message, Task -import time import pickle +import time +import uuid from multiprocessing import Queue +from funcx.serialize import FuncXSerializer +from funcx_endpoint.executors.high_throughput.executor import HighThroughputExecutor +from funcx_endpoint.executors.high_throughput.messages import Task + def double(x): return x * 2 @@ -17,8 +16,7 @@ def double(x): results_queue = Queue() # set_file_logger('executor.log', name='funcx_endpoint', level=logging.DEBUG) - htex = HighThroughputExecutor(interchange_local=True, - passthrough=True) + htex = HighThroughputExecutor(interchange_local=True, passthrough=True) htex.start(results_passthrough=results_queue) htex._start_remote_interchange_process() @@ -31,12 +29,11 @@ def double(x): fn_code = fx_serializer.serialize(double) ser_code = fx_serializer.pack_buffers([fn_code]) - ser_params = fx_serializer.pack_buffers([fx_serializer.serialize(args), - fx_serializer.serialize(kwargs)]) + ser_params = fx_serializer.pack_buffers( + [fx_serializer.serialize(args), fx_serializer.serialize(kwargs)] + ) - payload = Task(task_id, - 'RAW', - ser_code + ser_params) + payload = Task(task_id, "RAW", ser_code + ser_params) f = htex.submit_raw(payload.pack()) time.sleep(0.5) @@ -44,7 +41,7 @@ def double(x): result_package = results_queue.get() # print("Result package : ", result_package) r = pickle.loads(result_package) - result = fx_serializer.deserialize(r['result']) + result = fx_serializer.deserialize(r["result"]) print(f"Result:{i}: {result}") print("All done") diff --git a/funcx_endpoint/tests/integration/test_interchange.py b/funcx_endpoint/tests/integration/test_interchange.py index 6a697026d..967510b6e 100644 --- a/funcx_endpoint/tests/integration/test_interchange.py +++ b/funcx_endpoint/tests/integration/test_interchange.py @@ -1,22 +1,21 @@ import argparse -from funcx_endpoint.endpoint.utils.config import Config -from funcx_endpoint.executors.high_throughput.interchange import Interchange import funcx +from funcx_endpoint.endpoint.utils.config import Config +from funcx_endpoint.executors.high_throughput.interchange import Interchange funcx.set_stream_logger() -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-a", "--address", required=True, - help="Address") - parser.add_argument("-c", "--client_ports", required=True, - help="ports") + parser.add_argument("-a", "--address", required=True, help="Address") + parser.add_argument("-c", "--client_ports", required=True, help="ports") args = parser.parse_args() config = Config() - ic = Interchange(client_address=args.address, - client_ports=[int(i) for i in args.client_ports.split(',')], - ) + ic = Interchange( + client_address=args.address, + client_ports=[int(i) for i in args.client_ports.split(",")], + ) ic.start() print("Interchange started") diff --git a/funcx_endpoint/tests/integration/test_per_func_batch.py b/funcx_endpoint/tests/integration/test_per_func_batch.py index 3103ba1e6..774696a8b 100644 --- a/funcx_endpoint/tests/integration/test_per_func_batch.py +++ b/funcx_endpoint/tests/integration/test_per_func_batch.py @@ -35,13 +35,13 @@ def test_batch3(a, b, c=2, d=2): task_ids = fx.batch_run(batch) delta = time.time() - start -print("Time to launch {} tasks: {:8.3f} s".format(task_count * len(func_ids), delta)) -print("Got {} tasks_ids ".format(len(task_ids))) +print(f"Time to launch {task_count * len(func_ids)} tasks: {delta:8.3f} s") +print(f"Got {len(task_ids)} tasks_ids ") for _i in range(10): x = fx.get_batch_result(task_ids) complete_count = sum([1 for t in task_ids if t in x and x[t].get("pending", False)]) - print("Batch status : {}/{} complete".format(complete_count, len(task_ids))) + print(f"Batch status : {complete_count}/{len(task_ids)} complete") if complete_count == len(task_ids): print(x) break diff --git a/funcx_endpoint/tests/integration/test_registration.py b/funcx_endpoint/tests/integration/test_registration.py index 1d69c4d72..e12f750e6 100644 --- a/funcx_endpoint/tests/integration/test_registration.py +++ b/funcx_endpoint/tests/integration/test_registration.py @@ -1,9 +1,8 @@ from funcx.sdk.client import FuncXClient - if __name__ == "__main__": fxc = FuncXClient() print(fxc) - fxc.register_endpoint('foobar', None) + fxc.register_endpoint("foobar", None) diff --git a/funcx_endpoint/tests/integration/test_serialization.py b/funcx_endpoint/tests/integration/test_serialization.py index b93dc30c9..bf2101161 100644 --- a/funcx_endpoint/tests/integration/test_serialization.py +++ b/funcx_endpoint/tests/integration/test_serialization.py @@ -8,7 +8,7 @@ def foo(x, y=3): def test_1(): jb = concretes.json_base64() - d = jb.serialize(([2], {'y': 10})) + d = jb.serialize(([2], {"y": 10})) args, kwargs = jb.deserialize(d) result = foo(*args, **kwargs) print(result) @@ -22,7 +22,7 @@ def test_2(): fn = jb.deserialize(f) print(fn) - assert fn(2) == 6, "Expected 6 got {}".format(fn(2)) + assert fn(2) == 6, f"Expected 6 got {fn(2)}" def test_code_1(): @@ -79,6 +79,7 @@ def bar(x, y=5): def test_overall(): from funcx.serialize.facade import FuncXSerializer + fxs = FuncXSerializer() print(fxs._list_methods()) @@ -87,7 +88,7 @@ def test_overall(): print(fxs.deserialize(x)) -if __name__ == '__main__': +if __name__ == "__main__": # test_1() # test_2() diff --git a/funcx_endpoint/tests/integration/test_status.py b/funcx_endpoint/tests/integration/test_status.py index 441665426..d406009d5 100644 --- a/funcx_endpoint/tests/integration/test_status.py +++ b/funcx_endpoint/tests/integration/test_status.py @@ -23,8 +23,9 @@ def sum_yadu_new01(event): def test(fxc, ep_id): - fn_uuid = fxc.register_function(sum_yadu_new01, - description="New sum function defined without string spec") + fn_uuid = fxc.register_function( + sum_yadu_new01, description="New sum function defined without string spec" + ) print("FN_UUID : ", fn_uuid) task_id = fxc.run([1, 2, 3, 9001], endpoint_id=ep_id, function_id=fn_uuid) @@ -34,6 +35,7 @@ def test(fxc, ep_id): def platinfo(): import platform + return platform.uname() @@ -42,23 +44,21 @@ def div_by_zero(x): def test2(fxc, ep_id): - fn_uuid = fxc.register_function(platinfo, - description="Get platform info") + fn_uuid = fxc.register_function(platinfo, description="Get platform info") print("FN_UUID : ", fn_uuid) task_id = fxc.run(endpoint_id=ep_id, function_id=fn_uuid) time.sleep(2) r = fxc.get_task_status(task_id) - if 'details' in r: - s_buf = r['details']['result'] + if "details" in r: + s_buf = r["details"]["result"] print("Result : ", fxs.deserialize(s_buf)) else: print("Got from status :", r) def test3(fxc, ep_id): - fn_uuid = fxc.register_function(div_by_zero, - description="Div by zero") + fn_uuid = fxc.register_function(div_by_zero, description="Div by zero") print("FN_UUID : ", fn_uuid) task_id = fxc.run(1099, endpoint_id=ep_id, function_id=fn_uuid) @@ -70,8 +70,7 @@ def test3(fxc, ep_id): def test4(fxc, ep_id): - fn_uuid = fxc.register_function(platinfo, - description="Get platform info") + fn_uuid = fxc.register_function(platinfo, description="Get platform info") print("FN_UUID : ", fn_uuid) task_id = fxc.run(endpoint_id=ep_id, function_id=fn_uuid) @@ -80,7 +79,7 @@ def test4(fxc, ep_id): print("Got result : ", r) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-e", "--endpoint", required=True) args = parser.parse_args() diff --git a/funcx_endpoint/tests/integration/test_submits.py b/funcx_endpoint/tests/integration/test_submits.py index 9b7eb45d3..db1f67f5a 100644 --- a/funcx_endpoint/tests/integration/test_submits.py +++ b/funcx_endpoint/tests/integration/test_submits.py @@ -1,5 +1,3 @@ -import json -import sys import argparse from funcx.sdk.client import FuncXClient @@ -29,16 +27,18 @@ def sum_yadu_new01(event): def test(fxc, ep_id): - fn_uuid = fxc.register_function(sum_yadu_new01, - ep_id, # TODO: We do not need ep id here - description="New sum function defined without string spec") + fn_uuid = fxc.register_function( + sum_yadu_new01, + ep_id, # TODO: We do not need ep id here + description="New sum function defined without string spec", + ) print("FN_UUID : ", fn_uuid) res = fxc.run([1, 2, 3, 99], endpoint_id=ep_id, function_id=fn_uuid) print(res) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-e", "--endpoint", required=True) args = parser.parse_args() diff --git a/funcx_endpoint/tests/integration/test_throttling.py b/funcx_endpoint/tests/integration/test_throttling.py index 4729a637c..96dc5d36a 100644 --- a/funcx_endpoint/tests/integration/test_throttling.py +++ b/funcx_endpoint/tests/integration/test_throttling.py @@ -6,44 +6,47 @@ pytest test_throttling.py """ -import pytest -import globus_sdk from unittest.mock import Mock -from funcx.sdk.utils.throttling import (ThrottledBaseClient, - MaxRequestSizeExceeded, - MaxRequestsExceeded) +import globus_sdk +import pytest + +from funcx.sdk.utils.throttling import ( + MaxRequestsExceeded, + MaxRequestSizeExceeded, + ThrottledBaseClient, +) @pytest.fixture def mock_globus_sdk(monkeypatch): - monkeypatch.setattr(globus_sdk.base.BaseClient, '__init__', Mock()) + monkeypatch.setattr(globus_sdk.base.BaseClient, "__init__", Mock()) def test_size_throttling_on_small_requests(mock_globus_sdk): cli = ThrottledBaseClient() # Should not raise - jb = {'not': 'big enough'} - cli.throttle_request_size('POST', '/my_rest_endpoint', json_body=jb) + jb = {"not": "big enough"} + cli.throttle_request_size("POST", "/my_rest_endpoint", json_body=jb) # Should not raise for these methods - cli.throttle_request_size('GET', '/my_rest_endpoint') - cli.throttle_request_size('PUT', '/my_rest_endpoint') - cli.throttle_request_size('DELETE', '/my_rest_endpoint') + cli.throttle_request_size("GET", "/my_rest_endpoint") + cli.throttle_request_size("PUT", "/my_rest_endpoint") + cli.throttle_request_size("DELETE", "/my_rest_endpoint") def test_size_throttle_on_large_request(mock_globus_sdk): cli = ThrottledBaseClient() # Test with ~2mb sized POST - jb = {'is': 'l' + 'o' * 2 * 2 ** 20 + 'ng'} + jb = {"is": "l" + "o" * 2 * 2 ** 20 + "ng"} with pytest.raises(MaxRequestSizeExceeded): - cli.throttle_request_size('POST', '/my_rest_endpoint', json_body=jb) + cli.throttle_request_size("POST", "/my_rest_endpoint", json_body=jb) # Test on text request - data = 'B' + 'i' * 2 * 2 ** 20 + 'gly' + data = "B" + "i" * 2 * 2 ** 20 + "gly" with pytest.raises(MaxRequestSizeExceeded): - cli.throttle_request_size('POST', '/my_rest_endpoint', text_body=data) + cli.throttle_request_size("POST", "/my_rest_endpoint", text_body=data) def test_low_threshold_requests_does_not_raise(mock_globus_sdk): diff --git a/funcx_endpoint/tests/tutorial_ep/test_tutotial_ep.py b/funcx_endpoint/tests/tutorial_ep/test_tutotial_ep.py index b92b9224c..b9ed0123d 100644 --- a/funcx_endpoint/tests/tutorial_ep/test_tutotial_ep.py +++ b/funcx_endpoint/tests/tutorial_ep/test_tutotial_ep.py @@ -1,9 +1,11 @@ -import time -import logging import argparse -import sys import copy -from globus_sdk import ConfidentialAppAuthClient, AccessTokenAuthorizer +import logging +import sys +import time + +from globus_sdk import AccessTokenAuthorizer, ConfidentialAppAuthClient + from funcx.sdk.client import FuncXClient @@ -11,11 +13,20 @@ def identity(x): return x -class TestTutorial(): - - def __init__(self, fx_auth, search_auth, openid_auth, - endpoint_id, func, expected, - args=None, timeout=15, concurrency=1, tol=1e-5): +class TestTutorial: + def __init__( + self, + fx_auth, + search_auth, + openid_auth, + endpoint_id, + func, + expected, + args=None, + timeout=15, + concurrency=1, + tol=1e-5, + ): self.endpoint_id = endpoint_id self.func = func self.expected = expected @@ -23,16 +34,20 @@ def __init__(self, fx_auth, search_auth, openid_auth, self.timeout = timeout self.concurrency = concurrency self.tol = tol - self.fxc = FuncXClient(fx_authorizer=fx_auth, - search_authorizer=search_auth, - openid_authorizer=openid_auth) + self.fxc = FuncXClient( + fx_authorizer=fx_auth, + search_authorizer=search_auth, + openid_authorizer=openid_auth, + ) self.func_uuid = self.fxc.register_function(self.func) self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.DEBUG) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.DEBUG) - formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)d [%(levelname)s] %(message)s") + formatter = logging.Formatter( + "%(asctime)s %(name)s:%(lineno)d [%(levelname)s] %(message)s" + ) handler.setFormatter(formatter) self.logger.addHandler(handler) @@ -40,14 +55,18 @@ def run(self): try: submissions = [] for _ in range(self.concurrency): - task = self.fxc.run(self.args, endpoint_id=self.endpoint_id, function_id=self.func_uuid) + task = self.fxc.run( + self.args, endpoint_id=self.endpoint_id, function_id=self.func_uuid + ) submissions.append(task) time.sleep(self.timeout) unfinished = copy.deepcopy(submissions) while True: - unfinished[:] = [task for task in unfinished if self.fxc.get_task(task)['pending']] + unfinished[:] = [ + task for task in unfinished if self.fxc.get_task(task)["pending"] + ] if not unfinished: break time.sleep(self.timeout) @@ -56,45 +75,53 @@ def run(self): for task in submissions: result = self.fxc.get_result(task) if abs(result - self.expected) > self.tol: - self.logger.exception(f'Difference for task {task}. ' - f'Returned: {result}, Expected: {self.expected}') + self.logger.exception( + f"Difference for task {task}. " + f"Returned: {result}, Expected: {self.expected}" + ) else: success += 1 - self.logger.info(f'{success}/{self.concurrency} tasks completed successfully') + self.logger.info( + f"{success}/{self.concurrency} tasks completed successfully" + ) except KeyboardInterrupt: - self.logger.info('Cancelled by keyboard interruption') + self.logger.info("Cancelled by keyboard interruption") except Exception as e: - self.logger.exception(f'Encountered exception: {e}') + self.logger.exception(f"Encountered exception: {e}") raise if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--tutorial", required=True, - help="Tutorial Endpoint ID") - parser.add_argument("-i", "--id", required=True, - help="API_CLIENT_ID for Globus") - parser.add_argument("-s", "--secret", required=True, - help="API_CLIENT_SECRET for Globus") + parser.add_argument("-t", "--tutorial", required=True, help="Tutorial Endpoint ID") + parser.add_argument("-i", "--id", required=True, help="API_CLIENT_ID for Globus") + parser.add_argument( + "-s", "--secret", required=True, help="API_CLIENT_SECRET for Globus" + ) args = parser.parse_args() client = ConfidentialAppAuthClient(args.id, args.secret) - scopes = ["https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all", - "urn:globus:auth:scope:search.api.globus.org:all", - "openid"] + scopes = [ + "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all", + "urn:globus:auth:scope:search.api.globus.org:all", + "openid", + ] token_response = client.oauth2_client_credentials_tokens(requested_scopes=scopes) - fx_token = token_response.by_resource_server['funcx_service']['access_token'] - search_token = token_response.by_resource_server['search.api.globus.org']['access_token'] - openid_token = token_response.by_resource_server['auth.globus.org']['access_token'] + fx_token = token_response.by_resource_server["funcx_service"]["access_token"] + search_token = token_response.by_resource_server["search.api.globus.org"][ + "access_token" + ] + openid_token = token_response.by_resource_server["auth.globus.org"]["access_token"] fx_auth = AccessTokenAuthorizer(fx_token) search_auth = AccessTokenAuthorizer(search_token) openid_auth = AccessTokenAuthorizer(openid_token) val = 1 - tt = TestTutorial(fx_auth, search_auth, openid_auth, - args.tutorial, identity, val, args=val) + tt = TestTutorial( + fx_auth, search_auth, openid_auth, args.tutorial, identity, val, args=val + ) tt.run() diff --git a/funcx_sdk/funcx/__init__.py b/funcx_sdk/funcx/__init__.py index d17d8b1f6..7358085d4 100644 --- a/funcx_sdk/funcx/__init__.py +++ b/funcx_sdk/funcx/__init__.py @@ -1,8 +1,6 @@ """ funcX : Fast function serving for clouds, clusters and supercomputers. """ -import logging - from funcx.sdk.version import VERSION __author__ = "The funcX team" @@ -10,3 +8,5 @@ from funcx.sdk.client import FuncXClient from funcx.utils.loggers import set_file_logger, set_stream_logger + +__all__ = ("FuncXClient", "set_file_logger", "set_stream_logger") diff --git a/funcx_sdk/funcx/sdk/__init__.py b/funcx_sdk/funcx/sdk/__init__.py index 7775be9a4..8435c7f08 100644 --- a/funcx_sdk/funcx/sdk/__init__.py +++ b/funcx_sdk/funcx/sdk/__init__.py @@ -1 +1,3 @@ from funcx.sdk.version import VERSION + +__all__ = ("VERSION",) diff --git a/funcx_sdk/funcx/sdk/asynchronous/ws_polling_task.py b/funcx_sdk/funcx/sdk/asynchronous/ws_polling_task.py index a60a34348..5237886ca 100644 --- a/funcx_sdk/funcx/sdk/asynchronous/ws_polling_task.py +++ b/funcx_sdk/funcx/sdk/asynchronous/ws_polling_task.py @@ -1,7 +1,7 @@ import asyncio import json import logging -from asyncio import AbstractEventLoop, QueueEmpty +from asyncio import AbstractEventLoop import dill import websockets @@ -72,8 +72,9 @@ def __init__( # the WebSocket server immediately self.running_task_group_ids.add(self.init_task_group_id) - # Set event loop explicitly since event loop can only be fetched automatically in main thread - # when batch submission is enabled, the task submission is in a new thread + # Set event loop explicitly since event loop can only be fetched automatically + # in main thread when batch submission is enabled, the task submission is in a + # new thread asyncio.set_event_loop(self.loop) self.task_group_ids_queue = asyncio.Queue() self.pending_tasks = {} @@ -91,12 +92,13 @@ async def init_ws(self, start_message_handlers=True): self.ws = await websockets.connect( self.results_ws_uri, extra_headers=headers ) - # initial Globus authentication happens during the HTTP portion of the handshake, - # so an invalid handshake means that the user was not authenticated + # initial Globus authentication happens during the HTTP portion of the + # handshake, so an invalid handshake means that the user was not authenticated except InvalidStatusCode as e: if e.status_code == 404: raise Exception( - "WebSocket service responsed with a 404. Please ensure you set the correct results_ws_uri" + "WebSocket service responsed with a 404. " + "Please ensure you set the correct results_ws_uri" ) else: raise e @@ -136,12 +138,14 @@ async def handle_incoming(self, pending_futures, auto_close=False): if await self.set_result(task_id, data, pending_futures): return else: - # This scenario occurs rarely using non-batching mode, - # but quite often in batching mode. - # When submitting tasks in batch with batch_run, - # some task results may be received by websocket before the response of batch_run, + # This scenario occurs rarely using non-batching mode, but quite + # often in batching mode. + # + # When submitting tasks in batch with batch_run, some task results + # may be received by websocket before the response of batch_run, # and pending_futures do not have the futures for the tasks yet. - # We store these in unknown_results and process when their futures are ready. + # We store these in unknown_results and process when their futures + # are ready. self.unknown_results[task_id] = data # Handle the results received but not processed before @@ -188,8 +192,8 @@ async def set_result(self, task_id, data, pending_futures): except Exception: logger.exception("Caught unexpected exception while setting results") - # When the counter hits 0 we always exit. This guarantees that that - # if the counter increments to 1 on the executor, this handler needs to be restarted. + # When the counter hits 0 we always exit. This guarantees that that if the + # counter increments to 1 on the executor, this handler needs to be restarted. if self.atomic_controller is not None: count = self.atomic_controller.decrement() # Only close when count == 0 and unknown_results are empty @@ -216,14 +220,19 @@ def add_task(self, task: FuncXTask): def get_auth_header(self): """ - Gets an Authorization header to be sent during the WebSocket handshake. Based on - header setting in the Globus SDK: https://github.com/globus/globus-sdk-python/blob/main/globus_sdk/base.py + Gets an Authorization header to be sent during the WebSocket handshake. Returns ------- Key-value tuple of the Authorization header (key, value) """ + # TODO: under SDK v3 this will be + # + # return ( + # "Authorization", + # self.funcx_client.authorizer.get_authorization_header()` + # ) headers = dict() self.funcx_client.authorizer.set_authorization_header(headers) header_name = "Authorization" diff --git a/funcx_sdk/funcx/sdk/client.py b/funcx_sdk/funcx/sdk/client.py index 94d7108f9..238ad75c3 100644 --- a/funcx_sdk/funcx/sdk/client.py +++ b/funcx_sdk/funcx/sdk/client.py @@ -28,6 +28,8 @@ logger = logging.getLogger(__name__) +_FUNCX_HOME = os.path.join("~", ".funcx") + class FuncXClient(FuncXErrorHandlingClient): """Main class for interacting with the funcX service @@ -48,7 +50,7 @@ class FuncXClient(FuncXErrorHandlingClient): def __init__( self, http_timeout=None, - funcx_home=os.path.join("~", ".funcx"), + funcx_home=_FUNCX_HOME, force_login=False, fx_authorizer=None, search_authorizer=None, diff --git a/funcx_sdk/funcx/sdk/error_handling_client.py b/funcx_sdk/funcx/sdk/error_handling_client.py index b1336ad6e..6faaad319 100644 --- a/funcx_sdk/funcx/sdk/error_handling_client.py +++ b/funcx_sdk/funcx/sdk/error_handling_client.py @@ -6,7 +6,9 @@ class FuncXErrorHandlingClient(ThrottledBaseClient): - """Class which handles errors from GET, POST, and DELETE requests before proceeding""" + """ + Class which handles errors from GET, POST, and DELETE requests before proceeding + """ def get(self, path, **kwargs): try: diff --git a/funcx_sdk/funcx/sdk/executor.py b/funcx_sdk/funcx/sdk/executor.py index 61b599688..2f1d39acd 100644 --- a/funcx_sdk/funcx/sdk/executor.py +++ b/funcx_sdk/funcx/sdk/executor.py @@ -231,7 +231,10 @@ def _submit_tasks(self, messages): self.poller_thread.atomic_controller.increment() def _get_tasks_in_batch(self): - """Get tasks from task_outgoing queue in batch, either by interval or by batch size""" + """ + Get tasks from task_outgoing queue in batch, + either by interval or by batch size + """ messages = [] start = time.time() while True: @@ -312,8 +315,8 @@ def event_loop_thread(self, eventloop): eventloop.run_until_complete(self.web_socket_poller()) async def web_socket_poller(self): - # TODO: if WebSocket connection fails, we should either retry connecting and back off - # or we should set an exception to all of the outstanding futures + # TODO: if WebSocket connection fails, we should either retry connecting and + # backoff or we should set an exception to all of the outstanding futures await self.ws_handler.init_ws(start_message_handlers=False) await self.ws_handler.handle_incoming( self._function_future_map, auto_close=True diff --git a/funcx_sdk/funcx/sdk/search.py b/funcx_sdk/funcx/sdk/search.py index ed644fa56..016c190e3 100644 --- a/funcx_sdk/funcx/sdk/search.py +++ b/funcx_sdk/funcx/sdk/search.py @@ -81,7 +81,8 @@ def search_function(self, q, offset=0, limit=DEFAULT_SEARCH_LIMIT, advanced=Fals # print(res) # Restructure results to look like the data dict in FuncXClient - # see the JSON structure of res.data: https://docs.globus.org/api/search/search/#gsearchresult + # see the JSON structure of res.data: + # https://docs.globus.org/api/search/search/#gsearchresult gmeta = response.data["gmeta"] results = [] for item in gmeta: @@ -130,7 +131,7 @@ def search_endpoint(self, q, scope="all", owner_id=None): } elif scope == "shared-with-me": # TODO: filter for public=False AND owner != self._owner_uuid - # but...need to build advanced query for that, because GFilters cannot do NOT + # need to build advanced query for that, because GFilters cannot do NOT # raise Exception('This scope has not been implemented') scope_filter = { "type": "match_all", diff --git a/funcx_sdk/funcx/sdk/utils/futures.py b/funcx_sdk/funcx/sdk/utils/futures.py index 117bd6ca8..8e570dba2 100644 --- a/funcx_sdk/funcx/sdk/utils/futures.py +++ b/funcx_sdk/funcx/sdk/utils/futures.py @@ -2,7 +2,6 @@ Credit: Logan Ward """ -import json from concurrent.futures import Future from threading import Thread from time import sleep diff --git a/funcx_sdk/funcx/utils/loggers.py b/funcx_sdk/funcx/utils/loggers.py index 6eb7b0f51..59a080c5e 100644 --- a/funcx_sdk/funcx/utils/loggers.py +++ b/funcx_sdk/funcx/utils/loggers.py @@ -25,7 +25,8 @@ def set_file_logger( - level (logging.LEVEL): Set the logging level. - format_string (string): Set the format string - maxBytes: The maximum bytes per logger file, default: 100MB - - backupCount: The number of backup (must be non-zero) per logger file, default: 1 + - backupCount: The number of backup (must be non-zero) per logger file, + default: 1 Returns: - None diff --git a/funcx_sdk/funcx/utils/response_errors.py b/funcx_sdk/funcx/utils/response_errors.py index 085e14c9d..6b9a9c33c 100644 --- a/funcx_sdk/funcx/utils/response_errors.py +++ b/funcx_sdk/funcx/utils/response_errors.py @@ -2,8 +2,10 @@ from enum import Enum -# IMPORTANT: new error codes can be added, but existing error codes must not be changed once published. -# changing existing error codes will cause problems with users that have older SDK versions +# IMPORTANT: new error codes can be added, but existing error codes must not be changed +# once published. +# changing existing error codes will cause problems with users that have older SDK +# versions class ResponseErrorCode(int, Enum): UNKNOWN_ERROR = 0 USER_UNAUTHENTICATED = 1 @@ -86,8 +88,9 @@ def unpack(cls, res_data): ): try: # if the response error code is not recognized here because the - # user is not using the latest SDK version, an exception will occur here - # which we will pass in order to give the user a generic exception below + # user is not using the latest SDK version, an exception will occur + # here, which we will pass in order to give the user a generic + # exception below res_error_code = ResponseErrorCode(res_data["code"]) error_class = None if res_error_code is ResponseErrorCode.USER_UNAUTHENTICATED: @@ -179,9 +182,10 @@ def __init__(self): class UserNotFound(FuncxResponseError): - """User not found exception. This error should only be used when the server must - look up a user in order to fulfill the user's request body. If the request only - fails because the user is unauthenticated, UserUnauthenticated should be used instead. + """ + User not found exception. This error should only be used when the server must look + up a user in order to fulfill the user's request body. If the request only fails + because the user is unauthenticated, UserUnauthenticated should be used instead. """ code = ResponseErrorCode.USER_NOT_FOUND @@ -399,7 +403,10 @@ class EndpointOutdated(FuncxResponseError): def __init__(self, min_ep_version): self.error_args = [min_ep_version] - self.reason = f"Endpoint is out of date. Minimum supported endpoint version is {min_ep_version}" + self.reason = ( + "Endpoint is out of date. " + f"Minimum supported endpoint version is {min_ep_version}" + ) class TaskGroupNotFound(FuncxResponseError):