From 3b1643c3754b6e4aa008d2dbf1115a2934c013c4 Mon Sep 17 00:00:00 2001 From: Victor Prins <victor.prins@outlook.com> Date: Fri, 15 Dec 2023 00:00:57 +0100 Subject: [PATCH] Add `@override` for files in `src/lightning/fabric/plugins/precision` (#19158) --- src/lightning/fabric/plugins/precision/amp.py | 9 +++++++++ .../fabric/plugins/precision/bitsandbytes.py | 7 +++++++ src/lightning/fabric/plugins/precision/deepspeed.py | 9 ++++++++- src/lightning/fabric/plugins/precision/double.py | 7 +++++++ src/lightning/fabric/plugins/precision/fsdp.py | 12 +++++++++++- src/lightning/fabric/plugins/precision/half.py | 7 +++++++ .../fabric/plugins/precision/transformer_engine.py | 7 +++++++ src/lightning/fabric/plugins/precision/xla.py | 4 +++- 8 files changed, 59 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index 8f0bd8aa42241..b57455d6f4aac 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -18,6 +18,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import LBFGS, Optimizer +from typing_extensions import override from lightning.fabric.accelerators.cuda import _patch_cuda_is_available from lightning.fabric.plugins.precision.precision import Precision @@ -59,20 +60,25 @@ def __init__( self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16 + @override def forward_context(self) -> ContextManager: return torch.autocast(self.device, dtype=self._desired_input_dtype) + @override def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) + @override def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) + @override def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: if self.scaler is not None: tensor = self.scaler.scale(tensor) super().backward(tensor, model, *args, **kwargs) + @override def optimizer_step( self, optimizer: Optimizable, @@ -88,15 +94,18 @@ def optimizer_step( self.scaler.update() return step_output + @override def state_dict(self) -> Dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} + @override def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) + @override def unscale_gradients(self, optimizer: Optimizer) -> None: scaler = self.scaler if scaler is not None: diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index ab4db3a4d6b71..e3e3533e4a3b6 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -25,6 +25,7 @@ from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.nn.modules.module import _IncompatibleKeys +from typing_extensions import override from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.utils import ( @@ -96,6 +97,7 @@ def __init__( self.dtype = dtype self.ignore_modules = ignore_modules or set() + @override def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: # avoid naive users thinking they quantized their model if not any(isinstance(m, torch.nn.Linear) for m in module.modules()): @@ -116,9 +118,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: m.compute_type_is_set = False return module + @override def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self.dtype) + @override def module_init_context(self) -> ContextManager: if self.ignore_modules: # cannot patch the Linear class if the user wants to skip some submodules @@ -136,12 +140,15 @@ def module_init_context(self) -> ContextManager: stack.enter_context(context_manager) return stack + @override def forward_context(self) -> ContextManager: return _DtypeContextManager(self.dtype) + @override def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype) + @override def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index 7f5e95bc15976..2fcaa38258e3a 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -18,7 +18,7 @@ from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.nn import Module -from typing_extensions import get_args +from typing_extensions import get_args, override from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager @@ -61,29 +61,36 @@ def __init__(self, precision: _PRECISION_INPUT) -> None: } self._desired_dtype = precision_to_type[self.precision] + @override def convert_module(self, module: Module) -> Module: if "true" in self.precision: return module.to(dtype=self._desired_dtype) return module + @override def tensor_init_context(self) -> ContextManager: if "true" not in self.precision: return nullcontext() return _DtypeContextManager(self._desired_dtype) + @override def module_init_context(self) -> ContextManager: return self.tensor_init_context() + @override def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) + @override def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) + @override def backward(self, tensor: Tensor, model: "DeepSpeedEngine", *args: Any, **kwargs: Any) -> None: """Performs back-propagation using DeepSpeed's engine.""" model.backward(tensor, *args, **kwargs) + @override def optimizer_step( self, optimizer: Steppable, diff --git a/src/lightning/fabric/plugins/precision/double.py b/src/lightning/fabric/plugins/precision/double.py index 3e38ccce67be1..0a857499f3d34 100644 --- a/src/lightning/fabric/plugins/precision/double.py +++ b/src/lightning/fabric/plugins/precision/double.py @@ -17,6 +17,7 @@ from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.nn import Module +from typing_extensions import override from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager @@ -27,20 +28,26 @@ class DoublePrecision(Precision): precision: Literal["64-true"] = "64-true" + @override def convert_module(self, module: Module) -> Module: return module.double() + @override def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(torch.double) + @override def module_init_context(self) -> ContextManager: return self.tensor_init_context() + @override def forward_context(self) -> ContextManager: return self.tensor_init_context() + @override def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.double) + @override def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index ebdafcd651d93..9fa3eec33065f 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -18,7 +18,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import Optimizer -from typing_extensions import get_args +from typing_extensions import get_args, override from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling from lightning.fabric.plugins.precision.precision import Precision @@ -103,28 +103,35 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": buffer_dtype=buffer_dtype, ) + @override def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self._desired_input_dtype) + @override def module_init_context(self) -> ContextManager: return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) + @override def forward_context(self) -> ContextManager: if "mixed" in self.precision: return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) return self.tensor_init_context() + @override def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) + @override def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) + @override def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: if self.scaler is not None: tensor = cast(Tensor, self.scaler.scale(tensor)) super().backward(tensor, model, *args, **kwargs) + @override def optimizer_step( self, optimizer: Optimizable, @@ -138,6 +145,7 @@ def optimizer_step( self.scaler.update() return step_output + @override def unscale_gradients(self, optimizer: Optimizer) -> None: scaler = self.scaler if scaler is not None: @@ -145,11 +153,13 @@ def unscale_gradients(self, optimizer: Optimizer) -> None: raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.") scaler.unscale_(optimizer) + @override def state_dict(self) -> Dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} + @override def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/fabric/plugins/precision/half.py b/src/lightning/fabric/plugins/precision/half.py index 77d02d0c000c2..32ca7da815213 100644 --- a/src/lightning/fabric/plugins/precision/half.py +++ b/src/lightning/fabric/plugins/precision/half.py @@ -17,6 +17,7 @@ from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.nn import Module +from typing_extensions import override from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager @@ -36,20 +37,26 @@ def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> No self.precision = precision self._desired_input_dtype = torch.bfloat16 if precision == "bf16-true" else torch.float16 + @override def convert_module(self, module: Module) -> Module: return module.to(dtype=self._desired_input_dtype) + @override def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self._desired_input_dtype) + @override def module_init_context(self) -> ContextManager: return self.tensor_init_context() + @override def forward_context(self) -> ContextManager: return self.tensor_init_context() + @override def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) + @override def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index caeef6d32287c..1892d18cb3afb 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -19,6 +19,7 @@ from lightning_utilities import apply_to_collection from lightning_utilities.core.imports import RequirementCache from torch import Tensor +from typing_extensions import override from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.utils import ( @@ -89,6 +90,7 @@ def __init__( self.replace_layers = replace_layers self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype + @override def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: # avoid converting if any is found. assume the user took care of it if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()): @@ -103,9 +105,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: module = module.to(dtype=self.weights_dtype) return module + @override def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self.weights_dtype) + @override def module_init_context(self) -> ContextManager: dtype_ctx = self.tensor_init_context() stack = ExitStack() @@ -122,6 +126,7 @@ def module_init_context(self) -> ContextManager: stack.enter_context(dtype_ctx) return stack + @override def forward_context(self) -> ContextManager: dtype_ctx = _DtypeContextManager(self.weights_dtype) fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype) @@ -135,9 +140,11 @@ def forward_context(self) -> ContextManager: stack.enter_context(autocast_ctx) return stack + @override def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype) + @override def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) diff --git a/src/lightning/fabric/plugins/precision/xla.py b/src/lightning/fabric/plugins/precision/xla.py index 3aff7d5f6866b..fdb30032b3cdd 100644 --- a/src/lightning/fabric/plugins/precision/xla.py +++ b/src/lightning/fabric/plugins/precision/xla.py @@ -15,7 +15,7 @@ from typing import Any, Literal import torch -from typing_extensions import get_args +from typing_extensions import get_args, override from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.fabric.plugins.precision.precision import Precision @@ -56,6 +56,7 @@ def __init__(self, precision: _PRECISION_INPUT) -> None: else: self._desired_dtype = torch.float32 + @override def optimizer_step( self, optimizer: Optimizable, @@ -66,6 +67,7 @@ def optimizer_step( # you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True` return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True) + @override def teardown(self) -> None: os.environ.pop("XLA_USE_BF16", None) os.environ.pop("XLA_USE_F16", None)