Skip to content

Commit

Permalink
Add @override for files in src/lightning/fabric/plugins/precision (
Browse files Browse the repository at this point in the history
  • Loading branch information
VictorPrins authored Dec 14, 2023
1 parent c985400 commit 3b1643c
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 3 deletions.
9 changes: 9 additions & 0 deletions src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer
from typing_extensions import override

from lightning.fabric.accelerators.cuda import _patch_cuda_is_available
from lightning.fabric.plugins.precision.precision import Precision
Expand Down Expand Up @@ -59,20 +60,25 @@ def __init__(

self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16

@override
def forward_context(self) -> ContextManager:
return torch.autocast(self.device, dtype=self._desired_input_dtype)

@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)

@override
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())

@override
def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
if self.scaler is not None:
tensor = self.scaler.scale(tensor)
super().backward(tensor, model, *args, **kwargs)

@override
def optimizer_step(
self,
optimizer: Optimizable,
Expand All @@ -88,15 +94,18 @@ def optimizer_step(
self.scaler.update()
return step_output

@override
def state_dict(self) -> Dict[str, Any]:
if self.scaler is not None:
return self.scaler.state_dict()
return {}

@override
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)

@override
def unscale_gradients(self, optimizer: Optimizer) -> None:
scaler = self.scaler
if scaler is not None:
Expand Down
7 changes: 7 additions & 0 deletions src/lightning/fabric/plugins/precision/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from torch.nn.modules.module import _IncompatibleKeys
from typing_extensions import override

from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import (
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(
self.dtype = dtype
self.ignore_modules = ignore_modules or set()

@override
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
# avoid naive users thinking they quantized their model
if not any(isinstance(m, torch.nn.Linear) for m in module.modules()):
Expand All @@ -116,9 +118,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
m.compute_type_is_set = False
return module

@override
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self.dtype)

@override
def module_init_context(self) -> ContextManager:
if self.ignore_modules:
# cannot patch the Linear class if the user wants to skip some submodules
Expand All @@ -136,12 +140,15 @@ def module_init_context(self) -> ContextManager:
stack.enter_context(context_manager)
return stack

@override
def forward_context(self) -> ContextManager:
return _DtypeContextManager(self.dtype)

@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype)

@override
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())

Expand Down
9 changes: 8 additions & 1 deletion src/lightning/fabric/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
from typing_extensions import get_args
from typing_extensions import get_args, override

from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
Expand Down Expand Up @@ -61,29 +61,36 @@ def __init__(self, precision: _PRECISION_INPUT) -> None:
}
self._desired_dtype = precision_to_type[self.precision]

@override
def convert_module(self, module: Module) -> Module:
if "true" in self.precision:
return module.to(dtype=self._desired_dtype)
return module

@override
def tensor_init_context(self) -> ContextManager:
if "true" not in self.precision:
return nullcontext()
return _DtypeContextManager(self._desired_dtype)

@override
def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)

@override
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())

@override
def backward(self, tensor: Tensor, model: "DeepSpeedEngine", *args: Any, **kwargs: Any) -> None:
"""Performs back-propagation using DeepSpeed's engine."""
model.backward(tensor, *args, **kwargs)

@override
def optimizer_step(
self,
optimizer: Steppable,
Expand Down
7 changes: 7 additions & 0 deletions src/lightning/fabric/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
from typing_extensions import override

from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
Expand All @@ -27,20 +28,26 @@ class DoublePrecision(Precision):

precision: Literal["64-true"] = "64-true"

@override
def convert_module(self, module: Module) -> Module:
return module.double()

@override
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(torch.double)

@override
def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

@override
def forward_context(self) -> ContextManager:
return self.tensor_init_context()

@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.double)

@override
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
12 changes: 11 additions & 1 deletion src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import get_args
from typing_extensions import get_args, override

from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.precision import Precision
Expand Down Expand Up @@ -103,28 +103,35 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
buffer_dtype=buffer_dtype,
)

@override
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)

@override
def module_init_context(self) -> ContextManager:
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)

@override
def forward_context(self) -> ContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return self.tensor_init_context()

@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)

@override
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())

@override
def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
if self.scaler is not None:
tensor = cast(Tensor, self.scaler.scale(tensor))
super().backward(tensor, model, *args, **kwargs)

@override
def optimizer_step(
self,
optimizer: Optimizable,
Expand All @@ -138,18 +145,21 @@ def optimizer_step(
self.scaler.update()
return step_output

@override
def unscale_gradients(self, optimizer: Optimizer) -> None:
scaler = self.scaler
if scaler is not None:
if _optimizer_handles_unscaling(optimizer):
raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.")
scaler.unscale_(optimizer)

@override
def state_dict(self) -> Dict[str, Any]:
if self.scaler is not None:
return self.scaler.state_dict()
return {}

@override
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)
7 changes: 7 additions & 0 deletions src/lightning/fabric/plugins/precision/half.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
from typing_extensions import override

from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
Expand All @@ -36,20 +37,26 @@ def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> No
self.precision = precision
self._desired_input_dtype = torch.bfloat16 if precision == "bf16-true" else torch.float16

@override
def convert_module(self, module: Module) -> Module:
return module.to(dtype=self._desired_input_dtype)

@override
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)

@override
def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

@override
def forward_context(self) -> ContextManager:
return self.tensor_init_context()

@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)

@override
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
7 changes: 7 additions & 0 deletions src/lightning/fabric/plugins/precision/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from lightning_utilities import apply_to_collection
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import override

from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import (
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
self.replace_layers = replace_layers
self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype

@override
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
# avoid converting if any is found. assume the user took care of it
if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()):
Expand All @@ -103,9 +105,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
module = module.to(dtype=self.weights_dtype)
return module

@override
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self.weights_dtype)

@override
def module_init_context(self) -> ContextManager:
dtype_ctx = self.tensor_init_context()
stack = ExitStack()
Expand All @@ -122,6 +126,7 @@ def module_init_context(self) -> ContextManager:
stack.enter_context(dtype_ctx)
return stack

@override
def forward_context(self) -> ContextManager:
dtype_ctx = _DtypeContextManager(self.weights_dtype)
fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
Expand All @@ -135,9 +140,11 @@ def forward_context(self) -> ContextManager:
stack.enter_context(autocast_ctx)
return stack

@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype)

@override
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())

Expand Down
4 changes: 3 additions & 1 deletion src/lightning/fabric/plugins/precision/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, Literal

import torch
from typing_extensions import get_args
from typing_extensions import get_args, override

from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins.precision.precision import Precision
Expand Down Expand Up @@ -56,6 +56,7 @@ def __init__(self, precision: _PRECISION_INPUT) -> None:
else:
self._desired_dtype = torch.float32

@override
def optimizer_step(
self,
optimizer: Optimizable,
Expand All @@ -66,6 +67,7 @@ def optimizer_step(
# you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True`
return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True)

@override
def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)
os.environ.pop("XLA_USE_F16", None)

0 comments on commit 3b1643c

Please sign in to comment.