Skip to content

Commit

Permalink
Update Habana integration to 1.2 (#18877)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
carmocca and Borda authored Oct 26, 2023
1 parent e50b68a commit 182c30b
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 76 deletions.
2 changes: 1 addition & 1 deletion docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
assist_local.AssistantCLI.pull_docs_files(
gh_user_repo="Lightning-AI/lightning-Habana",
target_dir="docs/source-pytorch/integrations/hpu",
checkout="tags/1.1.0",
checkout="tags/1.2.0",
)
assist_local.AssistantCLI.pull_docs_files(
gh_user_repo="Lightning-AI/lightning-Graphcore",
Expand Down
6 changes: 3 additions & 3 deletions requirements/_integrations/accelerators.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# validation HPU connectors
lightning-habana >=1.0.0
lightning-graphcore >=0.1.0.rc4
# validation accelerator connectors
lightning-habana >=1.2.0, <1.3.0
lightning-graphcore >=0.1.0, <0.2.0
4 changes: 2 additions & 2 deletions src/lightning/pytorch/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _lightning_graphcore_available
from lightning.pytorch.utilities.imports import _graphcore_available_and_importable
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -125,7 +125,7 @@ def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None:
datahook_selector = trainer._data_connector._datahook_selector
assert datahook_selector is not None
for hook in batch_transfer_hooks:
if _lightning_graphcore_available():
if _graphcore_available_and_importable():
from lightning_graphcore import IPUAccelerator

# TODO: This code could be done in a hook in the IPUAccelerator as it's a simple error check
Expand Down
26 changes: 13 additions & 13 deletions src/lightning/pytorch/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@
from lightning.pytorch.utilities.imports import (
_LIGHTNING_BAGUA_AVAILABLE,
_LIGHTNING_COLOSSALAI_AVAILABLE,
_lightning_graphcore_available,
_lightning_habana_available,
_graphcore_available_and_importable,
_habana_available_and_importable,
)
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn

Expand Down Expand Up @@ -347,12 +347,12 @@ def _choose_auto_accelerator(self) -> str:
"""Choose the accelerator type (str) based on availability."""
if XLAAccelerator.is_available():
return "tpu"
if _lightning_graphcore_available():
if _graphcore_available_and_importable():
from lightning_graphcore import IPUAccelerator

if IPUAccelerator.is_available():
return "ipu"
if _lightning_habana_available():
if _habana_available_and_importable():
from lightning_habana import HPUAccelerator

if HPUAccelerator.is_available():
Expand Down Expand Up @@ -435,7 +435,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:

def _choose_strategy(self) -> Union[Strategy, str]:
if self._accelerator_flag == "ipu":
if not _lightning_graphcore_available():
if not _graphcore_available_and_importable():
raise ImportError(
"You have passed `accelerator='ipu'` but the IPU integration is not installed."
" Please run `pip install lightning-graphcore` or check out"
Expand All @@ -445,7 +445,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:

return IPUStrategy.strategy_name
if self._accelerator_flag == "hpu":
if not _lightning_habana_available():
if not _habana_available_and_importable():
raise ImportError(
"You have asked for HPU but you miss install related integration."
" Please run `pip install lightning-habana` or see for further instructions"
Expand Down Expand Up @@ -514,7 +514,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
if isinstance(self._precision_plugin_flag, PrecisionPlugin):
return self._precision_plugin_flag

if _lightning_graphcore_available():
if _graphcore_available_and_importable():
from lightning_graphcore import IPUAccelerator, IPUPrecision

# TODO: For the strategies that have a fixed precision class, we don't really need this logic
Expand All @@ -524,7 +524,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
if isinstance(self.accelerator, IPUAccelerator):
return IPUPrecision(self._precision_flag)

if _lightning_habana_available():
if _habana_available_and_importable():
from lightning_habana import HPUAccelerator, HPUPrecisionPlugin

if isinstance(self.accelerator, HPUAccelerator):
Expand Down Expand Up @@ -571,7 +571,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:

def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, AMP type, and accelerator."""
if _lightning_habana_available():
if _habana_available_and_importable():
from lightning_habana import HPUAccelerator

if isinstance(self.accelerator, HPUAccelerator) and self._precision_flag not in (
Expand Down Expand Up @@ -626,7 +626,7 @@ def _lazy_init_strategy(self) -> None:
f" found {self.strategy.__class__.__name__}."
)

if _lightning_habana_available():
if _habana_available_and_importable():
from lightning_habana import HPUAccelerator, HPUParallelStrategy, SingleHPUStrategy

if isinstance(self.accelerator, HPUAccelerator) and not isinstance(
Expand All @@ -645,7 +645,7 @@ def is_distributed(self) -> bool:
DeepSpeedStrategy,
XLAStrategy,
]
if _lightning_habana_available():
if _habana_available_and_importable():
from lightning_habana import HPUParallelStrategy

distributed_strategies.append(HPUParallelStrategy)
Expand Down Expand Up @@ -698,7 +698,7 @@ def _register_external_accelerators_and_strategies() -> None:
if "bagua" not in StrategyRegistry:
BaguaStrategy.register_strategies(StrategyRegistry)

if _lightning_habana_available():
if _habana_available_and_importable():
from lightning_habana import HPUAccelerator, HPUParallelStrategy, SingleHPUStrategy

# TODO: Prevent registering multiple times
Expand All @@ -709,7 +709,7 @@ def _register_external_accelerators_and_strategies() -> None:
if "hpu_single" not in StrategyRegistry:
SingleHPUStrategy.register_strategies(StrategyRegistry)

if _lightning_graphcore_available():
if _graphcore_available_and_importable():
from lightning_graphcore import IPUAccelerator, IPUStrategy

# TODO: Prevent registering multiple times
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _lightning_graphcore_available
from lightning.pytorch.utilities.imports import _graphcore_available_and_importable
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
Expand Down Expand Up @@ -165,7 +165,7 @@ def attach_datamodule(
datamodule.trainer = trainer

def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
if _lightning_graphcore_available():
if _graphcore_available_and_importable():
from lightning_graphcore import IPUAccelerator

# `DistributedSampler` is never used with `poptorch.DataLoader`
Expand All @@ -191,7 +191,7 @@ def _prepare_dataloader(self, dataloader: object, shuffle: bool, mode: RunningSt
if not isinstance(dataloader, DataLoader):
return dataloader

if _lightning_graphcore_available():
if _graphcore_available_and_importable():
from lightning_graphcore import IPUAccelerator

# IPUs use a custom `poptorch.DataLoader` which we might need to convert to
Expand Down
10 changes: 5 additions & 5 deletions src/lightning/pytorch/trainer/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
XLAProfiler,
)
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _lightning_graphcore_available, _lightning_habana_available
from lightning.pytorch.utilities.imports import _graphcore_available_and_importable, _habana_available_and_importable
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn


Expand Down Expand Up @@ -158,7 +158,7 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0
rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores")

if _lightning_graphcore_available():
if _graphcore_available_and_importable():
from lightning_graphcore import IPUAccelerator

num_ipus = trainer.num_devices if isinstance(trainer.accelerator, IPUAccelerator) else 0
Expand All @@ -168,7 +168,7 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
ipu_available = False
rank_zero_info(f"IPU available: {ipu_available}, using: {num_ipus} IPUs")

if _lightning_habana_available():
if _habana_available_and_importable():
from lightning_habana import HPUAccelerator

num_hpus = trainer.num_devices if isinstance(trainer.accelerator, HPUAccelerator) else 0
Expand All @@ -192,13 +192,13 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
if XLAAccelerator.is_available() and not isinstance(trainer.accelerator, XLAAccelerator):
rank_zero_warn("TPU available but not used. You can set it by doing `Trainer(accelerator='tpu')`.")

if _lightning_graphcore_available():
if _graphcore_available_and_importable():
from lightning_graphcore import IPUAccelerator

if IPUAccelerator.is_available() and not isinstance(trainer.accelerator, IPUAccelerator):
rank_zero_warn("IPU available but not used. You can set it by doing `Trainer(accelerator='ipu')`.")

if _lightning_habana_available():
if _habana_available_and_importable():
from lightning_habana import HPUAccelerator

if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator):
Expand Down
20 changes: 12 additions & 8 deletions src/lightning/pytorch/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,25 @@ def _try_import_module(module_name: str) -> bool:
try:
__import__(module_name)
return True
# added also AttributeError fro case of impoerts like pl.LightningModule
# Also on AttributeError for failed imports like pl.LightningModule
except (ImportError, AttributeError) as err:
rank_zero_warn(f"Import of {module_name} package failed for some compatibility issues: \n{err}")
rank_zero_warn(f"Import of {module_name} package failed for some compatibility issues:\n{err}")
return False


@functools.lru_cache(maxsize=1)
def _lightning_graphcore_available() -> bool:
_LIGHTNING_GRAPHCORE_AVAILABLE = RequirementCache("lightning-graphcore>=0.1.0")


def _graphcore_available_and_importable() -> bool:
# This is defined as a function instead of a constant to avoid circular imports, because `lightning_graphcore`
# also imports Lightning
return bool(RequirementCache("lightning-graphcore")) and _try_import_module("lightning_graphcore")
return bool(_LIGHTNING_GRAPHCORE_AVAILABLE) and _try_import_module("lightning_graphcore")


_LIGHTNING_HABANA_AVAILABLE = RequirementCache("lightning-habana>=1.2.0")


@functools.lru_cache(maxsize=1)
def _lightning_habana_available() -> bool:
def _habana_available_and_importable() -> bool:
# This is defined as a function instead of a constant to avoid circular imports, because `lightning_habana`
# also imports Lightning
return bool(RequirementCache("lightning-habana")) and _try_import_module("lightning_habana")
return bool(_LIGHTNING_HABANA_AVAILABLE) and _try_import_module("lightning_habana")
Loading

0 comments on commit 182c30b

Please sign in to comment.