From e30401af80ba777459d0ab6b948ac5814fa34ae5 Mon Sep 17 00:00:00 2001 From: Victor Prins Date: Mon, 27 Nov 2023 13:22:29 +0100 Subject: [PATCH] Add `@override` for files in `src/lightning/fabric/accelerators` (#19068) --- src/lightning/fabric/accelerators/cpu.py | 8 ++++++++ src/lightning/fabric/accelerators/cuda.py | 8 ++++++++ src/lightning/fabric/accelerators/mps.py | 8 ++++++++ src/lightning/fabric/accelerators/registry.py | 3 +++ src/lightning/fabric/accelerators/xla.py | 8 ++++++++ 5 files changed, 35 insertions(+) diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index 0ab242eb7505a..1bcec1b2ac278 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -14,6 +14,7 @@ from typing import List, Union import torch +from typing_extensions import override from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.accelerators.registry import _AcceleratorRegistry @@ -22,6 +23,7 @@ class CPUAccelerator(Accelerator): """Accelerator for CPU devices.""" + @override def setup_device(self, device: torch.device) -> None: """ Raises: @@ -31,31 +33,37 @@ def setup_device(self, device: torch.device) -> None: if device.type != "cpu": raise ValueError(f"Device should be CPU, got {device} instead.") + @override def teardown(self) -> None: pass @staticmethod + @override def parse_devices(devices: Union[int, str, List[int]]) -> int: """Accelerator device parsing logic.""" return _parse_cpu_cores(devices) @staticmethod + @override def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices @staticmethod + @override def auto_device_count() -> int: """Get the devices when set to auto.""" return 1 @staticmethod + @override def is_available() -> bool: """CPU is always available for execution.""" return True @classmethod + @override def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: accelerator_registry.register( "cpu", diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index c1d0ee10525fd..da0bf7a7e720e 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -18,6 +18,7 @@ from typing import Generator, List, Optional, Union, cast import torch +from typing_extensions import override from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.accelerators.registry import _AcceleratorRegistry @@ -28,6 +29,7 @@ class CUDAAccelerator(Accelerator): """Accelerator for NVIDIA CUDA devices.""" + @override def setup_device(self, device: torch.device) -> None: """ Raises: @@ -39,10 +41,12 @@ def setup_device(self, device: torch.device) -> None: _check_cuda_matmul_precision(device) torch.cuda.set_device(device) + @override def teardown(self) -> None: _clear_cuda_memory() @staticmethod + @override def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: """Accelerator device parsing logic.""" from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -50,20 +54,24 @@ def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: return _parse_gpu_ids(devices, include_cuda=True) @staticmethod + @override def get_parallel_devices(devices: List[int]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] @staticmethod + @override def auto_device_count() -> int: """Get the devices when set to auto.""" return num_cuda_devices() @staticmethod + @override def is_available() -> bool: return num_cuda_devices() > 0 @classmethod + @override def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: accelerator_registry.register( "cuda", diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py index efc825261a167..d0f36698616d4 100644 --- a/src/lightning/fabric/accelerators/mps.py +++ b/src/lightning/fabric/accelerators/mps.py @@ -16,6 +16,7 @@ from typing import List, Optional, Union import torch +from typing_extensions import override from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.accelerators.registry import _AcceleratorRegistry @@ -28,6 +29,7 @@ class MPSAccelerator(Accelerator): """ + @override def setup_device(self, device: torch.device) -> None: """ Raises: @@ -37,10 +39,12 @@ def setup_device(self, device: torch.device) -> None: if device.type != "mps": raise ValueError(f"Device should be MPS, got {device} instead.") + @override def teardown(self) -> None: pass @staticmethod + @override def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: """Accelerator device parsing logic.""" from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -48,6 +52,7 @@ def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: return _parse_gpu_ids(devices, include_mps=True) @staticmethod + @override def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" parsed_devices = MPSAccelerator.parse_devices(devices) @@ -55,17 +60,20 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi return [torch.device("mps", i) for i in range(len(parsed_devices))] @staticmethod + @override def auto_device_count() -> int: """Get the devices when set to auto.""" return 1 @staticmethod + @override @lru_cache(1) def is_available() -> bool: """MPS is only available on a machine with the ARM-based Apple Silicon processors.""" return torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64") @classmethod + @override def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: accelerator_registry.register( "mps", diff --git a/src/lightning/fabric/accelerators/registry.py b/src/lightning/fabric/accelerators/registry.py index 4b85fdf6ecb4c..1299b1e148aa8 100644 --- a/src/lightning/fabric/accelerators/registry.py +++ b/src/lightning/fabric/accelerators/registry.py @@ -13,6 +13,8 @@ # limitations under the License. from typing import Any, Callable, Dict, List, Optional +from typing_extensions import override + from lightning.fabric.utilities.exceptions import MisconfigurationException from lightning.fabric.utilities.registry import _register_classes @@ -82,6 +84,7 @@ def do_register(name: str, accelerator: Callable) -> Callable: return do_register + @override def get(self, name: str, default: Optional[Any] = None) -> Any: """Calls the registered accelerator with the required parameters and returns the accelerator object. diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index 45169f2ce8ee7..4fc3e2248e64b 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -16,6 +16,7 @@ import torch from lightning_utilities.core.imports import RequirementCache +from typing_extensions import override from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.accelerators.registry import _AcceleratorRegistry @@ -34,18 +35,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(*args, **kwargs) + @override def setup_device(self, device: torch.device) -> None: pass + @override def teardown(self) -> None: pass @staticmethod + @override def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]: """Accelerator device parsing logic.""" return _parse_tpu_devices(devices) @staticmethod + @override def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_tpu_devices(devices) @@ -62,6 +67,7 @@ def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]: # it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy @staticmethod + @override # XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it # https://github.com/pytorch/xla/blob/v2.0.0/torch_xla/distributed/xla_multiprocessing.py#L280 @functools.lru_cache(maxsize=1) @@ -84,6 +90,7 @@ def auto_device_count() -> int: return getenv_as(xenv.TPU_NUM_DEVICES, int, 8) @staticmethod + @override @functools.lru_cache(maxsize=1) def is_available() -> bool: try: @@ -94,6 +101,7 @@ def is_available() -> bool: return False @classmethod + @override def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: accelerator_registry.register("tpu", cls, description=cls.__name__)