From 079544a9029d85518a0a8fa40062eed846d3c624 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 30 Oct 2023 21:53:13 +0100 Subject: [PATCH] Rename PrecisionPlugin -> Precision (#18840) --- docs/source-pytorch/api_references.rst | 18 ++-- .../common/precision_expert.rst | 6 +- .../common/precision_intermediate.rst | 8 +- docs/source-pytorch/extensions/plugins.rst | 18 ++-- .../fabric/plugins/precision/bitsandbytes.py | 2 +- src/lightning/pytorch/CHANGELOG.md | 2 +- src/lightning/pytorch/_graveyard/__init__.py | 1 + src/lightning/pytorch/_graveyard/precision.py | 86 ++++++++++++++++++ src/lightning/pytorch/_graveyard/tpu.py | 20 ++--- .../pytorch/callbacks/throughput_monitor.py | 40 ++++----- src/lightning/pytorch/plugins/__init__.py | 39 ++++---- .../pytorch/plugins/precision/__init__.py | 37 ++++---- .../pytorch/plugins/precision/amp.py | 4 +- .../pytorch/plugins/precision/bitsandbytes.py | 4 +- .../pytorch/plugins/precision/deepspeed.py | 4 +- .../pytorch/plugins/precision/double.py | 4 +- .../pytorch/plugins/precision/fsdp.py | 26 +----- .../pytorch/plugins/precision/half.py | 4 +- .../{precision_plugin.py => precision.py} | 4 +- .../plugins/precision/transformer_engine.py | 4 +- .../pytorch/plugins/precision/xla.py | 4 +- src/lightning/pytorch/strategies/ddp.py | 4 +- src/lightning/pytorch/strategies/deepspeed.py | 4 +- src/lightning/pytorch/strategies/fsdp.py | 20 ++--- src/lightning/pytorch/strategies/parallel.py | 4 +- .../pytorch/strategies/single_device.py | 4 +- .../pytorch/strategies/single_xla.py | 18 ++-- src/lightning/pytorch/strategies/strategy.py | 12 +-- src/lightning/pytorch/strategies/xla.py | 18 ++-- .../connectors/accelerator_connector.py | 48 +++++----- .../connectors/checkpoint_connector.py | 4 +- src/lightning/pytorch/trainer/trainer.py | 4 +- .../collectives/test_torch_collective.py | 1 + tests/tests_pytorch/accelerators/test_cpu.py | 6 +- tests/tests_pytorch/accelerators/test_xla.py | 10 +-- .../deprecated_api/test_no_removal_version.py | 80 ++++++++++++++++- tests/tests_pytorch/graveyard/__init__.py | 0 .../tests_pytorch/graveyard/test_precision.py | 90 +++++++++++++++++++ .../tests_pytorch/models/test_ddp_fork_amp.py | 4 +- tests/tests_pytorch/models/test_hooks.py | 2 +- .../plugins/precision/test_all.py | 18 ++-- .../plugins/precision/test_amp.py | 6 +- .../plugins/precision/test_amp_integration.py | 4 +- .../precision/test_deepspeed_precision.py | 10 +-- .../plugins/precision/test_double.py | 8 +- .../plugins/precision/test_fsdp.py | 30 +++---- .../plugins/precision/test_half.py | 10 +-- .../precision/test_transformer_engine.py | 6 +- .../plugins/precision/test_xla.py | 24 ++--- .../tests_pytorch/plugins/test_amp_plugins.py | 26 +++--- tests/tests_pytorch/strategies/test_common.py | 10 +-- tests/tests_pytorch/strategies/test_ddp.py | 10 +-- .../strategies/test_deepspeed.py | 8 +- tests/tests_pytorch/strategies/test_fsdp.py | 12 +-- .../connectors/test_accelerator_connector.py | 52 +++++------ tests/tests_pytorch/utilities/test_imports.py | 2 +- 56 files changed, 566 insertions(+), 338 deletions(-) create mode 100644 src/lightning/pytorch/_graveyard/precision.py rename src/lightning/pytorch/plugins/precision/{precision_plugin.py => precision.py} (97%) create mode 100644 tests/tests_pytorch/graveyard/__init__.py create mode 100644 tests/tests_pytorch/graveyard/test_precision.py diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 2f76871192991..3542001a6e973 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -114,15 +114,15 @@ precision :nosignatures: :template: classtemplate.rst - DeepSpeedPrecisionPlugin - DoublePrecisionPlugin - HalfPrecisionPlugin - FSDPPrecisionPlugin - MixedPrecisionPlugin - PrecisionPlugin - XLAPrecisionPlugin - TransformerEnginePrecisionPlugin - BitsandbytesPrecisionPlugin + DeepSpeedPrecision + DoublePrecision + HalfPrecision + FSDPPrecision + MixedPrecision + Precision + XLAPrecision + TransformerEnginePrecision + BitsandbytesPrecision environments """""""""""" diff --git a/docs/source-pytorch/common/precision_expert.rst b/docs/source-pytorch/common/precision_expert.rst index a4502972dbb18..40cdb7d31ed1f 100644 --- a/docs/source-pytorch/common/precision_expert.rst +++ b/docs/source-pytorch/common/precision_expert.rst @@ -12,17 +12,17 @@ N-Bit Precision (Expert) Precision Plugins ***************** -You can also customize and pass your own Precision Plugin by subclassing the :class:`~lightning.pytorch.plugins.precision.precision_plugin.PrecisionPlugin` class. +You can also customize and pass your own Precision Plugin by subclassing the :class:`~lightning.pytorch.plugins.precision.precision.Precision` class. - Perform pre and post backward/optimizer step operations such as scaling gradients. - Provide context managers for forward, training_step, etc. .. code-block:: python - class CustomPrecisionPlugin(PrecisionPlugin): + class CustomPrecision(Precision): precision = "16-mixed" ... - trainer = Trainer(plugins=[CustomPrecisionPlugin()]) + trainer = Trainer(plugins=[CustomPrecision()]) diff --git a/docs/source-pytorch/common/precision_intermediate.rst b/docs/source-pytorch/common/precision_intermediate.rst index bfa957b498b3b..41025ab1e8b09 100644 --- a/docs/source-pytorch/common/precision_intermediate.rst +++ b/docs/source-pytorch/common/precision_intermediate.rst @@ -182,18 +182,18 @@ This is configurable via the dtype argument in the plugin. Quantizing the model will dramatically reduce the weight's memory requirements but may have a negative impact on the model's performance or runtime. -The :class:`~lightning.pytorch.plugins.precision.bitsandbytes.BitsandbytesPrecisionPlugin` automatically replaces the :class:`torch.nn.Linear` layers in your model with their BNB alternatives. +The :class:`~lightning.pytorch.plugins.precision.bitsandbytes.BitsandbytesPrecision` automatically replaces the :class:`torch.nn.Linear` layers in your model with their BNB alternatives. .. code-block:: python - from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin + from lightning.pytorch.plugins import BitsandbytesPrecision # this will pick out the compute dtype automatically, by default `bfloat16` - precision = BitsandbytesPrecisionPlugin("nf4-dq") + precision = BitsandbytesPrecision("nf4-dq") trainer = Trainer(plugins=precision) # Customize the dtype, or skip some modules - precision = BitsandbytesPrecisionPlugin("int8-training", dtype=torch.float16, ignore_modules={"lm_head"}) + precision = BitsandbytesPrecision("int8-training", dtype=torch.float16, ignore_modules={"lm_head"}) trainer = Trainer(plugins=precision) diff --git a/docs/source-pytorch/extensions/plugins.rst b/docs/source-pytorch/extensions/plugins.rst index 9c358af40b860..2384c1a1c5b31 100644 --- a/docs/source-pytorch/extensions/plugins.rst +++ b/docs/source-pytorch/extensions/plugins.rst @@ -52,15 +52,15 @@ The full list of built-in precision plugins is listed below. :nosignatures: :template: classtemplate.rst - DeepSpeedPrecisionPlugin - DoublePrecisionPlugin - HalfPrecisionPlugin - FSDPPrecisionPlugin - MixedPrecisionPlugin - PrecisionPlugin - XLAPrecisionPlugin - TransformerEnginePrecisionPlugin - BitsandbytesPrecisionPlugin + DeepSpeedPrecision + DoublePrecision + HalfPrecision + FSDPPrecision + MixedPrecision + Precision + XLAPrecision + TransformerEnginePrecision + BitsandbytesPrecision More information regarding precision with Lightning can be found :ref:`here ` diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index a1e1d6bc20b6a..ab4db3a4d6b71 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -61,7 +61,7 @@ class BitsandbytesPrecision(Precision): # TODO: we could implement optimizer replacement with # - Fabric: Add `Precision.convert_optimizer` from `Strategy.setup_optimizer` - # - Trainer: Use `PrecisionPlugin.connect` + # - Trainer: Use `Precision.connect` def __init__( self, diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index eed52694a7282..88987e702d831 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated -- +- Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840)) ### Removed diff --git a/src/lightning/pytorch/_graveyard/__init__.py b/src/lightning/pytorch/_graveyard/__init__.py index 1baadda1e0e1f..3403bbd52aeec 100644 --- a/src/lightning/pytorch/_graveyard/__init__.py +++ b/src/lightning/pytorch/_graveyard/__init__.py @@ -14,4 +14,5 @@ import lightning.pytorch._graveyard._torchmetrics import lightning.pytorch._graveyard.hpu import lightning.pytorch._graveyard.ipu +import lightning.pytorch._graveyard.precision import lightning.pytorch._graveyard.tpu # noqa: F401 diff --git a/src/lightning/pytorch/_graveyard/precision.py b/src/lightning/pytorch/_graveyard/precision.py new file mode 100644 index 0000000000000..ee2d590e58134 --- /dev/null +++ b/src/lightning/pytorch/_graveyard/precision.py @@ -0,0 +1,86 @@ +import sys +from typing import TYPE_CHECKING, Any, Literal, Optional + +import lightning.pytorch as pl +from lightning.fabric.utilities.rank_zero import rank_zero_deprecation +from lightning.pytorch.plugins.precision import ( + BitsandbytesPrecision, + DeepSpeedPrecision, + DoublePrecision, + FSDPPrecision, + HalfPrecision, + MixedPrecision, + Precision, + TransformerEnginePrecision, + XLAPrecision, +) + +if TYPE_CHECKING: + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + + +def _patch_sys_modules() -> None: + sys.modules["lightning.pytorch.plugins.precision.precision_plugin"] = sys.modules[ + "lightning.pytorch.plugins.precision.precision" + ] + + +class FSDPMixedPrecisionPlugin(FSDPPrecision): + """AMP for Fully Sharded Data Parallel (FSDP) Training. + + .. deprecated:: Use :class:`FSDPPrecision` instead. + + .. warning:: This is an :ref:`experimental ` feature. + + """ + + def __init__( + self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional["ShardedGradScaler"] = None + ) -> None: + rank_zero_deprecation( + f"The `{type(self).__name__}` is deprecated." + " Use `lightning.pytorch.plugins.precision.FSDPPrecision` instead." + ) + super().__init__(precision=precision, scaler=scaler) + + +def _create_class(deprecated_name: str, new_class: type) -> type: + def init(self: type, *args: Any, **kwargs: Any) -> None: + rank_zero_deprecation( + f"The `{deprecated_name}` is deprecated." + f" Use `lightning.pytorch.plugins.precision.{new_class.__name__}` instead." + ) + super(type(self), self).__init__(*args, **kwargs) + + return type(deprecated_name, (new_class,), {"__init__": init}) + + +def _patch_classes() -> None: + classes_map = ( + # module name, old name, new class + ("bitsandbytes", "BitsandbytesPrecisionPlugin", BitsandbytesPrecision), + ("deepspeed", "DeepSpeedPrecisionPlugin", DeepSpeedPrecision), + ("double", "DoublePrecisionPlugin", DoublePrecision), + ("fsdp", "FSDPPrecisionPlugin", FSDPPrecision), + ("fsdp", "FSDPMixedPrecisionPlugin", FSDPPrecision), + ("half", "HalfPrecisionPlugin", HalfPrecision), + ("amp", "MixedPrecisionPlugin", MixedPrecision), + ("precision", "PrecisionPlugin", Precision), + ("transformer_engine", "TransformerEnginePrecisionPlugin", TransformerEnginePrecision), + ("xla", "XLAPrecisionPlugin", XLAPrecision), + ) + + for module_name, deprecated_name, new_class in classes_map: + deprecated_class = _create_class(deprecated_name, new_class) + setattr(getattr(pl.plugins.precision, module_name), deprecated_name, deprecated_class) + setattr(pl.plugins.precision, deprecated_name, deprecated_class) + setattr(pl.plugins, deprecated_name, deprecated_class) + + # special treatment for `FSDPMixedPrecisionPlugin` because it has a different signature + setattr(pl.plugins.precision.fsdp, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin) + setattr(pl.plugins.precision, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin) + setattr(pl.plugins, "FSDPMixedPrecisionPlugin", FSDPMixedPrecisionPlugin) + + +_patch_sys_modules() +_patch_classes() diff --git a/src/lightning/pytorch/_graveyard/tpu.py b/src/lightning/pytorch/_graveyard/tpu.py index eb1d493645c45..34008e3ee556d 100644 --- a/src/lightning/pytorch/_graveyard/tpu.py +++ b/src/lightning/pytorch/_graveyard/tpu.py @@ -18,7 +18,7 @@ import lightning.pytorch as pl from lightning.fabric.strategies import _StrategyRegistry from lightning.pytorch.accelerators.xla import XLAAccelerator -from lightning.pytorch.plugins.precision import XLAPrecisionPlugin +from lightning.pytorch.plugins.precision import XLAPrecision from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation @@ -63,47 +63,47 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) -class TPUPrecisionPlugin(XLAPrecisionPlugin): +class TPUPrecisionPlugin(XLAPrecision): """Legacy class. - Use :class:`~lightning.pytorch.plugins.precision.xla.XLAPrecisionPlugin` instead. + Use :class:`~lightning.pytorch.plugins.precision.xla.XLAPrecision` instead. """ def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( - "The `TPUPrecisionPlugin` class is deprecated. Use `lightning.pytorch.plugins.precision.XLAPrecisionPlugin`" + "The `TPUPrecisionPlugin` class is deprecated. Use `lightning.pytorch.plugins.precision.XLAPrecision`" " instead." ) super().__init__(precision="32-true") -class TPUBf16PrecisionPlugin(XLAPrecisionPlugin): +class TPUBf16PrecisionPlugin(XLAPrecision): """Legacy class. - Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead. + Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecision` instead. """ def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( "The `TPUBf16PrecisionPlugin` class is deprecated. Use" - " `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead." + " `lightning.pytorch.plugins.precision.XLAPrecision` instead." ) super().__init__(precision="bf16-true") -class XLABf16PrecisionPlugin(XLAPrecisionPlugin): +class XLABf16PrecisionPlugin(XLAPrecision): """Legacy class. - Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecisionPlugin` instead. + Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLAPrecision` instead. """ def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( "The `XLABf16PrecisionPlugin` class is deprecated. Use" - " `lightning.pytorch.plugins.precision.XLAPrecisionPlugin` instead." + " `lightning.pytorch.plugins.precision.XLAPrecision` instead." ) super().__init__(precision="bf16-true") diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index ec32851a18e3a..79d3e67913d7b 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -16,21 +16,21 @@ import torch -from lightning.fabric.plugins import Precision +from lightning.fabric.plugins import Precision as FabricPrecision from lightning.fabric.utilities.throughput import Throughput, get_available_flops from lightning.fabric.utilities.throughput import _plugin_to_compute_dtype as fabric_plugin_to_compute_dtype from lightning.pytorch.callbacks import Callback from lightning.pytorch.plugins import ( - DoublePrecisionPlugin, - FSDPPrecisionPlugin, - MixedPrecisionPlugin, - PrecisionPlugin, - TransformerEnginePrecisionPlugin, + BitsandbytesPrecision, + DeepSpeedPrecision, + DoublePrecision, + FSDPPrecision, + HalfPrecision, + MixedPrecision, + Precision, + TransformerEnginePrecision, + XLAPrecision, ) -from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecisionPlugin -from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin -from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin -from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn @@ -227,24 +227,24 @@ def on_predict_batch_end( self._compute(trainer, iter_num) -def _plugin_to_compute_dtype(plugin: Union[Precision, PrecisionPlugin]) -> torch.dtype: +def _plugin_to_compute_dtype(plugin: Union[FabricPrecision, Precision]) -> torch.dtype: # TODO: integrate this into the precision plugins - if not isinstance(plugin, PrecisionPlugin): + if not isinstance(plugin, Precision): return fabric_plugin_to_compute_dtype(plugin) - if isinstance(plugin, BitsandbytesPrecisionPlugin): + if isinstance(plugin, BitsandbytesPrecision): return plugin.dtype - if isinstance(plugin, HalfPrecisionPlugin): + if isinstance(plugin, HalfPrecision): return plugin._desired_input_dtype - if isinstance(plugin, MixedPrecisionPlugin): + if isinstance(plugin, MixedPrecision): return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half - if isinstance(plugin, DoublePrecisionPlugin): + if isinstance(plugin, DoublePrecision): return torch.double - if isinstance(plugin, (XLAPrecisionPlugin, DeepSpeedPrecisionPlugin)): + if isinstance(plugin, (XLAPrecision, DeepSpeedPrecision)): return plugin._desired_dtype - if isinstance(plugin, TransformerEnginePrecisionPlugin): + if isinstance(plugin, TransformerEnginePrecision): return torch.int8 - if isinstance(plugin, FSDPPrecisionPlugin): + if isinstance(plugin, FSDPPrecision): return plugin.mixed_precision_config.reduce_dtype or torch.float32 - if isinstance(plugin, PrecisionPlugin): + if isinstance(plugin, Precision): return torch.float32 raise NotImplementedError(plugin) diff --git a/src/lightning/pytorch/plugins/__init__.py b/src/lightning/pytorch/plugins/__init__.py index 6e1d527907dcf..48fef3c136e62 100644 --- a/src/lightning/pytorch/plugins/__init__.py +++ b/src/lightning/pytorch/plugins/__init__.py @@ -3,17 +3,17 @@ from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment, TorchCheckpointIO, XLACheckpointIO from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm -from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin -from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecisionPlugin -from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin -from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin -from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, FSDPPrecisionPlugin -from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin -from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin -from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecisionPlugin -from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin +from lightning.pytorch.plugins.precision.amp import MixedPrecision +from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecision +from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision +from lightning.pytorch.plugins.precision.double import DoublePrecision +from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision +from lightning.pytorch.plugins.precision.half import HalfPrecision +from lightning.pytorch.plugins.precision.precision import Precision +from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision +from lightning.pytorch.plugins.precision.xla import XLAPrecision -PLUGIN = Union[PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync] +PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO, LayerSync] PLUGIN_INPUT = Union[PLUGIN, str] __all__ = [ @@ -21,16 +21,15 @@ "CheckpointIO", "TorchCheckpointIO", "XLACheckpointIO", - "BitsandbytesPrecisionPlugin", - "DeepSpeedPrecisionPlugin", - "DoublePrecisionPlugin", - "HalfPrecisionPlugin", - "MixedPrecisionPlugin", - "PrecisionPlugin", - "TransformerEnginePrecisionPlugin", - "FSDPMixedPrecisionPlugin", - "FSDPPrecisionPlugin", - "XLAPrecisionPlugin", + "BitsandbytesPrecision", + "DeepSpeedPrecision", + "DoublePrecision", + "HalfPrecision", + "MixedPrecision", + "Precision", + "TransformerEnginePrecision", + "FSDPPrecision", + "XLAPrecision", "LayerSync", "TorchSyncBatchNorm", ] diff --git a/src/lightning/pytorch/plugins/precision/__init__.py b/src/lightning/pytorch/plugins/precision/__init__.py index b4d5e8cf5eb26..2f6a257333505 100644 --- a/src/lightning/pytorch/plugins/precision/__init__.py +++ b/src/lightning/pytorch/plugins/precision/__init__.py @@ -11,25 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin -from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecisionPlugin -from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin -from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin -from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, FSDPPrecisionPlugin -from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin -from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin -from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecisionPlugin -from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin +from lightning.pytorch.plugins.precision.amp import MixedPrecision +from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecision +from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision +from lightning.pytorch.plugins.precision.double import DoublePrecision +from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision +from lightning.pytorch.plugins.precision.half import HalfPrecision +from lightning.pytorch.plugins.precision.precision import Precision +from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision +from lightning.pytorch.plugins.precision.xla import XLAPrecision __all__ = [ - "BitsandbytesPrecisionPlugin", - "DeepSpeedPrecisionPlugin", - "DoublePrecisionPlugin", - "FSDPMixedPrecisionPlugin", - "FSDPPrecisionPlugin", - "HalfPrecisionPlugin", - "MixedPrecisionPlugin", - "PrecisionPlugin", - "TransformerEnginePrecisionPlugin", - "XLAPrecisionPlugin", + "BitsandbytesPrecision", + "DeepSpeedPrecision", + "DoublePrecision", + "FSDPPrecision", + "HalfPrecision", + "MixedPrecision", + "Precision", + "TransformerEnginePrecision", + "XLAPrecision", ] diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 4eb79b87c031f..9c24196b99e77 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -20,12 +20,12 @@ from lightning.fabric.accelerators.cuda import _patch_cuda_is_available from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling from lightning.fabric.utilities.types import Optimizable -from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin +from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.utilities import GradClipAlgorithmType from lightning.pytorch.utilities.exceptions import MisconfigurationException -class MixedPrecisionPlugin(PrecisionPlugin): +class MixedPrecision(Precision): """Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``. Args: diff --git a/src/lightning/pytorch/plugins/precision/bitsandbytes.py b/src/lightning/pytorch/plugins/precision/bitsandbytes.py index aaf51079f63a9..62acc7bf77c8d 100644 --- a/src/lightning/pytorch/plugins/precision/bitsandbytes.py +++ b/src/lightning/pytorch/plugins/precision/bitsandbytes.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from lightning.fabric.plugins.precision.bitsandbytes import BitsandbytesPrecision as FabricBNBPrecision -from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin +from lightning.pytorch.plugins.precision.precision import Precision -class BitsandbytesPrecisionPlugin(PrecisionPlugin, FabricBNBPrecision): +class BitsandbytesPrecision(Precision, FabricBNBPrecision): """Plugin for quantizing weights with `bitsandbytes `__. .. warning:: This is an :ref:`experimental ` feature. diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 7fa409f1d6b5f..26bf1d2734833 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -25,7 +25,7 @@ from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager from lightning.fabric.utilities.types import Steppable -from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin +from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.utilities import GradClipAlgorithmType from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden @@ -37,7 +37,7 @@ warning_cache = WarningCache() -class DeepSpeedPrecisionPlugin(PrecisionPlugin): +class DeepSpeedPrecision(Precision): """Precision plugin for DeepSpeed integration. .. warning:: This is an :ref:`experimental ` feature. diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index 72b5d0d6da2bf..b65e87290b835 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -22,11 +22,11 @@ import lightning.pytorch as pl from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin -from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin +from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation -class DoublePrecisionPlugin(PrecisionPlugin): +class DoublePrecision(Precision): """Plugin for training with double (``torch.float64``) precision.""" precision: Literal["64-true"] = "64-true" diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 5a124ab6b676d..46d11d2449acf 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Optional import torch from lightning_utilities import apply_to_collection @@ -23,9 +23,8 @@ from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 -from lightning.fabric.utilities.rank_zero import rank_zero_deprecation from lightning.fabric.utilities.types import Optimizable -from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin +from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.utilities.exceptions import MisconfigurationException if TYPE_CHECKING: @@ -33,7 +32,7 @@ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -class FSDPPrecisionPlugin(PrecisionPlugin): +class FSDPPrecision(Precision): """Precision plugin for training with Fully Sharded Data Parallel (FSDP). .. warning:: This is an :ref:`experimental ` feature. @@ -170,22 +169,3 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) - - -class FSDPMixedPrecisionPlugin(FSDPPrecisionPlugin): - """AMP for Fully Sharded Data Parallel (FSDP) Training. - - .. deprecated:: Use :class:`FSDPPrecisionPlugin` instead. - - .. warning:: This is an :ref:`experimental ` feature. - - """ - - def __init__( - self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional["ShardedGradScaler"] = None - ) -> None: - rank_zero_deprecation( - f"The `{type(self).__name__}` is deprecated." - " Use `lightning.pytorch.plugins.precision.FSDPPrecisionPlugin` instead." - ) - super().__init__(precision=precision, scaler=scaler) diff --git a/src/lightning/pytorch/plugins/precision/half.py b/src/lightning/pytorch/plugins/precision/half.py index a7ef8c82afe86..80d633ba740f1 100644 --- a/src/lightning/pytorch/plugins/precision/half.py +++ b/src/lightning/pytorch/plugins/precision/half.py @@ -20,10 +20,10 @@ from torch.nn import Module from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager -from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin +from lightning.pytorch.plugins.precision.precision import Precision -class HalfPrecisionPlugin(PrecisionPlugin): +class HalfPrecision(Precision): """Plugin for training with half precision. Args: diff --git a/src/lightning/pytorch/plugins/precision/precision_plugin.py b/src/lightning/pytorch/plugins/precision/precision.py similarity index 97% rename from src/lightning/pytorch/plugins/precision/precision_plugin.py rename to src/lightning/pytorch/plugins/precision/precision.py index 15b28186b896a..b10afeb038c71 100644 --- a/src/lightning/pytorch/plugins/precision/precision_plugin.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -28,7 +28,7 @@ from lightning.pytorch.utilities import GradClipAlgorithmType -class PrecisionPlugin(FabricPrecision, CheckpointHooks): +class Precision(FabricPrecision, CheckpointHooks): """Base class for all plugins handling the precision-specific parts of the training. The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. @@ -98,7 +98,7 @@ def _wrap_closure( hook is called. The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is - consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly. + consistent with the ``Precision`` subclasses that cannot pass ``optimizer.step(closure)`` directly. """ closure_result = closure() diff --git a/src/lightning/pytorch/plugins/precision/transformer_engine.py b/src/lightning/pytorch/plugins/precision/transformer_engine.py index 858ef34ecf9a5..4ecbc64d7be0c 100644 --- a/src/lightning/pytorch/plugins/precision/transformer_engine.py +++ b/src/lightning/pytorch/plugins/precision/transformer_engine.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision as FabricTEPrecision -from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin +from lightning.pytorch.plugins.precision.precision import Precision -class TransformerEnginePrecisionPlugin(PrecisionPlugin, FabricTEPrecision): +class TransformerEnginePrecision(Precision, FabricTEPrecision): """Plugin for training with fp8 precision via nvidia's `Transformer Engine `__. diff --git a/src/lightning/pytorch/plugins/precision/xla.py b/src/lightning/pytorch/plugins/precision/xla.py index 00c3db5f9022d..bcdb427b14c90 100644 --- a/src/lightning/pytorch/plugins/precision/xla.py +++ b/src/lightning/pytorch/plugins/precision/xla.py @@ -22,11 +22,11 @@ from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins.precision.xla import _PRECISION_INPUT from lightning.fabric.utilities.types import Optimizable -from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin +from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.utilities.exceptions import MisconfigurationException -class XLAPrecisionPlugin(PrecisionPlugin): +class XLAPrecision(Precision): """Plugin for training with XLA. Args: diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 9bd24bc657b13..d3500a385b743 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -40,7 +40,7 @@ from lightning.fabric.utilities.types import ReduceOp from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.overrides.distributed import _register_ddp_comm_hook, _sync_module_states, prepare_for_backward -from lightning.pytorch.plugins.precision import PrecisionPlugin +from lightning.pytorch.plugins.precision import Precision from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.strategy import TBroadcast, _ForwardRedirection @@ -72,7 +72,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[PrecisionPlugin] = None, + precision_plugin: Optional[Precision] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[Callable] = None, ddp_comm_wrapper: Optional[Callable] = None, diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index fab3a3615bc48..380f8c74e449f 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -39,7 +39,7 @@ from lightning.fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau from lightning.pytorch.accelerators.cuda import CUDAAccelerator from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers -from lightning.pytorch.plugins.precision import PrecisionPlugin +from lightning.pytorch.plugins.precision import Precision from lightning.pytorch.strategies.ddp import DDPStrategy from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities import GradClipAlgorithmType @@ -114,7 +114,7 @@ def __init__( contiguous_memory_optimization: bool = False, synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, - precision_plugin: Optional[PrecisionPlugin] = None, + precision_plugin: Optional[Precision] = None, process_group_backend: Optional[str] = None, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index ca85d784bad52..907497d218268 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -60,8 +60,8 @@ from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH, ReduceOp from lightning.pytorch.core.optimizer import LightningOptimizer -from lightning.pytorch.plugins.precision import PrecisionPlugin -from lightning.pytorch.plugins.precision.fsdp import FSDPPrecisionPlugin +from lightning.pytorch.plugins.precision import Precision +from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.strategy import TBroadcast @@ -144,7 +144,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[PrecisionPlugin] = None, + precision_plugin: Optional[Precision] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, cpu_offload: Union[bool, "CPUOffload", None] = None, @@ -205,23 +205,23 @@ def mixed_precision_config(self) -> Optional["MixedPrecision"]: if self.mixed_precision: return self.mixed_precision plugin = self.precision_plugin - if isinstance(plugin, FSDPPrecisionPlugin): + if isinstance(plugin, FSDPPrecision): return plugin.mixed_precision_config return None @property # type: ignore[override] - def precision_plugin(self) -> FSDPPrecisionPlugin: + def precision_plugin(self) -> FSDPPrecision: plugin = self._precision_plugin if plugin is not None: - assert isinstance(plugin, FSDPPrecisionPlugin) + assert isinstance(plugin, FSDPPrecision) return plugin - return FSDPPrecisionPlugin("32-true") + return FSDPPrecision("32-true") @precision_plugin.setter - def precision_plugin(self, precision_plugin: Optional[FSDPPrecisionPlugin]) -> None: - if precision_plugin is not None and not isinstance(precision_plugin, FSDPPrecisionPlugin): + def precision_plugin(self, precision_plugin: Optional[FSDPPrecision]) -> None: + if precision_plugin is not None and not isinstance(precision_plugin, FSDPPrecision): raise TypeError( - f"The FSDP strategy can only work with the `FSDPPrecisionPlugin` plugin, found {precision_plugin}" + f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision_plugin}" ) self._precision_plugin = precision_plugin diff --git a/src/lightning/pytorch/strategies/parallel.py b/src/lightning/pytorch/strategies/parallel.py index 794ed836f0cd2..6a70d58f4f1d8 100644 --- a/src/lightning/pytorch/strategies/parallel.py +++ b/src/lightning/pytorch/strategies/parallel.py @@ -22,7 +22,7 @@ from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment from lightning.fabric.utilities.distributed import ReduceOp, _all_gather_ddp_if_available from lightning.pytorch.plugins import LayerSync -from lightning.pytorch.plugins.precision import PrecisionPlugin +from lightning.pytorch.plugins.precision import Precision from lightning.pytorch.strategies.strategy import Strategy @@ -35,7 +35,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[PrecisionPlugin] = None, + precision_plugin: Optional[Precision] = None, ): super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.parallel_devices = parallel_devices diff --git a/src/lightning/pytorch/strategies/single_device.py b/src/lightning/pytorch/strategies/single_device.py index 6736201dc8ace..e54bd12227104 100644 --- a/src/lightning/pytorch/strategies/single_device.py +++ b/src/lightning/pytorch/strategies/single_device.py @@ -22,7 +22,7 @@ from lightning.fabric.plugins import CheckpointIO from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.utilities.types import _DEVICE -from lightning.pytorch.plugins.precision import PrecisionPlugin +from lightning.pytorch.plugins.precision import Precision from lightning.pytorch.strategies.strategy import Strategy, TBroadcast @@ -36,7 +36,7 @@ def __init__( device: _DEVICE = "cpu", accelerator: pl.accelerators.accelerator.Accelerator | None = None, checkpoint_io: CheckpointIO | None = None, - precision_plugin: PrecisionPlugin | None = None, + precision_plugin: Precision | None = None, ): super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) if not isinstance(device, torch.device): diff --git a/src/lightning/pytorch/strategies/single_xla.py b/src/lightning/pytorch/strategies/single_xla.py index aa6c2a92adf83..40d44cd656cdc 100644 --- a/src/lightning/pytorch/strategies/single_xla.py +++ b/src/lightning/pytorch/strategies/single_xla.py @@ -22,7 +22,7 @@ from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.utilities.types import _DEVICE from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO -from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin +from lightning.pytorch.plugins.precision.xla import XLAPrecision from lightning.pytorch.strategies.single_device import SingleDeviceStrategy from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters @@ -35,7 +35,7 @@ def __init__( device: _DEVICE, accelerator: Optional["pl.accelerators.Accelerator"] = None, checkpoint_io: Optional[Union[XLACheckpointIO, _WrappingCheckpointIO]] = None, - precision_plugin: Optional[XLAPrecisionPlugin] = None, + precision_plugin: Optional[XLAPrecision] = None, debug: bool = False, ): if not _XLA_AVAILABLE: @@ -68,19 +68,17 @@ def checkpoint_io(self, io: Optional[Union[XLACheckpointIO, _WrappingCheckpointI self._checkpoint_io = io @property # type: ignore[override] - def precision_plugin(self) -> XLAPrecisionPlugin: + def precision_plugin(self) -> XLAPrecision: plugin = self._precision_plugin if plugin is not None: - assert isinstance(plugin, XLAPrecisionPlugin) + assert isinstance(plugin, XLAPrecision) return plugin - return XLAPrecisionPlugin() + return XLAPrecision() @precision_plugin.setter - def precision_plugin(self, precision_plugin: Optional[XLAPrecisionPlugin]) -> None: - if precision_plugin is not None and not isinstance(precision_plugin, XLAPrecisionPlugin): - raise TypeError( - f"The XLA strategy can only work with the `XLAPrecisionPlugin` plugin, found {precision_plugin}" - ) + def precision_plugin(self, precision_plugin: Optional[XLAPrecision]) -> None: + if precision_plugin is not None and not isinstance(precision_plugin, XLAPrecision): + raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision_plugin}") self._precision_plugin = precision_plugin def setup(self, trainer: "pl.Trainer") -> None: diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 3b6467d5bd886..5ea8d19310f0c 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -33,7 +33,7 @@ from lightning.pytorch.core.optimizer import LightningOptimizer, _init_optimizers_and_lr_schedulers from lightning.pytorch.plugins import TorchCheckpointIO from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO -from lightning.pytorch.plugins.precision import PrecisionPlugin +from lightning.pytorch.plugins.precision import Precision from lightning.pytorch.strategies.launchers.launcher import _Launcher from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.types import STEP_OUTPUT, LRSchedulerConfig @@ -51,11 +51,11 @@ def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[PrecisionPlugin] = None, + precision_plugin: Optional[Precision] = None, ) -> None: self._accelerator: Optional["pl.accelerators.Accelerator"] = accelerator self._checkpoint_io: Optional[CheckpointIO] = checkpoint_io - self._precision_plugin: Optional[PrecisionPlugin] = None + self._precision_plugin: Optional[Precision] = None # Call the precision setter for input validation self.precision_plugin = precision_plugin # type: ignore[assignment] self._lightning_module: Optional[pl.LightningModule] = None @@ -92,11 +92,11 @@ def checkpoint_io(self, io: CheckpointIO) -> None: self._checkpoint_io = io @property - def precision_plugin(self) -> PrecisionPlugin: - return self._precision_plugin if self._precision_plugin is not None else PrecisionPlugin() + def precision_plugin(self) -> Precision: + return self._precision_plugin if self._precision_plugin is not None else Precision() @precision_plugin.setter - def precision_plugin(self, precision_plugin: Optional[PrecisionPlugin]) -> None: + def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: self._precision_plugin = precision_plugin @property diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index 8ff05f5e42d4a..00069fe6de3a8 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -26,7 +26,7 @@ from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.types import _PATH, ReduceOp -from lightning.pytorch.plugins import XLAPrecisionPlugin +from lightning.pytorch.plugins import XLAPrecision from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO from lightning.pytorch.strategies.ddp import DDPStrategy from lightning.pytorch.strategies.launchers.xla import _XLALauncher @@ -50,7 +50,7 @@ def __init__( accelerator: Optional["pl.accelerators.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[Union[XLACheckpointIO, _WrappingCheckpointIO]] = None, - precision_plugin: Optional[XLAPrecisionPlugin] = None, + precision_plugin: Optional[XLAPrecision] = None, debug: bool = False, sync_module_states: bool = True, **_: Any, @@ -84,19 +84,17 @@ def checkpoint_io(self, io: Optional[Union[XLACheckpointIO, _WrappingCheckpointI self._checkpoint_io = io @property # type: ignore[override] - def precision_plugin(self) -> XLAPrecisionPlugin: + def precision_plugin(self) -> XLAPrecision: plugin = self._precision_plugin if plugin is not None: - assert isinstance(plugin, XLAPrecisionPlugin) + assert isinstance(plugin, XLAPrecision) return plugin - return XLAPrecisionPlugin() + return XLAPrecision() @precision_plugin.setter - def precision_plugin(self, precision_plugin: Optional[XLAPrecisionPlugin]) -> None: - if precision_plugin is not None and not isinstance(precision_plugin, XLAPrecisionPlugin): - raise TypeError( - f"The XLA strategy can only work with the `XLAPrecisionPlugin` plugin, found {precision_plugin}" - ) + def precision_plugin(self, precision_plugin: Optional[XLAPrecision]) -> None: + if precision_plugin is not None and not isinstance(precision_plugin, XLAPrecision): + raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision_plugin}") self._precision_plugin = precision_plugin @property diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index ec91b1a340ed4..f4e2a1ced7e93 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -38,14 +38,14 @@ from lightning.pytorch.plugins import ( PLUGIN_INPUT, CheckpointIO, - DeepSpeedPrecisionPlugin, - DoublePrecisionPlugin, - FSDPPrecisionPlugin, - HalfPrecisionPlugin, - MixedPrecisionPlugin, - PrecisionPlugin, - TransformerEnginePrecisionPlugin, - XLAPrecisionPlugin, + DeepSpeedPrecision, + DoublePrecision, + FSDPPrecision, + HalfPrecision, + MixedPrecision, + Precision, + TransformerEnginePrecision, + XLAPrecision, ) from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm from lightning.pytorch.strategies import ( @@ -130,7 +130,7 @@ def __init__( self._strategy_flag: Union[Strategy, str] = "auto" self._accelerator_flag: Union[Accelerator, str] = "auto" self._precision_flag: _PRECISION_INPUT_STR = "32-true" - self._precision_plugin_flag: Optional[PrecisionPlugin] = None + self._precision_plugin_flag: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None self._parallel_devices: List[Union[int, torch.device, str]] = [] self._layer_sync: Optional[LayerSync] = TorchSyncBatchNorm() if sync_batchnorm else None @@ -242,9 +242,9 @@ def _check_config_and_set_final_flags( if plugins: plugins_flags_types: Dict[str, int] = Counter() for plugin in plugins: - if isinstance(plugin, PrecisionPlugin): + if isinstance(plugin, Precision): self._precision_plugin_flag = plugin - plugins_flags_types[PrecisionPlugin.__name__] += 1 + plugins_flags_types[Precision.__name__] += 1 elif isinstance(plugin, CheckpointIO): self.checkpoint_io = plugin plugins_flags_types[CheckpointIO.__name__] += 1 @@ -261,7 +261,7 @@ def _check_config_and_set_final_flags( plugins_flags_types[TorchSyncBatchNorm.__name__] += 1 else: raise MisconfigurationException( - f"Found invalid type for plugin {plugin}. Expected one of: PrecisionPlugin, " + f"Found invalid type for plugin {plugin}. Expected one of: Precision, " "CheckpointIO, ClusterEnviroment, or LayerSync." ) @@ -272,7 +272,7 @@ def _check_config_and_set_final_flags( " Expected one value for each type at most." ) - if plugins_flags_types.get(PrecisionPlugin.__name__) and precision_flag is not None: + if plugins_flags_types.get(Precision.__name__) and precision_flag is not None: raise ValueError( f"Received both `precision={precision_flag}` and `plugins={self._precision_plugin_flag}`." f" Choose one." @@ -509,9 +509,9 @@ def _init_strategy(self) -> None: else: self.strategy = self._strategy_flag - def _check_and_init_precision(self) -> PrecisionPlugin: + def _check_and_init_precision(self) -> Precision: self._validate_precision_choice() - if isinstance(self._precision_plugin_flag, PrecisionPlugin): + if isinstance(self._precision_plugin_flag, Precision): return self._precision_plugin_flag if _graphcore_available_and_importable(): @@ -537,21 +537,21 @@ def _check_and_init_precision(self) -> PrecisionPlugin: return ColossalAIPrecisionPlugin(self._precision_flag) if isinstance(self.strategy, (SingleDeviceXLAStrategy, XLAStrategy)): - return XLAPrecisionPlugin(self._precision_flag) # type: ignore + return XLAPrecision(self._precision_flag) # type: ignore if isinstance(self.strategy, DeepSpeedStrategy): - return DeepSpeedPrecisionPlugin(self._precision_flag) # type: ignore[arg-type] + return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type] if isinstance(self.strategy, FSDPStrategy): - return FSDPPrecisionPlugin(self._precision_flag) # type: ignore[arg-type] + return FSDPPrecision(self._precision_flag) # type: ignore[arg-type] if self._precision_flag in ("16-true", "bf16-true"): - return HalfPrecisionPlugin(self._precision_flag) # type: ignore + return HalfPrecision(self._precision_flag) # type: ignore if self._precision_flag == "32-true": - return PrecisionPlugin() + return Precision() if self._precision_flag == "64-true": - return DoublePrecisionPlugin() + return DoublePrecision() if self._precision_flag == "transformer-engine": - return TransformerEnginePrecisionPlugin(dtype=torch.bfloat16) + return TransformerEnginePrecision(dtype=torch.bfloat16) if self._precision_flag == "transformer-engine-float16": - return TransformerEnginePrecisionPlugin(dtype=torch.float16) + return TransformerEnginePrecision(dtype=torch.float16) if self._precision_flag == "16-mixed" and self._accelerator_flag == "cpu": rank_zero_warn( @@ -565,7 +565,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin: f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" - return MixedPrecisionPlugin(self._precision_flag, device) # type: ignore[arg-type] + return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 19a73532dc210..a2f983d6f9cc7 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -26,7 +26,7 @@ from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem from lightning.fabric.utilities.types import _PATH from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.plugins.precision import MixedPrecision from lightning.pytorch.trainer import call from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -303,7 +303,7 @@ def restore_precision_plugin_state(self) -> None: prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__qualname__]) # old checkpoints compatibility - if "native_amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, MixedPrecisionPlugin): + if "native_amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, MixedPrecision): prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"]) def restore_callbacks(self) -> None: diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index d0dcb8f437558..79070a3eb43bd 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -47,7 +47,7 @@ from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.loops.utilities import _parse_loop_limits, _reset_progress -from lightning.pytorch.plugins import PLUGIN_INPUT, PrecisionPlugin +from lightning.pytorch.plugins import PLUGIN_INPUT, Precision from lightning.pytorch.profilers import Profiler from lightning.pytorch.strategies import ParallelStrategy, Strategy from lightning.pytorch.trainer import call, setup @@ -1141,7 +1141,7 @@ def strategy(self) -> Strategy: return self._accelerator_connector.strategy @property - def precision_plugin(self) -> PrecisionPlugin: + def precision_plugin(self) -> Precision: return self.strategy.precision_plugin @property diff --git a/tests/tests_fabric/plugins/collectives/test_torch_collective.py b/tests/tests_fabric/plugins/collectives/test_torch_collective.py index b5ceeaa2771df..2f0edbcc1c3ad 100644 --- a/tests/tests_fabric/plugins/collectives/test_torch_collective.py +++ b/tests/tests_fabric/plugins/collectives/test_torch_collective.py @@ -231,6 +231,7 @@ def _test_distributed_collectives_fn(strategy, collective): @skip_distributed_unavailable +@pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( "n", [1, pytest.param(2, marks=[RunIf(skip_windows=True), pytest.mark.xfail(raises=TimeoutError, strict=False)])] ) diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index a75b20b06c7be..59e3e673dd7aa 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -10,7 +10,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CPUAccelerator from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin +from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.strategies import SingleDeviceStrategy from tests_pytorch.helpers.runif import RunIf @@ -19,7 +19,7 @@ def test_restore_checkpoint_after_pre_setup_default(): """Assert default for restore_checkpoint_after_setup is False.""" plugin = SingleDeviceStrategy( - accelerator=CPUAccelerator(), device=torch.device("cpu"), precision_plugin=PrecisionPlugin() + accelerator=CPUAccelerator(), device=torch.device("cpu"), precision_plugin=Precision() ) assert not plugin.restore_checkpoint_after_setup @@ -66,7 +66,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: plugin = TestPlugin( accelerator=CPUAccelerator(), - precision_plugin=PrecisionPlugin(), + precision_plugin=Precision(), device=torch.device("cpu"), checkpoint_io=TorchCheckpointIO(), ) diff --git a/tests/tests_pytorch/accelerators/test_xla.py b/tests/tests_pytorch/accelerators/test_xla.py index 77db6c38cbc87..4241bf7836239 100644 --- a/tests/tests_pytorch/accelerators/test_xla.py +++ b/tests/tests_pytorch/accelerators/test_xla.py @@ -24,7 +24,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CPUAccelerator, XLAAccelerator from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.plugins import PrecisionPlugin, XLACheckpointIO, XLAPrecisionPlugin +from lightning.pytorch.plugins import Precision, XLACheckpointIO, XLAPrecision from lightning.pytorch.strategies import DDPStrategy, XLAStrategy from lightning.pytorch.utilities import find_shared_parameters from torch import nn @@ -245,16 +245,16 @@ def test_auto_parameters_tying_tpus_nested_module(tmpdir): def test_tpu_invalid_raises(tpu_available, mps_count_0): - strategy = DDPStrategy(accelerator=XLAAccelerator(), precision_plugin=XLAPrecisionPlugin()) + strategy = DDPStrategy(accelerator=XLAAccelerator(), precision_plugin=XLAPrecision()) with pytest.raises(ValueError, match="XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy`"): Trainer(strategy=strategy, devices=8) accelerator = XLAAccelerator() - with pytest.raises(TypeError, match="can only work with the `XLAPrecisionPlugin` plugin"): - XLAStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin()) + with pytest.raises(TypeError, match="can only work with the `XLAPrecision` plugin"): + XLAStrategy(accelerator=accelerator, precision_plugin=Precision()) accelerator = XLAAccelerator() - strategy = DDPStrategy(accelerator=accelerator, precision_plugin=XLAPrecisionPlugin()) + strategy = DDPStrategy(accelerator=accelerator, precision_plugin=XLAPrecision()) with pytest.raises( ValueError, match="The `XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy` or `XLAStrategy" ): diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py index d5e451b14f6a3..b442fe93853af 100644 --- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py +++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py @@ -1,10 +1,12 @@ +import sys +from unittest.mock import Mock + import lightning.fabric import pytest import torch.nn from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins.precision.double import LightningDoublePrecisionModule -from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy from tests_pytorch.helpers.runif import RunIf @@ -53,5 +55,81 @@ def test_double_precision_wrapper(): def test_fsdp_mixed_precision_plugin(): + from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin + with pytest.deprecated_call(match=r"The `FSDPMixedPrecisionPlugin` is deprecated"): FSDPMixedPrecisionPlugin(precision="16-mixed", device="cuda") + + +def test_fsdp_precision_plugin(): + from lightning.pytorch.plugins.precision.fsdp import FSDPPrecisionPlugin + + with pytest.deprecated_call(match=r"The `FSDPPrecisionPlugin` is deprecated"): + FSDPPrecisionPlugin(precision="16-mixed") + + +def test_bitsandbytes_precision_plugin(monkeypatch): + monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True) + bitsandbytes_mock = Mock() + monkeypatch.setitem(sys.modules, "bitsandbytes", bitsandbytes_mock) + + from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecisionPlugin + + with pytest.deprecated_call(match=r"The `BitsandbytesPrecisionPlugin` is deprecated"): + BitsandbytesPrecisionPlugin("nf4") + + +def test_deepspeed_precision_plugin(): + from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin + + with pytest.deprecated_call(match=r"The `DeepSpeedPrecisionPlugin` is deprecated"): + DeepSpeedPrecisionPlugin(precision="32-true") + + +def test_double_precision_plugin(): + from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin + + with pytest.deprecated_call(match=r"The `DoublePrecisionPlugin` is deprecated"): + DoublePrecisionPlugin() + + +def test_half_precision_plugin(): + from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin + + with pytest.deprecated_call(match=r"The `HalfPrecisionPlugin` is deprecated"): + HalfPrecisionPlugin() + + +def test_mixed_precision_plugin(): + from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin + + with pytest.deprecated_call(match=r"The `MixedPrecisionPlugin` is deprecated"): + MixedPrecisionPlugin(precision="16-mixed", device="cuda") + + +def test_precision_plugin(): + from lightning.pytorch.plugins.precision.precision import PrecisionPlugin + + with pytest.deprecated_call(match=r"The `PrecisionPlugin` is deprecated"): + PrecisionPlugin() + + +def test_transformer_engine_precision_plugin(monkeypatch): + monkeypatch.setattr(lightning.fabric.plugins.precision.transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", True) + transformer_engine_mock = Mock() + monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock) + monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock()) + recipe_mock = Mock() + monkeypatch.setitem(sys.modules, "transformer_engine.common.recipe", recipe_mock) + + from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecisionPlugin + + with pytest.deprecated_call(match=r"The `TransformerEnginePrecisionPlugin` is deprecated"): + TransformerEnginePrecisionPlugin() + + +def test_xla_precision_plugin(xla_available): + from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin + + with pytest.deprecated_call(match=r"The `XLAPrecisionPlugin` is deprecated"): + XLAPrecisionPlugin() diff --git a/tests/tests_pytorch/graveyard/__init__.py b/tests/tests_pytorch/graveyard/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_pytorch/graveyard/test_precision.py b/tests/tests_pytorch/graveyard/test_precision.py new file mode 100644 index 0000000000000..17bca3315ea64 --- /dev/null +++ b/tests/tests_pytorch/graveyard/test_precision.py @@ -0,0 +1,90 @@ +def test_precision_plugin_renamed_imports(): + # base class + from lightning.pytorch.plugins import PrecisionPlugin as PrecisionPlugin2 + from lightning.pytorch.plugins.precision import PrecisionPlugin as PrecisionPlugin1 + from lightning.pytorch.plugins.precision.precision import Precision + from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin as PrecisionPlugin0 + + assert issubclass(PrecisionPlugin0, Precision) + assert issubclass(PrecisionPlugin1, Precision) + assert issubclass(PrecisionPlugin2, Precision) + + # bitsandbytes + from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin as BnbPlugin2 + from lightning.pytorch.plugins.precision import BitsandbytesPrecisionPlugin as BnbPlugin1 + from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecision + from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecisionPlugin as BnbPlugin0 + + assert issubclass(BnbPlugin0, BitsandbytesPrecision) + assert issubclass(BnbPlugin1, BitsandbytesPrecision) + assert issubclass(BnbPlugin2, BitsandbytesPrecision) + + # deepspeed + from lightning.pytorch.plugins import DeepSpeedPrecisionPlugin as DeepSpeedPlugin2 + from lightning.pytorch.plugins.precision import DeepSpeedPrecisionPlugin as DeepSpeedPlugin1 + from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision + from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin as DeepSpeedPlugin0 + + assert issubclass(DeepSpeedPlugin0, DeepSpeedPrecision) + assert issubclass(DeepSpeedPlugin1, DeepSpeedPrecision) + assert issubclass(DeepSpeedPlugin2, DeepSpeedPrecision) + + # double + from lightning.pytorch.plugins import DoublePrecisionPlugin as DoublePlugin2 + from lightning.pytorch.plugins.precision import DoublePrecisionPlugin as DoublePlugin1 + from lightning.pytorch.plugins.precision.double import DoublePrecision + from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin as DoublePlugin0 + + assert issubclass(DoublePlugin0, DoublePrecision) + assert issubclass(DoublePlugin1, DoublePrecision) + assert issubclass(DoublePlugin2, DoublePrecision) + + # fsdp + from lightning.pytorch.plugins import FSDPPrecisionPlugin as FSDPPlugin2 + from lightning.pytorch.plugins.precision import FSDPPrecisionPlugin as FSDPPlugin1 + from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision + from lightning.pytorch.plugins.precision.fsdp import FSDPPrecisionPlugin as FSDPPlugin0 + + assert issubclass(FSDPPlugin0, FSDPPrecision) + assert issubclass(FSDPPlugin1, FSDPPrecision) + assert issubclass(FSDPPlugin2, FSDPPrecision) + + # half + from lightning.pytorch.plugins import HalfPrecisionPlugin as HalfPlugin2 + from lightning.pytorch.plugins.precision import HalfPrecisionPlugin as HalfPlugin1 + from lightning.pytorch.plugins.precision.half import HalfPrecision + from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin as HalfPlugin0 + + assert issubclass(HalfPlugin0, HalfPrecision) + assert issubclass(HalfPlugin1, HalfPrecision) + assert issubclass(HalfPlugin2, HalfPrecision) + + # mixed + from lightning.pytorch.plugins import MixedPrecisionPlugin as MixedPlugin2 + from lightning.pytorch.plugins.precision import MixedPrecisionPlugin as MixedPlugin1 + from lightning.pytorch.plugins.precision.amp import MixedPrecision + from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin as MixedPlugin0 + + assert issubclass(MixedPlugin0, MixedPrecision) + assert issubclass(MixedPlugin1, MixedPrecision) + assert issubclass(MixedPlugin2, MixedPrecision) + + # transformer_engine + from lightning.pytorch.plugins import TransformerEnginePrecisionPlugin as TEPlugin2 + from lightning.pytorch.plugins.precision import TransformerEnginePrecisionPlugin as TEPlugin1 + from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision + from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecisionPlugin as TEPlugin0 + + assert issubclass(TEPlugin0, TransformerEnginePrecision) + assert issubclass(TEPlugin1, TransformerEnginePrecision) + assert issubclass(TEPlugin2, TransformerEnginePrecision) + + # xla + from lightning.pytorch.plugins import XLAPrecisionPlugin as XLAPlugin2 + from lightning.pytorch.plugins.precision import XLAPrecisionPlugin as XLAPlugin1 + from lightning.pytorch.plugins.precision.xla import XLAPrecision + from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin as XLAPlugin0 + + assert issubclass(XLAPlugin0, XLAPrecision) + assert issubclass(XLAPlugin1, XLAPrecision) + assert issubclass(XLAPlugin2, XLAPrecision) diff --git a/tests/tests_pytorch/models/test_ddp_fork_amp.py b/tests/tests_pytorch/models/test_ddp_fork_amp.py index 032e72eaf0eb4..54d394948eeee 100644 --- a/tests/tests_pytorch/models/test_ddp_fork_amp.py +++ b/tests/tests_pytorch/models/test_ddp_fork_amp.py @@ -14,7 +14,7 @@ import multiprocessing import torch -from lightning.pytorch.plugins import MixedPrecisionPlugin +from lightning.pytorch.plugins import MixedPrecision from tests_pytorch.helpers.runif import RunIf @@ -24,7 +24,7 @@ def test_amp_gpus_ddp_fork(): """Ensure the use of AMP with `ddp_fork` (or associated alias strategies) does not generate CUDA initialization errors.""" - _ = MixedPrecisionPlugin(precision="16-mixed", device="cuda") + _ = MixedPrecision(precision="16-mixed", device="cuda") with multiprocessing.get_context("fork").Pool(1) as pool: in_bad_fork = pool.apply(torch.cuda._is_in_bad_fork) assert not in_bad_fork diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 1c546c69394ab..8b5836cd4d574 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -294,7 +294,7 @@ def _auto_train_batch(trainer, model, batches, device, current_epoch=0, current_ "kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None}, }, # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates - # the actual call to `PrecisionPlugin.optimizer_step` + # the actual call to `Precision.optimizer_step` { "name": "optimizer_step", "args": (current_epoch, i, ANY, ANY), diff --git a/tests/tests_pytorch/plugins/precision/test_all.py b/tests/tests_pytorch/plugins/precision/test_all.py index 8b58ae9b0eebd..2668311c8b452 100644 --- a/tests/tests_pytorch/plugins/precision/test_all.py +++ b/tests/tests_pytorch/plugins/precision/test_all.py @@ -1,29 +1,29 @@ import pytest import torch from lightning.pytorch.plugins import ( - DeepSpeedPrecisionPlugin, - DoublePrecisionPlugin, - FSDPPrecisionPlugin, - HalfPrecisionPlugin, + DeepSpeedPrecision, + DoublePrecision, + FSDPPrecision, + HalfPrecision, ) @pytest.mark.parametrize( "precision", [ - DeepSpeedPrecisionPlugin("16-true"), - DoublePrecisionPlugin(), - HalfPrecisionPlugin(), + DeepSpeedPrecision("16-true"), + DoublePrecision(), + HalfPrecision(), "fsdp", ], ) def test_default_dtype_is_restored(precision): if precision == "fsdp": - precision = FSDPPrecisionPlugin("16-true") + precision = FSDPPrecision("16-true") contexts = ( (precision.module_init_context, precision.forward_context) - if not isinstance(precision, DeepSpeedPrecisionPlugin) + if not isinstance(precision, DeepSpeedPrecision) else (precision.module_init_context,) ) for context in contexts: diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index e83588f31398c..90ecc703c8945 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,7 +14,7 @@ from unittest.mock import Mock import pytest -from lightning.pytorch.plugins import MixedPrecisionPlugin +from lightning.pytorch.plugins import MixedPrecision from lightning.pytorch.utilities import GradClipAlgorithmType from torch.optim import Optimizer @@ -22,7 +22,7 @@ def test_clip_gradients(): """Test that `.clip_gradients()` is a no-op when clipping is disabled.""" optimizer = Mock(spec=Optimizer) - precision = MixedPrecisionPlugin(precision="16-mixed", device="cuda:0", scaler=Mock()) + precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) precision.clip_grad_by_value = Mock() precision.clip_grad_by_norm = Mock() precision.clip_gradients(optimizer) @@ -46,7 +46,7 @@ def test_optimizer_amp_scaling_support_in_step_method(): gradient clipping (example: fused Adam).""" optimizer = Mock(_step_supports_amp_scaling=True) - precision = MixedPrecisionPlugin(precision="16-mixed", device="cuda:0", scaler=Mock()) + precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): precision.clip_gradients(optimizer, clip_val=1.0) diff --git a/tests/tests_pytorch/plugins/precision/test_amp_integration.py b/tests/tests_pytorch/plugins/precision/test_amp_integration.py index ccee6c0d73d69..759ea50eb9c37 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp_integration.py +++ b/tests/tests_pytorch/plugins/precision/test_amp_integration.py @@ -17,7 +17,7 @@ from lightning.fabric import seed_everything from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.plugins.precision import MixedPrecision from tests_pytorch.helpers.runif import RunIf @@ -77,7 +77,7 @@ def training_step(self, batch, batch_idx): max_steps=5, gradient_clip_val=0.5, ) - assert isinstance(trainer.precision_plugin, MixedPrecisionPlugin) + assert isinstance(trainer.precision_plugin, MixedPrecision) assert trainer.precision_plugin.scaler is not None trainer.precision_plugin.scaler = Mock(wraps=trainer.precision_plugin.scaler) model = TestModel() diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py index 8c4d9a6b198e9..3e1aaa17763e9 100644 --- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py @@ -14,12 +14,12 @@ import pytest import torch -from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin +from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision def test_invalid_precision_with_deepspeed_precision(): with pytest.raises(ValueError, match="is not supported. `precision` must be one of"): - DeepSpeedPrecisionPlugin(precision="64-true") + DeepSpeedPrecision(precision="64-true") @pytest.mark.parametrize( @@ -33,7 +33,7 @@ def test_invalid_precision_with_deepspeed_precision(): ], ) def test_selected_dtype(precision, expected_dtype): - plugin = DeepSpeedPrecisionPlugin(precision=precision) + plugin = DeepSpeedPrecision(precision=precision) assert plugin.precision == precision assert plugin._desired_dtype == expected_dtype @@ -49,7 +49,7 @@ def test_selected_dtype(precision, expected_dtype): ], ) def test_module_init_context(precision, expected_dtype): - plugin = DeepSpeedPrecisionPlugin(precision=precision) + plugin = DeepSpeedPrecision(precision=precision) with plugin.module_init_context(): model = torch.nn.Linear(2, 2) assert torch.get_default_dtype() == expected_dtype @@ -67,7 +67,7 @@ def test_module_init_context(precision, expected_dtype): ], ) def test_convert_module(precision, expected_dtype): - precision = DeepSpeedPrecisionPlugin(precision=precision) + precision = DeepSpeedPrecision(precision=precision) module = torch.nn.Linear(2, 2) assert module.weight.dtype == module.bias.dtype == torch.float32 module = precision.convert_module(module) diff --git a/tests/tests_pytorch/plugins/precision/test_double.py b/tests/tests_pytorch/plugins/precision/test_double.py index f295bd5601f45..1f9dbac9e79d6 100644 --- a/tests/tests_pytorch/plugins/precision/test_double.py +++ b/tests/tests_pytorch/plugins/precision/test_double.py @@ -18,7 +18,7 @@ import torch from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin +from lightning.pytorch.plugins.precision.double import DoublePrecision from torch.utils.data import DataLoader, Dataset from tests_pytorch.helpers.runif import RunIf @@ -171,13 +171,13 @@ def test_double_precision_ddp(tmpdir): def test_double_precision_pickle(): model = BoringModel() - plugin = DoublePrecisionPlugin() + plugin = DoublePrecision() model, _, __ = plugin.connect(model, MagicMock(), MagicMock()) pickle.dumps(model) def test_convert_module(): - plugin = DoublePrecisionPlugin() + plugin = DoublePrecision() model = BoringModel() assert model.layer.weight.dtype == model.layer.bias.dtype == torch.float32 model = plugin.convert_module(model) @@ -185,7 +185,7 @@ def test_convert_module(): def test_module_init_context(): - plugin = DoublePrecisionPlugin() + plugin = DoublePrecision() with plugin.module_init_context(): model = torch.nn.Linear(2, 2) assert torch.get_default_dtype() == torch.double diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py index 1e81531cd1487..e4d652cb15864 100644 --- a/tests/tests_pytorch/plugins/precision/test_fsdp.py +++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py @@ -16,7 +16,7 @@ import pytest import torch from lightning.fabric.plugins.precision.utils import _DtypeContextManager -from lightning.pytorch.plugins.precision.fsdp import FSDPPrecisionPlugin +from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision from tests_pytorch.helpers.runif import RunIf @@ -48,7 +48,7 @@ ], ) def test_fsdp_precision_config(precision, expected): - plugin = FSDPPrecisionPlugin(precision=precision) + plugin = FSDPPrecision(precision=precision) config = plugin.mixed_precision_config assert config.param_dtype == expected[0] @@ -59,22 +59,22 @@ def test_fsdp_precision_config(precision, expected): def test_fsdp_precision_default_scaler(): from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler - precision = FSDPPrecisionPlugin(precision="16-mixed") + precision = FSDPPrecision(precision="16-mixed") assert isinstance(precision.scaler, ShardedGradScaler) def test_fsdp_precision_scaler_with_bf16(): with pytest.raises(ValueError, match="`precision='bf16-mixed'` does not use a scaler"): - FSDPPrecisionPlugin(precision="bf16-mixed", scaler=Mock()) + FSDPPrecision(precision="bf16-mixed", scaler=Mock()) - precision = FSDPPrecisionPlugin(precision="bf16-mixed") + precision = FSDPPrecision(precision="bf16-mixed") assert precision.scaler is None @RunIf(min_cuda_gpus=1) def test_fsdp_precision_forward_context(): """Test to ensure that the context manager correctly is set to bfloat16.""" - precision = FSDPPrecisionPlugin(precision="16-mixed") + precision = FSDPPrecision(precision="16-mixed") assert isinstance(precision.scaler, torch.cuda.amp.GradScaler) assert torch.get_default_dtype() == torch.float32 with precision.forward_context(): @@ -82,7 +82,7 @@ def test_fsdp_precision_forward_context(): assert isinstance(precision.forward_context(), torch.autocast) assert precision.forward_context().fast_dtype == torch.float16 - precision = FSDPPrecisionPlugin(precision="16-true") + precision = FSDPPrecision(precision="16-true") assert precision.scaler is None assert torch.get_default_dtype() == torch.float32 with precision.forward_context(): @@ -90,14 +90,14 @@ def test_fsdp_precision_forward_context(): assert isinstance(precision.forward_context(), _DtypeContextManager) assert precision.forward_context()._new_dtype == torch.float16 - precision = FSDPPrecisionPlugin(precision="bf16-mixed") + precision = FSDPPrecision(precision="bf16-mixed") assert precision.scaler is None with precision.forward_context(): assert torch.get_autocast_gpu_dtype() == torch.bfloat16 assert isinstance(precision.forward_context(), torch.autocast) assert precision.forward_context().fast_dtype == torch.bfloat16 - precision = FSDPPrecisionPlugin(precision="bf16-true") + precision = FSDPPrecision(precision="bf16-true") assert precision.scaler is None with precision.forward_context(): # forward context is not using autocast ctx manager assert torch.get_default_dtype() == torch.bfloat16 @@ -106,7 +106,7 @@ def test_fsdp_precision_forward_context(): def test_fsdp_precision_backward(): - precision = FSDPPrecisionPlugin(precision="16-mixed") + precision = FSDPPrecision(precision="16-mixed") precision.scaler = Mock() precision.scaler.scale = Mock(side_effect=(lambda x: x)) tensor = Mock() @@ -118,7 +118,7 @@ def test_fsdp_precision_backward(): def test_fsdp_precision_optimizer_step_with_scaler(): - precision = FSDPPrecisionPlugin(precision="16-mixed") + precision = FSDPPrecision(precision="16-mixed") precision.scaler = Mock() model = Mock(trainer=Mock(callbacks=[], profiler=MagicMock())) optimizer = Mock() @@ -130,7 +130,7 @@ def test_fsdp_precision_optimizer_step_with_scaler(): def test_fsdp_precision_optimizer_step_without_scaler(): - precision = FSDPPrecisionPlugin(precision="bf16-mixed") + precision = FSDPPrecision(precision="bf16-mixed") assert precision.scaler is None model = Mock(trainer=Mock(callbacks=[], profiler=MagicMock())) optimizer = Mock() @@ -141,8 +141,8 @@ def test_fsdp_precision_optimizer_step_without_scaler(): def test_invalid_precision_with_fsdp_precision(): - FSDPPrecisionPlugin("16-mixed") - FSDPPrecisionPlugin("bf16-mixed") + FSDPPrecision("16-mixed") + FSDPPrecision("bf16-mixed") with pytest.raises(ValueError, match="is not supported in FSDP. `precision` must be one of"): - FSDPPrecisionPlugin(precision="64-true") + FSDPPrecision(precision="64-true") diff --git a/tests/tests_pytorch/plugins/precision/test_half.py b/tests/tests_pytorch/plugins/precision/test_half.py index 89a7cddf13b2c..5712b5ba46f13 100644 --- a/tests/tests_pytorch/plugins/precision/test_half.py +++ b/tests/tests_pytorch/plugins/precision/test_half.py @@ -14,7 +14,7 @@ import pytest import torch -from lightning.pytorch.plugins import HalfPrecisionPlugin +from lightning.pytorch.plugins import HalfPrecision @pytest.mark.parametrize( @@ -25,7 +25,7 @@ ], ) def test_selected_dtype(precision, expected_dtype): - plugin = HalfPrecisionPlugin(precision=precision) + plugin = HalfPrecision(precision=precision) assert plugin.precision == precision assert plugin._desired_input_dtype == expected_dtype @@ -38,7 +38,7 @@ def test_selected_dtype(precision, expected_dtype): ], ) def test_module_init_context(precision, expected_dtype): - plugin = HalfPrecisionPlugin(precision=precision) + plugin = HalfPrecision(precision=precision) with plugin.module_init_context(): model = torch.nn.Linear(2, 2) assert torch.get_default_dtype() == expected_dtype @@ -53,7 +53,7 @@ def test_module_init_context(precision, expected_dtype): ], ) def test_forward_context(precision, expected_dtype): - precision = HalfPrecisionPlugin(precision=precision) + precision = HalfPrecision(precision=precision) assert torch.get_default_dtype() == torch.float32 with precision.forward_context(): assert torch.get_default_dtype() == expected_dtype @@ -68,7 +68,7 @@ def test_forward_context(precision, expected_dtype): ], ) def test_convert_module(precision, expected_dtype): - precision = HalfPrecisionPlugin(precision=precision) + precision = HalfPrecision(precision=precision) module = torch.nn.Linear(2, 2) assert module.weight.dtype == module.bias.dtype == torch.float32 module = precision.convert_module(module) diff --git a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py index 069ebd52f36e2..cf9e79a53ad5e 100644 --- a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py +++ b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py @@ -16,7 +16,7 @@ import pytest import torch -from lightning.pytorch.plugins import TransformerEnginePrecisionPlugin +from lightning.pytorch.plugins import TransformerEnginePrecision from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector @@ -31,11 +31,11 @@ def test_transformer_engine_precision_plugin(monkeypatch): monkeypatch.setitem(sys.modules, "transformer_engine.common.recipe", Mock()) connector = _AcceleratorConnector(precision="transformer-engine") - assert isinstance(connector.precision_plugin, TransformerEnginePrecisionPlugin) + assert isinstance(connector.precision_plugin, TransformerEnginePrecision) assert connector.precision_plugin.dtype is torch.bfloat16 connector = _AcceleratorConnector(precision="transformer-engine-float16") assert connector.precision_plugin.dtype is torch.float16 - precision = TransformerEnginePrecisionPlugin() + precision = TransformerEnginePrecision() connector = _AcceleratorConnector(plugins=precision) assert connector.precision_plugin is precision diff --git a/tests/tests_pytorch/plugins/precision/test_xla.py b/tests/tests_pytorch/plugins/precision/test_xla.py index ccc4e521f74d4..97990b6380dab 100644 --- a/tests/tests_pytorch/plugins/precision/test_xla.py +++ b/tests/tests_pytorch/plugins/precision/test_xla.py @@ -18,7 +18,7 @@ import pytest import torch -from lightning.pytorch.plugins import XLAPrecisionPlugin +from lightning.pytorch.plugins import XLAPrecision from tests_pytorch.helpers.runif import RunIf @@ -26,7 +26,7 @@ @RunIf(tpu=True) @mock.patch.dict(os.environ, {}, clear=True) def test_optimizer_step_calls_mark_step(): - plugin = XLAPrecisionPlugin(precision="32-true") + plugin = XLAPrecision(precision="32-true") optimizer = Mock() with mock.patch("torch_xla.core.xla_model") as xm_mock: plugin.optimizer_step(optimizer=optimizer, model=Mock(), closure=Mock()) @@ -36,18 +36,18 @@ def test_optimizer_step_calls_mark_step(): @mock.patch.dict(os.environ, {}, clear=True) def test_precision_input_validation(xla_available): - XLAPrecisionPlugin(precision="32-true") - XLAPrecisionPlugin(precision="16-true") - XLAPrecisionPlugin(precision="bf16-true") + XLAPrecision(precision="32-true") + XLAPrecision(precision="16-true") + XLAPrecision(precision="bf16-true") with pytest.raises(ValueError, match=re.escape("`precision='16')` is not supported in XLA")): - XLAPrecisionPlugin("16") + XLAPrecision("16") with pytest.raises(ValueError, match=re.escape("`precision='16-mixed')` is not supported in XLA")): - XLAPrecisionPlugin("16-mixed") + XLAPrecision("16-mixed") with pytest.raises(ValueError, match=re.escape("`precision='bf16-mixed')` is not supported in XLA")): - XLAPrecisionPlugin("bf16-mixed") + XLAPrecision("bf16-mixed") with pytest.raises(ValueError, match=re.escape("`precision='64-true')` is not supported in XLA")): - XLAPrecisionPlugin("64-true") + XLAPrecision("64-true") @pytest.mark.parametrize( @@ -59,18 +59,18 @@ def test_precision_input_validation(xla_available): ) @mock.patch.dict(os.environ, {}, clear=True) def test_selected_dtype(precision, expected_dtype, xla_available): - plugin = XLAPrecisionPlugin(precision=precision) + plugin = XLAPrecision(precision=precision) assert plugin.precision == precision assert plugin._desired_dtype == expected_dtype def test_teardown(xla_available): - plugin = XLAPrecisionPlugin(precision="16-true") + plugin = XLAPrecision(precision="16-true") assert os.environ["XLA_USE_F16"] == "1" plugin.teardown() assert "XLA_USE_B16" not in os.environ - plugin = XLAPrecisionPlugin(precision="bf16-true") + plugin = XLAPrecision(precision="bf16-true") assert os.environ["XLA_USE_BF16"] == "1" plugin.teardown() assert "XLA_USE_BF16" not in os.environ diff --git a/tests/tests_pytorch/plugins/test_amp_plugins.py b/tests/tests_pytorch/plugins/test_amp_plugins.py index 2d0205b819fde..9fcba9515891d 100644 --- a/tests/tests_pytorch/plugins/test_amp_plugins.py +++ b/tests/tests_pytorch/plugins/test_amp_plugins.py @@ -20,13 +20,13 @@ import torch from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.plugins import MixedPrecisionPlugin +from lightning.pytorch.plugins import MixedPrecision from torch import Tensor from tests_pytorch.helpers.runif import RunIf -class MyAMP(MixedPrecisionPlugin): +class MyAMP(MixedPrecision): pass @@ -48,7 +48,7 @@ class MyAMP(MixedPrecisionPlugin): @pytest.mark.parametrize( ("custom_plugin", "plugin_cls"), [ - (False, MixedPrecisionPlugin), + (False, MixedPrecision), (True, MyAMP), ], ) @@ -190,7 +190,7 @@ def configure_optimizers(self): def test_cpu_amp_precision_context_manager(): """Test to ensure that the context manager correctly is set to CPU + bfloat16.""" - plugin = MixedPrecisionPlugin("bf16-mixed", "cpu") + plugin = MixedPrecision("bf16-mixed", "cpu") assert plugin.device == "cpu" assert plugin.scaler is None context_manager = plugin.autocast_context_manager() @@ -199,25 +199,23 @@ def test_cpu_amp_precision_context_manager(): def test_amp_precision_plugin_parameter_validation(): - MixedPrecisionPlugin("16-mixed", "cpu") # should not raise exception - MixedPrecisionPlugin("bf16-mixed", "cpu") + MixedPrecision("16-mixed", "cpu") # should not raise exception + MixedPrecision("bf16-mixed", "cpu") with pytest.raises( ValueError, - match=re.escape("Passed `MixedPrecisionPlugin(precision='16')`. Precision must be '16-mixed' or 'bf16-mixed'"), + match=re.escape("Passed `MixedPrecision(precision='16')`. Precision must be '16-mixed' or 'bf16-mixed'"), ): - MixedPrecisionPlugin("16", "cpu") + MixedPrecision("16", "cpu") with pytest.raises( ValueError, - match=re.escape("Passed `MixedPrecisionPlugin(precision=16)`. Precision must be '16-mixed' or 'bf16-mixed'"), + match=re.escape("Passed `MixedPrecision(precision=16)`. Precision must be '16-mixed' or 'bf16-mixed'"), ): - MixedPrecisionPlugin(16, "cpu") + MixedPrecision(16, "cpu") with pytest.raises( ValueError, - match=re.escape( - "Passed `MixedPrecisionPlugin(precision='bf16')`. Precision must be '16-mixed' or 'bf16-mixed'" - ), + match=re.escape("Passed `MixedPrecision(precision='bf16')`. Precision must be '16-mixed' or 'bf16-mixed'"), ): - MixedPrecisionPlugin("bf16", "cpu") + MixedPrecision("bf16", "cpu") diff --git a/tests/tests_pytorch/strategies/test_common.py b/tests/tests_pytorch/strategies/test_common.py index fa6ba5f79963a..4ee89c1f7b3a7 100644 --- a/tests/tests_pytorch/strategies/test_common.py +++ b/tests/tests_pytorch/strategies/test_common.py @@ -17,7 +17,7 @@ import torch from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.pytorch import Trainer -from lightning.pytorch.plugins import DoublePrecisionPlugin, HalfPrecisionPlugin, PrecisionPlugin +from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision from lightning.pytorch.strategies import SingleDeviceStrategy from tests_pytorch.helpers.datamodules import ClassifDataModule @@ -66,10 +66,10 @@ def test_evaluate(tmpdir, trainer_kwargs): @pytest.mark.parametrize( ("precision", "dtype"), [ - (PrecisionPlugin(), torch.float32), - pytest.param(DoublePrecisionPlugin(), torch.float64, marks=RunIf(mps=False)), - (HalfPrecisionPlugin("16-true"), torch.float16), - pytest.param(HalfPrecisionPlugin("bf16-true"), torch.bfloat16, marks=RunIf(bf16_cuda=True)), + (Precision(), torch.float32), + pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)), + (HalfPrecision("16-true"), torch.float16), + pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) @pytest.mark.parametrize("empty_init", [None, True, False]) diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index 172c32f0182bf..dadd49c359e06 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -21,7 +21,7 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.plugins import DoublePrecisionPlugin, HalfPrecisionPlugin, PrecisionPlugin +from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.trainer.states import TrainerFn from torch.nn.parallel import DistributedDataParallel @@ -92,10 +92,10 @@ def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs, mps_count_ @pytest.mark.parametrize( ("precision_plugin", "expected_dtype"), [ - (PrecisionPlugin(), torch.float32), - (DoublePrecisionPlugin(), torch.float64), - (HalfPrecisionPlugin("16-true"), torch.float16), - pytest.param(HalfPrecisionPlugin("bf16-true"), torch.bfloat16, marks=RunIf(bf16_cuda=True)), + (Precision(), torch.float32), + (DoublePrecision(), torch.float64), + (HalfPrecision("16-true"), torch.float16), + pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) @mock.patch.dict(os.environ, {"LOCAL_RANK": "1"}) diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index 2004c1aea72a8..202c221ba946e 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -28,7 +28,7 @@ from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger -from lightning.pytorch.plugins import DeepSpeedPrecisionPlugin +from lightning.pytorch.plugins import DeepSpeedPrecision from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE, DeepSpeedStrategy from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 @@ -139,7 +139,7 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config): def test_deepspeed_precision_choice(cuda_count_1, tmpdir): """Test to ensure precision plugin is also correctly chosen. - DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin + DeepSpeed handles precision via Custom DeepSpeedPrecision """ trainer = Trainer( @@ -151,7 +151,7 @@ def test_deepspeed_precision_choice(cuda_count_1, tmpdir): ) assert isinstance(trainer.strategy, DeepSpeedStrategy) - assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecisionPlugin) + assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecision) assert trainer.strategy.precision_plugin.precision == "16-mixed" @@ -1212,7 +1212,7 @@ def test_deepspeed_with_bfloat16_precision(tmpdir): ) trainer.fit(model) - assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecisionPlugin) + assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecision) assert trainer.strategy.precision_plugin.precision == "bf16-mixed" assert trainer.strategy.config["zero_optimization"]["stage"] == 3 assert trainer.strategy.config["bf16"]["enabled"] diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index f8e6bb6633247..f094053e98fce 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -21,8 +21,8 @@ from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.plugins import HalfPrecisionPlugin -from lightning.pytorch.plugins.precision.fsdp import FSDPPrecisionPlugin +from lightning.pytorch.plugins import HalfPrecision +from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision from lightning.pytorch.strategies import FSDPStrategy from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -81,7 +81,7 @@ def on_predict_batch_end(self, _, batch, batch_idx): def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, FullyShardedDataParallel) - assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecisionPlugin) + assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision) if self.trainer.precision == "16-mixed": param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 @@ -144,7 +144,7 @@ def on_predict_batch_end(self, _, batch, batch_idx): def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, torch.nn.Sequential) - assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecisionPlugin) + assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision) if self.trainer.precision == "16-mixed": param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32 @@ -412,11 +412,11 @@ def test_fsdp_activation_checkpointing_support(): def test_fsdp_forbidden_precision_raises(): with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"): - FSDPStrategy(precision_plugin=HalfPrecisionPlugin()) + FSDPStrategy(precision_plugin=HalfPrecision()) strategy = FSDPStrategy() with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"): - strategy.precision_plugin = HalfPrecisionPlugin() + strategy.precision_plugin = HalfPrecision() @RunIf(min_torch="1.13") diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index ca329c0995535..59492a4e05fe7 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -35,12 +35,12 @@ from lightning.pytorch.plugins.io import TorchCheckpointIO from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm from lightning.pytorch.plugins.precision import ( - DeepSpeedPrecisionPlugin, - DoublePrecisionPlugin, - FSDPPrecisionPlugin, - HalfPrecisionPlugin, - MixedPrecisionPlugin, - PrecisionPlugin, + DeepSpeedPrecision, + DoublePrecision, + FSDPPrecision, + HalfPrecision, + MixedPrecision, + Precision, ) from lightning.pytorch.strategies import ( DDPStrategy, @@ -97,7 +97,7 @@ def test_invalid_strategy_choice(invalid_strategy): def test_precision_and_precision_plugin_raises(): with pytest.raises(ValueError, match="both `precision=16-true` and `plugins"): - _AcceleratorConnector(precision="16-true", plugins=PrecisionPlugin()) + _AcceleratorConnector(precision="16-true", plugins=Precision()) @RunIf(skip_windows=True, standalone=True) @@ -202,7 +202,7 @@ def is_available() -> bool: def name() -> str: return "custom_acc_name" - class Prec(PrecisionPlugin): + class Prec(Precision): pass class Strat(SingleDeviceStrategy): @@ -765,8 +765,8 @@ def __init__(self, **kwargs): ([LightningEnvironment(), SLURMEnvironment()], "ClusterEnvironment"), ([TorchCheckpointIO(), TorchCheckpointIO()], "CheckpointIO"), ( - [PrecisionPlugin(), DoublePrecisionPlugin(), LightningEnvironment(), SLURMEnvironment()], - "PrecisionPlugin, ClusterEnvironment", + [Precision(), DoublePrecision(), LightningEnvironment(), SLURMEnvironment()], + "Precision, ClusterEnvironment", ), ], ) @@ -1094,22 +1094,22 @@ def test_connector_num_nodes_input_validation(): @pytest.mark.parametrize( ("precision_str", "strategy_str", "expected_precision_cls"), [ - ("64-true", "auto", DoublePrecisionPlugin), - ("32-true", "auto", PrecisionPlugin), - ("16-true", "auto", HalfPrecisionPlugin), - ("bf16-true", "auto", HalfPrecisionPlugin), - ("16-mixed", "auto", MixedPrecisionPlugin), - ("bf16-mixed", "auto", MixedPrecisionPlugin), - pytest.param("32-true", "fsdp", FSDPPrecisionPlugin, marks=RunIf(min_cuda_gpus=1)), - pytest.param("16-true", "fsdp", FSDPPrecisionPlugin, marks=RunIf(min_cuda_gpus=1)), - pytest.param("bf16-true", "fsdp", FSDPPrecisionPlugin, marks=RunIf(min_cuda_gpus=1)), - pytest.param("16-mixed", "fsdp", FSDPPrecisionPlugin, marks=RunIf(min_cuda_gpus=1)), - pytest.param("bf16-mixed", "fsdp", FSDPPrecisionPlugin, marks=RunIf(min_cuda_gpus=1)), - pytest.param("32-true", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)), - pytest.param("16-true", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)), - pytest.param("bf16-true", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)), - pytest.param("16-mixed", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)), - pytest.param("bf16-mixed", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)), + ("64-true", "auto", DoublePrecision), + ("32-true", "auto", Precision), + ("16-true", "auto", HalfPrecision), + ("bf16-true", "auto", HalfPrecision), + ("16-mixed", "auto", MixedPrecision), + ("bf16-mixed", "auto", MixedPrecision), + pytest.param("32-true", "fsdp", FSDPPrecision, marks=RunIf(min_cuda_gpus=1)), + pytest.param("16-true", "fsdp", FSDPPrecision, marks=RunIf(min_cuda_gpus=1)), + pytest.param("bf16-true", "fsdp", FSDPPrecision, marks=RunIf(min_cuda_gpus=1)), + pytest.param("16-mixed", "fsdp", FSDPPrecision, marks=RunIf(min_cuda_gpus=1)), + pytest.param("bf16-mixed", "fsdp", FSDPPrecision, marks=RunIf(min_cuda_gpus=1)), + pytest.param("32-true", "deepspeed", DeepSpeedPrecision, marks=RunIf(deepspeed=True, mps=False)), + pytest.param("16-true", "deepspeed", DeepSpeedPrecision, marks=RunIf(deepspeed=True, mps=False)), + pytest.param("bf16-true", "deepspeed", DeepSpeedPrecision, marks=RunIf(deepspeed=True, mps=False)), + pytest.param("16-mixed", "deepspeed", DeepSpeedPrecision, marks=RunIf(deepspeed=True, mps=False)), + pytest.param("bf16-mixed", "deepspeed", DeepSpeedPrecision, marks=RunIf(deepspeed=True, mps=False)), ], ) def test_precision_selection(precision_str, strategy_str, expected_precision_cls): diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index c30840927e8e5..43a3fad916086 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -142,7 +142,7 @@ def test_import_deepspeed_lazily(): assert 'deepspeed' not in sys.modules from lightning.pytorch.strategies import DeepSpeedStrategy - from lightning.pytorch.plugins import DeepSpeedPrecisionPlugin + from lightning.pytorch.plugins import DeepSpeedPrecision assert 'deepspeed' not in sys.modules import deepspeed