diff --git a/torchx/schedulers/local_scheduler.py b/torchx/schedulers/local_scheduler.py index aa899b1d..3e11fee4 100644 --- a/torchx/schedulers/local_scheduler.py +++ b/torchx/schedulers/local_scheduler.py @@ -28,6 +28,7 @@ import warnings from dataclasses import asdict, dataclass from datetime import datetime +from subprocess import Popen from types import FrameType from typing import ( Any, @@ -696,12 +697,11 @@ def _popen( log.debug(f"Running {role_name} (replica {replica_id}):\n {args_pfmt}") env = self._get_replica_env(replica_params) - proc = subprocess.Popen( + proc = self.run_local_job( args=replica_params.args, env=env, stdout=stdout_, stderr=stderr_, - start_new_session=True, cwd=replica_params.cwd, ) return _LocalReplica( @@ -714,6 +714,23 @@ def _popen( error_file=env.get("TORCHELASTIC_ERROR_FILE", ""), ) + def run_local_job( + self, + args: List[str], + env: Dict[str, str], + stdout: Optional[io.FileIO], + stderr: Optional[io.FileIO], + cwd: Optional[str] = None, + ) -> Popen[bytes]: + return subprocess.Popen( + args=args, + env=env, + stdout=stdout, + stderr=stderr, + start_new_session=True, + cwd=cwd, + ) + def _get_replica_output_handles( self, replica_params: ReplicaParam,