From 78c869d1204e82721b8ca43b5a224d0facf2a2ee Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 8 Jan 2025 21:45:05 +0000 Subject: [PATCH] Fix fusible ops checkpoint Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/ops/op.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 30367d2c5e..7a0de145da 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -527,13 +527,6 @@ def get_extra_state(self) -> torch.Tensor: # See: https://github.com/NVIDIA/TransformerEngine/pull/351 # See: https://github.com/NVIDIA/TransformerEngine/pull/363 - # Return immediately if op has no FP8 state - has_fp8_state = any( - self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output") - ) - if not has_fp8_state: - return torch.Tensor() - def to_cpu(src: torch.Tensor) -> torch.Tensor: """Helper function to make CPU copy of tensor @@ -548,12 +541,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Store FP8 state state = {} for mode in ("input", "param", "grad_output"): - - # Get state for a given FP8 tensor - if self.num_fp8_scales(mode) == 0: - state[mode] = None - continue - fp8_meta = self.get_fp8_meta(mode) + fp8_meta = self._fp8_metas if fp8_meta is None: continue state[mode] = {}