diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 30367d2c5e..7a0de145da 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -527,13 +527,6 @@ def get_extra_state(self) -> torch.Tensor: # See: https://github.com/NVIDIA/TransformerEngine/pull/351 # See: https://github.com/NVIDIA/TransformerEngine/pull/363 - # Return immediately if op has no FP8 state - has_fp8_state = any( - self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output") - ) - if not has_fp8_state: - return torch.Tensor() - def to_cpu(src: torch.Tensor) -> torch.Tensor: """Helper function to make CPU copy of tensor @@ -548,12 +541,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Store FP8 state state = {} for mode in ("input", "param", "grad_output"): - - # Get state for a given FP8 tensor - if self.num_fp8_scales(mode) == 0: - state[mode] = None - continue - fp8_meta = self.get_fp8_meta(mode) + fp8_meta = self._fp8_metas if fp8_meta is None: continue state[mode] = {}