Skip to content

Commit

Permalink
Add @override for files in src/lightning/fabric/accelerators (#19068
Browse files Browse the repository at this point in the history
)
  • Loading branch information
VictorPrins authored Nov 27, 2023
1 parent 37081c1 commit e30401a
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/lightning/fabric/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import List, Union

import torch
from typing_extensions import override

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
Expand All @@ -22,6 +23,7 @@
class CPUAccelerator(Accelerator):
"""Accelerator for CPU devices."""

@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
Expand All @@ -31,31 +33,37 @@ def setup_device(self, device: torch.device) -> None:
if device.type != "cpu":
raise ValueError(f"Device should be CPU, got {device} instead.")

@override
def teardown(self) -> None:
pass

@staticmethod
@override
def parse_devices(devices: Union[int, str, List[int]]) -> int:
"""Accelerator device parsing logic."""
return _parse_cpu_cores(devices)

@staticmethod
@override
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices

@staticmethod
@override
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return 1

@staticmethod
@override
def is_available() -> bool:
"""CPU is always available for execution."""
return True

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cpu",
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Generator, List, Optional, Union, cast

import torch
from typing_extensions import override

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
Expand All @@ -28,6 +29,7 @@
class CUDAAccelerator(Accelerator):
"""Accelerator for NVIDIA CUDA devices."""

@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
Expand All @@ -39,31 +41,37 @@ def setup_device(self, device: torch.device) -> None:
_check_cuda_matmul_precision(device)
torch.cuda.set_device(device)

@override
def teardown(self) -> None:
_clear_cuda_memory()

@staticmethod
@override
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic."""
from lightning.fabric.utilities.device_parser import _parse_gpu_ids

return _parse_gpu_ids(devices, include_cuda=True)

@staticmethod
@override
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
return [torch.device("cuda", i) for i in devices]

@staticmethod
@override
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return num_cuda_devices()

@staticmethod
@override
def is_available() -> bool:
return num_cuda_devices() > 0

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cuda",
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/fabric/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import List, Optional, Union

import torch
from typing_extensions import override

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
Expand All @@ -28,6 +29,7 @@ class MPSAccelerator(Accelerator):
"""

@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
Expand All @@ -37,35 +39,41 @@ def setup_device(self, device: torch.device) -> None:
if device.type != "mps":
raise ValueError(f"Device should be MPS, got {device} instead.")

@override
def teardown(self) -> None:
pass

@staticmethod
@override
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic."""
from lightning.fabric.utilities.device_parser import _parse_gpu_ids

return _parse_gpu_ids(devices, include_mps=True)

@staticmethod
@override
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
parsed_devices = MPSAccelerator.parse_devices(devices)
assert parsed_devices is not None
return [torch.device("mps", i) for i in range(len(parsed_devices))]

@staticmethod
@override
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return 1

@staticmethod
@override
@lru_cache(1)
def is_available() -> bool:
"""MPS is only available on a machine with the ARM-based Apple Silicon processors."""
return torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64")

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"mps",
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/accelerators/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional

from typing_extensions import override

from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.fabric.utilities.registry import _register_classes

Expand Down Expand Up @@ -82,6 +84,7 @@ def do_register(name: str, accelerator: Callable) -> Callable:

return do_register

@override
def get(self, name: str, default: Optional[Any] = None) -> Any:
"""Calls the registered accelerator with the required parameters and returns the accelerator object.
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/fabric/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import override

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
Expand All @@ -34,18 +35,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)

@override
def setup_device(self, device: torch.device) -> None:
pass

@override
def teardown(self) -> None:
pass

@staticmethod
@override
def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]:
"""Accelerator device parsing logic."""
return _parse_tpu_devices(devices)

@staticmethod
@override
def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_tpu_devices(devices)
Expand All @@ -62,6 +67,7 @@ def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]:
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy

@staticmethod
@override
# XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it
# https://github.com/pytorch/xla/blob/v2.0.0/torch_xla/distributed/xla_multiprocessing.py#L280
@functools.lru_cache(maxsize=1)
Expand All @@ -84,6 +90,7 @@ def auto_device_count() -> int:
return getenv_as(xenv.TPU_NUM_DEVICES, int, 8)

@staticmethod
@override
@functools.lru_cache(maxsize=1)
def is_available() -> bool:
try:
Expand All @@ -94,6 +101,7 @@ def is_available() -> bool:
return False

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register("tpu", cls, description=cls.__name__)

Expand Down

0 comments on commit e30401a

Please sign in to comment.