From 3b1643c3754b6e4aa008d2dbf1115a2934c013c4 Mon Sep 17 00:00:00 2001
From: Victor Prins <victor.prins@outlook.com>
Date: Fri, 15 Dec 2023 00:00:57 +0100
Subject: [PATCH] Add `@override` for files in
 `src/lightning/fabric/plugins/precision` (#19158)

---
 src/lightning/fabric/plugins/precision/amp.py        |  9 +++++++++
 .../fabric/plugins/precision/bitsandbytes.py         |  7 +++++++
 src/lightning/fabric/plugins/precision/deepspeed.py  |  9 ++++++++-
 src/lightning/fabric/plugins/precision/double.py     |  7 +++++++
 src/lightning/fabric/plugins/precision/fsdp.py       | 12 +++++++++++-
 src/lightning/fabric/plugins/precision/half.py       |  7 +++++++
 .../fabric/plugins/precision/transformer_engine.py   |  7 +++++++
 src/lightning/fabric/plugins/precision/xla.py        |  4 +++-
 8 files changed, 59 insertions(+), 3 deletions(-)

diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py
index 8f0bd8aa42241..b57455d6f4aac 100644
--- a/src/lightning/fabric/plugins/precision/amp.py
+++ b/src/lightning/fabric/plugins/precision/amp.py
@@ -18,6 +18,7 @@
 from torch import Tensor
 from torch.nn import Module
 from torch.optim import LBFGS, Optimizer
+from typing_extensions import override
 
 from lightning.fabric.accelerators.cuda import _patch_cuda_is_available
 from lightning.fabric.plugins.precision.precision import Precision
@@ -59,20 +60,25 @@ def __init__(
 
         self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16
 
+    @override
     def forward_context(self) -> ContextManager:
         return torch.autocast(self.device, dtype=self._desired_input_dtype)
 
+    @override
     def convert_input(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
 
+    @override
     def convert_output(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
 
+    @override
     def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
         if self.scaler is not None:
             tensor = self.scaler.scale(tensor)
         super().backward(tensor, model, *args, **kwargs)
 
+    @override
     def optimizer_step(
         self,
         optimizer: Optimizable,
@@ -88,15 +94,18 @@ def optimizer_step(
         self.scaler.update()
         return step_output
 
+    @override
     def state_dict(self) -> Dict[str, Any]:
         if self.scaler is not None:
             return self.scaler.state_dict()
         return {}
 
+    @override
     def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
         if self.scaler is not None:
             self.scaler.load_state_dict(state_dict)
 
+    @override
     def unscale_gradients(self, optimizer: Optimizer) -> None:
         scaler = self.scaler
         if scaler is not None:
diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py
index ab4db3a4d6b71..e3e3533e4a3b6 100644
--- a/src/lightning/fabric/plugins/precision/bitsandbytes.py
+++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py
@@ -25,6 +25,7 @@
 from lightning_utilities.core.imports import RequirementCache
 from torch import Tensor
 from torch.nn.modules.module import _IncompatibleKeys
+from typing_extensions import override
 
 from lightning.fabric.plugins.precision.precision import Precision
 from lightning.fabric.plugins.precision.utils import (
@@ -96,6 +97,7 @@ def __init__(
         self.dtype = dtype
         self.ignore_modules = ignore_modules or set()
 
+    @override
     def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
         # avoid naive users thinking they quantized their model
         if not any(isinstance(m, torch.nn.Linear) for m in module.modules()):
@@ -116,9 +118,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
                 m.compute_type_is_set = False
         return module
 
+    @override
     def tensor_init_context(self) -> ContextManager:
         return _DtypeContextManager(self.dtype)
 
+    @override
     def module_init_context(self) -> ContextManager:
         if self.ignore_modules:
             # cannot patch the Linear class if the user wants to skip some submodules
@@ -136,12 +140,15 @@ def module_init_context(self) -> ContextManager:
         stack.enter_context(context_manager)
         return stack
 
+    @override
     def forward_context(self) -> ContextManager:
         return _DtypeContextManager(self.dtype)
 
+    @override
     def convert_input(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype)
 
+    @override
     def convert_output(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
 
diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py
index 7f5e95bc15976..2fcaa38258e3a 100644
--- a/src/lightning/fabric/plugins/precision/deepspeed.py
+++ b/src/lightning/fabric/plugins/precision/deepspeed.py
@@ -18,7 +18,7 @@
 from lightning_utilities.core.apply_func import apply_to_collection
 from torch import Tensor
 from torch.nn import Module
-from typing_extensions import get_args
+from typing_extensions import get_args, override
 
 from lightning.fabric.plugins.precision.precision import Precision
 from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
@@ -61,29 +61,36 @@ def __init__(self, precision: _PRECISION_INPUT) -> None:
         }
         self._desired_dtype = precision_to_type[self.precision]
 
+    @override
     def convert_module(self, module: Module) -> Module:
         if "true" in self.precision:
             return module.to(dtype=self._desired_dtype)
         return module
 
+    @override
     def tensor_init_context(self) -> ContextManager:
         if "true" not in self.precision:
             return nullcontext()
         return _DtypeContextManager(self._desired_dtype)
 
+    @override
     def module_init_context(self) -> ContextManager:
         return self.tensor_init_context()
 
+    @override
     def convert_input(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
 
+    @override
     def convert_output(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
 
+    @override
     def backward(self, tensor: Tensor, model: "DeepSpeedEngine", *args: Any, **kwargs: Any) -> None:
         """Performs back-propagation using DeepSpeed's engine."""
         model.backward(tensor, *args, **kwargs)
 
+    @override
     def optimizer_step(
         self,
         optimizer: Steppable,
diff --git a/src/lightning/fabric/plugins/precision/double.py b/src/lightning/fabric/plugins/precision/double.py
index 3e38ccce67be1..0a857499f3d34 100644
--- a/src/lightning/fabric/plugins/precision/double.py
+++ b/src/lightning/fabric/plugins/precision/double.py
@@ -17,6 +17,7 @@
 from lightning_utilities.core.apply_func import apply_to_collection
 from torch import Tensor
 from torch.nn import Module
+from typing_extensions import override
 
 from lightning.fabric.plugins.precision.precision import Precision
 from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
@@ -27,20 +28,26 @@ class DoublePrecision(Precision):
 
     precision: Literal["64-true"] = "64-true"
 
+    @override
     def convert_module(self, module: Module) -> Module:
         return module.double()
 
+    @override
     def tensor_init_context(self) -> ContextManager:
         return _DtypeContextManager(torch.double)
 
+    @override
     def module_init_context(self) -> ContextManager:
         return self.tensor_init_context()
 
+    @override
     def forward_context(self) -> ContextManager:
         return self.tensor_init_context()
 
+    @override
     def convert_input(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.double)
 
+    @override
     def convert_output(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py
index ebdafcd651d93..9fa3eec33065f 100644
--- a/src/lightning/fabric/plugins/precision/fsdp.py
+++ b/src/lightning/fabric/plugins/precision/fsdp.py
@@ -18,7 +18,7 @@
 from torch import Tensor
 from torch.nn import Module
 from torch.optim import Optimizer
-from typing_extensions import get_args
+from typing_extensions import get_args, override
 
 from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
 from lightning.fabric.plugins.precision.precision import Precision
@@ -103,28 +103,35 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
             buffer_dtype=buffer_dtype,
         )
 
+    @override
     def tensor_init_context(self) -> ContextManager:
         return _DtypeContextManager(self._desired_input_dtype)
 
+    @override
     def module_init_context(self) -> ContextManager:
         return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)
 
+    @override
     def forward_context(self) -> ContextManager:
         if "mixed" in self.precision:
             return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
         return self.tensor_init_context()
 
+    @override
     def convert_input(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
 
+    @override
     def convert_output(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
 
+    @override
     def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
         if self.scaler is not None:
             tensor = cast(Tensor, self.scaler.scale(tensor))
         super().backward(tensor, model, *args, **kwargs)
 
+    @override
     def optimizer_step(
         self,
         optimizer: Optimizable,
@@ -138,6 +145,7 @@ def optimizer_step(
         self.scaler.update()
         return step_output
 
+    @override
     def unscale_gradients(self, optimizer: Optimizer) -> None:
         scaler = self.scaler
         if scaler is not None:
@@ -145,11 +153,13 @@ def unscale_gradients(self, optimizer: Optimizer) -> None:
                 raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.")
             scaler.unscale_(optimizer)
 
+    @override
     def state_dict(self) -> Dict[str, Any]:
         if self.scaler is not None:
             return self.scaler.state_dict()
         return {}
 
+    @override
     def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
         if self.scaler is not None:
             self.scaler.load_state_dict(state_dict)
diff --git a/src/lightning/fabric/plugins/precision/half.py b/src/lightning/fabric/plugins/precision/half.py
index 77d02d0c000c2..32ca7da815213 100644
--- a/src/lightning/fabric/plugins/precision/half.py
+++ b/src/lightning/fabric/plugins/precision/half.py
@@ -17,6 +17,7 @@
 from lightning_utilities.core.apply_func import apply_to_collection
 from torch import Tensor
 from torch.nn import Module
+from typing_extensions import override
 
 from lightning.fabric.plugins.precision.precision import Precision
 from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
@@ -36,20 +37,26 @@ def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> No
         self.precision = precision
         self._desired_input_dtype = torch.bfloat16 if precision == "bf16-true" else torch.float16
 
+    @override
     def convert_module(self, module: Module) -> Module:
         return module.to(dtype=self._desired_input_dtype)
 
+    @override
     def tensor_init_context(self) -> ContextManager:
         return _DtypeContextManager(self._desired_input_dtype)
 
+    @override
     def module_init_context(self) -> ContextManager:
         return self.tensor_init_context()
 
+    @override
     def forward_context(self) -> ContextManager:
         return self.tensor_init_context()
 
+    @override
     def convert_input(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
 
+    @override
     def convert_output(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py
index caeef6d32287c..1892d18cb3afb 100644
--- a/src/lightning/fabric/plugins/precision/transformer_engine.py
+++ b/src/lightning/fabric/plugins/precision/transformer_engine.py
@@ -19,6 +19,7 @@
 from lightning_utilities import apply_to_collection
 from lightning_utilities.core.imports import RequirementCache
 from torch import Tensor
+from typing_extensions import override
 
 from lightning.fabric.plugins.precision.precision import Precision
 from lightning.fabric.plugins.precision.utils import (
@@ -89,6 +90,7 @@ def __init__(
         self.replace_layers = replace_layers
         self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype
 
+    @override
     def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
         # avoid converting if any is found. assume the user took care of it
         if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()):
@@ -103,9 +105,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
         module = module.to(dtype=self.weights_dtype)
         return module
 
+    @override
     def tensor_init_context(self) -> ContextManager:
         return _DtypeContextManager(self.weights_dtype)
 
+    @override
     def module_init_context(self) -> ContextManager:
         dtype_ctx = self.tensor_init_context()
         stack = ExitStack()
@@ -122,6 +126,7 @@ def module_init_context(self) -> ContextManager:
         stack.enter_context(dtype_ctx)
         return stack
 
+    @override
     def forward_context(self) -> ContextManager:
         dtype_ctx = _DtypeContextManager(self.weights_dtype)
         fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
@@ -135,9 +140,11 @@ def forward_context(self) -> ContextManager:
         stack.enter_context(autocast_ctx)
         return stack
 
+    @override
     def convert_input(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype)
 
+    @override
     def convert_output(self, data: Any) -> Any:
         return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
 
diff --git a/src/lightning/fabric/plugins/precision/xla.py b/src/lightning/fabric/plugins/precision/xla.py
index 3aff7d5f6866b..fdb30032b3cdd 100644
--- a/src/lightning/fabric/plugins/precision/xla.py
+++ b/src/lightning/fabric/plugins/precision/xla.py
@@ -15,7 +15,7 @@
 from typing import Any, Literal
 
 import torch
-from typing_extensions import get_args
+from typing_extensions import get_args, override
 
 from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
 from lightning.fabric.plugins.precision.precision import Precision
@@ -56,6 +56,7 @@ def __init__(self, precision: _PRECISION_INPUT) -> None:
         else:
             self._desired_dtype = torch.float32
 
+    @override
     def optimizer_step(
         self,
         optimizer: Optimizable,
@@ -66,6 +67,7 @@ def optimizer_step(
         # you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True`
         return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True)
 
+    @override
     def teardown(self) -> None:
         os.environ.pop("XLA_USE_BF16", None)
         os.environ.pop("XLA_USE_F16", None)