Skip to content

Commit

Permalink
Only code cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
yadudoc committed Oct 12, 2023
1 parent 4c4535b commit c27be97
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 32 deletions.
21 changes: 6 additions & 15 deletions parsl/executors/high_throughput/process_worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,15 @@ def __init__(self,
self.block_id = block_id

self.enable_mpi_mode = enable_mpi_mode
self.nodes_q = None
self.inflight_q = None
logger.warning(f"YADU: enable_mpi_mode: {enable_mpi_mode}")
self.mpi_node_tracker = None
self.mpi_node_tracker: Optional[MpiTaskToNodesTracker] = None

if self.enable_mpi_mode:
scheduler = identify_scheduler()
logger.warning(f"YADU: Scheduler: {scheduler}")
logger.warning(f"YADU: os.env['PBS_NODEFILE'] : {os.environ['PBS_NODEFILE']}")
available_nodes = get_nodes_in_batchjob(scheduler=scheduler)
logging.warning(f"YADU: Got nodes : {available_nodes}")
self.nodes_q = multiprocessing.Queue()
self.inflight_q = multiprocessing.Queue()
self.mpi_node_tracker = MpiTaskToNodesTracker(
available_nodes=available_nodes,
nodes_q=self.nodes_q,
inflight_q=self.inflight_q)
nodes_q=multiprocessing.Queue(),
inflight_q=multiprocessing.Queue())

if os.environ.get('PARSL_CORES'):
cores_on_node = int(os.environ['PARSL_CORES'])
Expand Down Expand Up @@ -507,7 +500,7 @@ def start(self):
return


def update_resource_spec_env_vars(resource_spec: Dict, node_info: Optional[List[str]] = None) -> None:
def update_resource_spec_env_vars(resource_spec: Dict, node_info: List[str]) -> None:

scheduler = identify_scheduler()
prefix_table = compose_all(scheduler, resource_spec=resource_spec, node_hostnames=node_info)
Expand All @@ -533,10 +526,8 @@ def execute_task(bufs, mpi_node_tracker: Optional[MpiTaskToNodesTracker] = None)
if resource_spec.get('NUM_NODES') and mpi_node_tracker:
logger.warning(f"Provisioning {resource_spec['NUM_NODES']} nodes")
worker_id = os.environ['PARSL_WORKER_RANK']
logger.warning("YADU: Trying to get nodes from q")
nodes_for_task = mpi_node_tracker.get_nodes(resource_spec['NUM_NODES'],
owner_tag=worker_id)
logger.warning("YADU: Got nodes from q")
update_resource_spec_env_vars(resource_spec=resource_spec, node_info=nodes_for_task)

# We might need to look into callability of the function from itself
Expand All @@ -558,7 +549,7 @@ def execute_task(bufs, mpi_node_tracker: Optional[MpiTaskToNodesTracker] = None)
exec(code, user_ns, user_ns)
finally:
# Return the held nodes if any before raising exceptions are processed
if nodes_for_task:
if nodes_for_task and mpi_node_tracker:
mpi_node_tracker.return_nodes(num_nodes=len(nodes_for_task), owner_tag=worker_id)

result = user_ns.get(resultname)
Expand Down
15 changes: 11 additions & 4 deletions parsl/tests/test_mpi_apps/test_mpi_mode_disabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import parsl
from parsl import python_app
from parsl.tests.configs.htex_local import fresh_config
from parsl.executors.high_throughput.process_worker_pool import MpiTaskToNodesTracker

EXECUTOR_LABEL = "MPI_TEST"


Expand All @@ -15,9 +15,15 @@ def local_setup():
parsl.load(config)


def local_teardown():
parsl.dfk().cleanup()
parsl.clear()


@python_app
def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict:
import os

parsl_vars = {}
for key in os.environ:
if key.startswith("PARSL_"):
Expand All @@ -29,9 +35,10 @@ def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict:
def test_only_resource_specs_set():
"""Confirm that resource_spec env vars are set while launch prefixes are not
when enable_mpi_mode = False"""
resource_spec = {"NUM_NODES": 4,
"RANKS_PER_NODE": 2,
}
resource_spec = {
"NUM_NODES": 4,
"RANKS_PER_NODE": 2,
}

future = get_env_vars(parsl_resource_specification=resource_spec)

Expand Down
14 changes: 11 additions & 3 deletions parsl/tests/test_mpi_apps/test_mpi_mode_enabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import parsl
from parsl import python_app, bash_app
from parsl.tests.configs.htex_local import fresh_config
from parsl.executors.high_throughput.process_worker_pool import MpiTaskToNodesTracker

import os

EXECUTOR_LABEL = "MPI_TEST"
Expand All @@ -19,6 +19,11 @@ def local_setup():
parsl.load(config)


def local_teardown():
parsl.dfk().cleanup()
parsl.clear()


@python_app
def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict:
import os
Expand Down Expand Up @@ -103,7 +108,9 @@ def test_bash_multiple_set():

@bash_app
def bash_resource_spec(resource_specification: Dict, stdout=parsl.AUTO_LOGNAME):
total_ranks = resource_specification['RANKS_PER_NODE'] * resource_specification['NUM_NODES']
total_ranks = (
resource_specification["RANKS_PER_NODE"] * resource_specification["NUM_NODES"]
)
return f'echo "{total_ranks}"'


Expand All @@ -117,4 +124,5 @@ def test_bash_app_using_resource_spec():
assert future.result() == 0
with open(future.stdout) as f:
output = f.readlines()
assert int(output[0].strip()) == resource_spec["NUM_NODES"] * resource_spec["RANKS_PER_NODE"]
total_ranks = resource_spec["NUM_NODES"] * resource_spec["RANKS_PER_NODE"]
assert int(output[0].strip()) == total_ranks
16 changes: 6 additions & 10 deletions parsl/tests/test_mpi_apps/test_resource_spec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import logging
import multiprocessing
import os
import typing

import pytest
import unittest
Expand All @@ -13,10 +15,9 @@
get_pbs_hosts_list,
get_slurm_hosts_list,
get_nodes_in_batchjob,
get_cobalt_hosts_list,
identify_scheduler,
MpiTaskToNodesTracker,
)
from parsl.executors.high_throughput.process_worker_pool import MpiTaskToNodesTracker

EXECUTOR_LABEL = "MPI_TEST"

Expand Down Expand Up @@ -51,7 +52,6 @@ def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict:

@pytest.mark.local
def test_resource_spec_env_vars():

resource_spec = {
"NUM_NODES": 4,
"RANKS_PER_NODE": 2,
Expand All @@ -70,20 +70,16 @@ def test_resource_spec_env_vars():
@pytest.mark.local
@unittest.mock.patch("subprocess.check_output", return_value=b"c203-031\nc203-032\n")
def test_slurm_mocked_mpi_fetch(subprocess_check):

nodeinfo = get_slurm_hosts_list()
assert isinstance(nodeinfo, list)
assert len(nodeinfo) == 2


import os


@contextlib.contextmanager
def add_to_path(path: os.PathLike) -> None:
def add_to_path(path: os.PathLike) -> typing.Generator[None, None, None]:
old_path = os.environ["PATH"]
try:
os.environ["PATH"] += path
os.environ["PATH"] += str(path)
yield
finally:
os.environ["PATH"] = old_path
Expand All @@ -100,7 +96,7 @@ def test_slurm_mpi_fetch():


@contextlib.contextmanager
def mock_pbs_nodefile() -> None:
def mock_pbs_nodefile() -> typing.Generator[None, None, None]:
cwd = os.path.abspath(os.path.dirname(__file__))
filename = os.path.join(cwd, "mocks", "pbs_nodefile")
try:
Expand Down

0 comments on commit c27be97

Please sign in to comment.