Skip to content

Commit

Permalink
Add @override for files in `src/lightning/fabric/plugins/environmen…
Browse files Browse the repository at this point in the history
…ts` (#19098)
  • Loading branch information
VictorPrins authored Dec 1, 2023
1 parent c5363af commit 520c1e4
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/lightning/fabric/plugins/environments/kubeflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import logging
import os

from typing_extensions import override

from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment

log = logging.getLogger(__name__)
Expand All @@ -32,35 +34,45 @@ class KubeflowEnvironment(ClusterEnvironment):
"""

@property
@override
def creates_processes_externally(self) -> bool:
return True

@property
@override
def main_address(self) -> str:
return os.environ["MASTER_ADDR"]

@property
@override
def main_port(self) -> int:
return int(os.environ["MASTER_PORT"])

@staticmethod
@override
def detect() -> bool:
raise NotImplementedError("The Kubeflow environment can't be detected automatically.")

@override
def world_size(self) -> int:
return int(os.environ["WORLD_SIZE"])

@override
def set_world_size(self, size: int) -> None:
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

@override
def global_rank(self) -> int:
return int(os.environ["RANK"])

@override
def set_global_rank(self, rank: int) -> None:
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

@override
def local_rank(self) -> int:
return 0

@override
def node_rank(self) -> int:
return self.global_rank()
13 changes: 13 additions & 0 deletions src/lightning/fabric/plugins/environments/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os
import socket

from typing_extensions import override

from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.utilities.rank_zero import rank_zero_only

Expand Down Expand Up @@ -42,6 +44,7 @@ def __init__(self) -> None:
self._world_size: int = 1

@property
@override
def creates_processes_externally(self) -> bool:
"""Returns whether the cluster creates the processes or not.
Expand All @@ -52,10 +55,12 @@ def creates_processes_externally(self) -> bool:
return "LOCAL_RANK" in os.environ

@property
@override
def main_address(self) -> str:
return os.environ.get("MASTER_ADDR", "127.0.0.1")

@property
@override
def main_port(self) -> int:
if self._main_port == -1:
self._main_port = (
Expand All @@ -64,29 +69,37 @@ def main_port(self) -> int:
return self._main_port

@staticmethod
@override
def detect() -> bool:
return True

@override
def world_size(self) -> int:
return self._world_size

@override
def set_world_size(self, size: int) -> None:
self._world_size = size

@override
def global_rank(self) -> int:
return self._global_rank

@override
def set_global_rank(self, rank: int) -> None:
self._global_rank = rank
rank_zero_only.rank = rank

@override
def local_rank(self) -> int:
return int(os.environ.get("LOCAL_RANK", 0))

@override
def node_rank(self) -> int:
group_rank = os.environ.get("GROUP_RANK", 0)
return int(os.environ.get("NODE_RANK", group_rank))

@override
def teardown(self) -> None:
if "WORLD_SIZE" in os.environ:
del os.environ["WORLD_SIZE"]
Expand Down
12 changes: 12 additions & 0 deletions src/lightning/fabric/plugins/environments/lsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import socket
from typing import Dict, List

from typing_extensions import override

from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.utilities.cloud_io import get_filesystem

Expand Down Expand Up @@ -62,27 +64,32 @@ def _set_init_progress_group_env_vars(self) -> None:
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

@property
@override
def creates_processes_externally(self) -> bool:
"""LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them."""
return True

@property
@override
def main_address(self) -> str:
"""The main address is read from an OpenMPI host rank file in the environment variable
``LSB_DJOB_RANKFILE``."""
return self._main_address

@property
@override
def main_port(self) -> int:
"""The main port is calculated from the LSF job ID."""
return self._main_port

@staticmethod
@override
def detect() -> bool:
"""Returns ``True`` if the current process was launched using the ``jsrun`` command."""
required_env_vars = {"LSB_JOBID", "LSB_DJOB_RANKFILE", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
return required_env_vars.issubset(os.environ.keys())

@override
def world_size(self) -> int:
"""The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``."""
world_size = os.environ.get("JSM_NAMESPACE_SIZE")
Expand All @@ -93,9 +100,11 @@ def world_size(self) -> int:
)
return int(world_size)

@override
def set_world_size(self, size: int) -> None:
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

@override
def global_rank(self) -> int:
"""The world size is read from the environment variable ``JSM_NAMESPACE_RANK``."""
global_rank = os.environ.get("JSM_NAMESPACE_RANK")
Expand All @@ -106,9 +115,11 @@ def global_rank(self) -> int:
)
return int(global_rank)

@override
def set_global_rank(self, rank: int) -> None:
log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

@override
def local_rank(self) -> int:
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK")
Expand All @@ -119,6 +130,7 @@ def local_rank(self) -> int:
)
return int(local_rank)

@override
def node_rank(self) -> int:
"""The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored in
``LSB_DJOB_RANKFILE``."""
Expand Down
11 changes: 11 additions & 0 deletions src/lightning/fabric/plugins/environments/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional

from lightning_utilities.core.imports import RequirementCache
from typing_extensions import override

from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.environments.lightning import find_free_network_port
Expand Down Expand Up @@ -47,22 +48,26 @@ def __init__(self) -> None:
self._main_port: Optional[int] = None

@property
@override
def creates_processes_externally(self) -> bool:
return True

@property
@override
def main_address(self) -> str:
if self._main_address is None:
self._main_address = self._get_main_address()
return self._main_address

@property
@override
def main_port(self) -> int:
if self._main_port is None:
self._main_port = self._get_main_port()
return self._main_port

@staticmethod
@override
def detect() -> bool:
"""Returns ``True`` if the `mpi4py` package is installed and MPI returns a world size greater than 1."""
if not _MPI4PY_AVAILABLE:
Expand All @@ -72,27 +77,33 @@ def detect() -> bool:

return MPI.COMM_WORLD.Get_size() > 1

@override
@lru_cache(1)
def world_size(self) -> int:
return self._comm_world.Get_size()

@override
def set_world_size(self, size: int) -> None:
log.debug("MPIEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

@override
@lru_cache(1)
def global_rank(self) -> int:
return self._comm_world.Get_rank()

@override
def set_global_rank(self, rank: int) -> None:
log.debug("MPIEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

@override
@lru_cache(1)
def local_rank(self) -> int:
if self._comm_local is None:
self._init_comm_local()
assert self._comm_local is not None
return self._comm_local.Get_rank()

@override
def node_rank(self) -> int:
if self._node_rank is None:
self._init_comm_local()
Expand Down
13 changes: 13 additions & 0 deletions src/lightning/fabric/plugins/environments/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import sys
from typing import Optional

from typing_extensions import override

from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.fabric.utilities.rank_zero import rank_zero_warn
Expand Down Expand Up @@ -52,10 +54,12 @@ def __init__(self, auto_requeue: bool = True, requeue_signal: Optional[signal.Si
self._validate_srun_variables()

@property
@override
def creates_processes_externally(self) -> bool:
return True

@property
@override
def main_address(self) -> str:
root_node = os.environ.get("MASTER_ADDR")
if root_node is None:
Expand All @@ -67,6 +71,7 @@ def main_address(self) -> str:
return root_node

@property
@override
def main_port(self) -> int:
# -----------------------
# SLURM JOB = PORT number
Expand Down Expand Up @@ -94,6 +99,7 @@ def main_port(self) -> int:
return default_port

@staticmethod
@override
def detect() -> bool:
"""Returns ``True`` if the current process was launched on a SLURM cluster.
Expand Down Expand Up @@ -124,24 +130,31 @@ def job_id() -> Optional[int]:
except ValueError:
return None

@override
def world_size(self) -> int:
return int(os.environ["SLURM_NTASKS"])

@override
def set_world_size(self, size: int) -> None:
log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

@override
def global_rank(self) -> int:
return int(os.environ["SLURM_PROCID"])

@override
def set_global_rank(self, rank: int) -> None:
log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

@override
def local_rank(self) -> int:
return int(os.environ["SLURM_LOCALID"])

@override
def node_rank(self) -> int:
return int(os.environ["SLURM_NODEID"])

@override
def validate_settings(self, num_devices: int, num_nodes: int) -> None:
if _is_slurm_interactive_mode():
return
Expand Down
Loading

0 comments on commit 520c1e4

Please sign in to comment.