diff --git a/parsl/executors/high_throughput/process_worker_pool.py b/parsl/executors/high_throughput/process_worker_pool.py index 6833e2da5d..fe5e154c15 100755 --- a/parsl/executors/high_throughput/process_worker_pool.py +++ b/parsl/executors/high_throughput/process_worker_pool.py @@ -173,22 +173,15 @@ def __init__(self, self.block_id = block_id self.enable_mpi_mode = enable_mpi_mode - self.nodes_q = None - self.inflight_q = None - logger.warning(f"YADU: enable_mpi_mode: {enable_mpi_mode}") - self.mpi_node_tracker = None + self.mpi_node_tracker: Optional[MpiTaskToNodesTracker] = None + if self.enable_mpi_mode: scheduler = identify_scheduler() - logger.warning(f"YADU: Scheduler: {scheduler}") - logger.warning(f"YADU: os.env['PBS_NODEFILE'] : {os.environ['PBS_NODEFILE']}") available_nodes = get_nodes_in_batchjob(scheduler=scheduler) - logging.warning(f"YADU: Got nodes : {available_nodes}") - self.nodes_q = multiprocessing.Queue() - self.inflight_q = multiprocessing.Queue() self.mpi_node_tracker = MpiTaskToNodesTracker( available_nodes=available_nodes, - nodes_q=self.nodes_q, - inflight_q=self.inflight_q) + nodes_q=multiprocessing.Queue(), + inflight_q=multiprocessing.Queue()) if os.environ.get('PARSL_CORES'): cores_on_node = int(os.environ['PARSL_CORES']) @@ -507,7 +500,7 @@ def start(self): return -def update_resource_spec_env_vars(resource_spec: Dict, node_info: Optional[List[str]] = None) -> None: +def update_resource_spec_env_vars(resource_spec: Dict, node_info: List[str]) -> None: scheduler = identify_scheduler() prefix_table = compose_all(scheduler, resource_spec=resource_spec, node_hostnames=node_info) @@ -533,10 +526,8 @@ def execute_task(bufs, mpi_node_tracker: Optional[MpiTaskToNodesTracker] = None) if resource_spec.get('NUM_NODES') and mpi_node_tracker: logger.warning(f"Provisioning {resource_spec['NUM_NODES']} nodes") worker_id = os.environ['PARSL_WORKER_RANK'] - logger.warning("YADU: Trying to get nodes from q") nodes_for_task = mpi_node_tracker.get_nodes(resource_spec['NUM_NODES'], owner_tag=worker_id) - logger.warning("YADU: Got nodes from q") update_resource_spec_env_vars(resource_spec=resource_spec, node_info=nodes_for_task) # We might need to look into callability of the function from itself @@ -558,7 +549,7 @@ def execute_task(bufs, mpi_node_tracker: Optional[MpiTaskToNodesTracker] = None) exec(code, user_ns, user_ns) finally: # Return the held nodes if any before raising exceptions are processed - if nodes_for_task: + if nodes_for_task and mpi_node_tracker: mpi_node_tracker.return_nodes(num_nodes=len(nodes_for_task), owner_tag=worker_id) result = user_ns.get(resultname) diff --git a/parsl/tests/test_mpi_apps/test_mpi_mode_disabled.py b/parsl/tests/test_mpi_apps/test_mpi_mode_disabled.py index 4daeda4ebc..44c56ed9e3 100644 --- a/parsl/tests/test_mpi_apps/test_mpi_mode_disabled.py +++ b/parsl/tests/test_mpi_apps/test_mpi_mode_disabled.py @@ -3,7 +3,7 @@ import parsl from parsl import python_app from parsl.tests.configs.htex_local import fresh_config -from parsl.executors.high_throughput.process_worker_pool import MpiTaskToNodesTracker + EXECUTOR_LABEL = "MPI_TEST" @@ -15,9 +15,15 @@ def local_setup(): parsl.load(config) +def local_teardown(): + parsl.dfk().cleanup() + parsl.clear() + + @python_app def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict: import os + parsl_vars = {} for key in os.environ: if key.startswith("PARSL_"): @@ -29,9 +35,10 @@ def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict: def test_only_resource_specs_set(): """Confirm that resource_spec env vars are set while launch prefixes are not when enable_mpi_mode = False""" - resource_spec = {"NUM_NODES": 4, - "RANKS_PER_NODE": 2, - } + resource_spec = { + "NUM_NODES": 4, + "RANKS_PER_NODE": 2, + } future = get_env_vars(parsl_resource_specification=resource_spec) diff --git a/parsl/tests/test_mpi_apps/test_mpi_mode_enabled.py b/parsl/tests/test_mpi_apps/test_mpi_mode_enabled.py index 1cf70ff915..a25b253d50 100644 --- a/parsl/tests/test_mpi_apps/test_mpi_mode_enabled.py +++ b/parsl/tests/test_mpi_apps/test_mpi_mode_enabled.py @@ -4,7 +4,7 @@ import parsl from parsl import python_app, bash_app from parsl.tests.configs.htex_local import fresh_config -from parsl.executors.high_throughput.process_worker_pool import MpiTaskToNodesTracker + import os EXECUTOR_LABEL = "MPI_TEST" @@ -19,6 +19,11 @@ def local_setup(): parsl.load(config) +def local_teardown(): + parsl.dfk().cleanup() + parsl.clear() + + @python_app def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict: import os @@ -103,7 +108,9 @@ def test_bash_multiple_set(): @bash_app def bash_resource_spec(resource_specification: Dict, stdout=parsl.AUTO_LOGNAME): - total_ranks = resource_specification['RANKS_PER_NODE'] * resource_specification['NUM_NODES'] + total_ranks = ( + resource_specification["RANKS_PER_NODE"] * resource_specification["NUM_NODES"] + ) return f'echo "{total_ranks}"' @@ -117,4 +124,5 @@ def test_bash_app_using_resource_spec(): assert future.result() == 0 with open(future.stdout) as f: output = f.readlines() - assert int(output[0].strip()) == resource_spec["NUM_NODES"] * resource_spec["RANKS_PER_NODE"] + total_ranks = resource_spec["NUM_NODES"] * resource_spec["RANKS_PER_NODE"] + assert int(output[0].strip()) == total_ranks diff --git a/parsl/tests/test_mpi_apps/test_resource_spec.py b/parsl/tests/test_mpi_apps/test_resource_spec.py index 8db424a2d7..a5abf3eb1f 100644 --- a/parsl/tests/test_mpi_apps/test_resource_spec.py +++ b/parsl/tests/test_mpi_apps/test_resource_spec.py @@ -1,6 +1,8 @@ import contextlib import logging import multiprocessing +import os +import typing import pytest import unittest @@ -13,10 +15,9 @@ get_pbs_hosts_list, get_slurm_hosts_list, get_nodes_in_batchjob, - get_cobalt_hosts_list, identify_scheduler, + MpiTaskToNodesTracker, ) -from parsl.executors.high_throughput.process_worker_pool import MpiTaskToNodesTracker EXECUTOR_LABEL = "MPI_TEST" @@ -51,7 +52,6 @@ def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict: @pytest.mark.local def test_resource_spec_env_vars(): - resource_spec = { "NUM_NODES": 4, "RANKS_PER_NODE": 2, @@ -70,20 +70,16 @@ def test_resource_spec_env_vars(): @pytest.mark.local @unittest.mock.patch("subprocess.check_output", return_value=b"c203-031\nc203-032\n") def test_slurm_mocked_mpi_fetch(subprocess_check): - nodeinfo = get_slurm_hosts_list() assert isinstance(nodeinfo, list) assert len(nodeinfo) == 2 -import os - - @contextlib.contextmanager -def add_to_path(path: os.PathLike) -> None: +def add_to_path(path: os.PathLike) -> typing.Generator[None, None, None]: old_path = os.environ["PATH"] try: - os.environ["PATH"] += path + os.environ["PATH"] += str(path) yield finally: os.environ["PATH"] = old_path @@ -100,7 +96,7 @@ def test_slurm_mpi_fetch(): @contextlib.contextmanager -def mock_pbs_nodefile() -> None: +def mock_pbs_nodefile() -> typing.Generator[None, None, None]: cwd = os.path.abspath(os.path.dirname(__file__)) filename = os.path.join(cwd, "mocks", "pbs_nodefile") try: