Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract env set up/output handling in local_scheduler for easier subclassing #817

Merged
merged 1 commit into from
Feb 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 54 additions & 21 deletions torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,17 @@
from dataclasses import asdict, dataclass
from datetime import datetime
from types import FrameType
from typing import Any, BinaryIO, Callable, Dict, Iterable, List, Optional, TextIO
from typing import (
Any,
BinaryIO,
Callable,
Dict,
Iterable,
List,
Optional,
TextIO,
Tuple,
)

from torchx.schedulers.api import (
AppDryRunInfo,
Expand Down Expand Up @@ -658,29 +668,11 @@ def _popen(
as file name ``str`` rather than a file-like obj.
"""

stdout_ = self._get_file_io(replica_params.stdout)
stderr_ = self._get_file_io(replica_params.stderr)
combined_: Optional[Tee] = None
combined_file = self._get_file_io(replica_params.combined)
if combined_file:
combined_ = Tee(
combined_file,
none_throws(replica_params.stdout),
none_throws(replica_params.stderr),
)

# inherit parent's env vars since 99.9% of the time we want this behavior
# just make sure we override the parent's env vars with the user_defined ones
env = os.environ.copy()
env.update(replica_params.env)
# PATH is a special one, instead of overriding, append
env["PATH"] = _join_PATH(replica_params.env.get("PATH"), os.getenv("PATH"))

# default to unbuffered python for faster responsiveness locally
env.setdefault("PYTHONUNBUFFERED", "x")
stdout_, stderr_, combined_ = self._get_replica_output_handles(replica_params)

args_pfmt = pprint.pformat(asdict(replica_params), indent=2, width=80)
log.debug(f"Running {role_name} (replica {replica_id}):\n {args_pfmt}")
env = self._get_replica_env(replica_params)

proc = subprocess.Popen(
args=replica_params.args,
Expand All @@ -700,6 +692,47 @@ def _popen(
error_file=env.get("TORCHELASTIC_ERROR_FILE", "<N/A>"),
)

def _get_replica_output_handles(
self,
replica_params: ReplicaParam,
) -> Tuple[Optional[io.FileIO], Optional[io.FileIO], Optional[Tee]]:
"""
Returns the stdout, stderr, and combined outputs of the replica.
If the combined output file is not specified, then the combined output is ``None``.
"""

stdout_ = self._get_file_io(replica_params.stdout)
stderr_ = self._get_file_io(replica_params.stderr)
combined_: Optional[Tee] = None
combined_file = self._get_file_io(replica_params.combined)
if combined_file:
combined_ = Tee(
combined_file,
none_throws(replica_params.stdout),
none_throws(replica_params.stderr),
)
return stdout_, stderr_, combined_

def _get_replica_env(
self,
replica_params: ReplicaParam,
) -> Dict[str, str]:
"""
Returns environment variables for the ``_LocalReplica``
"""

# inherit parent's env vars since 99.9% of the time we want this behavior
# just make sure we override the parent's env vars with the user_defined ones
env = os.environ.copy()
env.update(replica_params.env)
# PATH is a special one, instead of overriding, append
env["PATH"] = _join_PATH(replica_params.env.get("PATH"), os.getenv("PATH"))

# default to unbuffered python for faster responsiveness locally
env.setdefault("PYTHONUNBUFFERED", "x")

return env

def _get_app_log_dir(self, app_id: str, cfg: LocalOpts) -> str:
"""
Returns the log dir. We redirect stdout/err
Expand Down
Loading