From 3fe0c0f4007aafa0dfb7ea93b005cd89cfea571e Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sun, 20 Oct 2024 19:51:43 -0700 Subject: [PATCH] fix autocast related deprecation warning Signed-off-by: Xin Yao --- tests/pytorch/test_fused_optimizer.py | 13 ++-- transformer_engine/pytorch/distributed.py | 61 +++++++++++++------ transformer_engine/pytorch/jit.py | 14 +++-- transformer_engine/pytorch/module/base.py | 3 +- .../pytorch/module/layernorm.py | 4 +- transformer_engine/pytorch/module/rmsnorm.py | 4 +- transformer_engine/pytorch/transformer.py | 3 +- transformer_engine/pytorch/utils.py | 13 ++++ 8 files changed, 82 insertions(+), 33 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index d19fc5a521..f804754949 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -14,6 +14,7 @@ from transformer_engine.pytorch import fp8_model_init from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.jit import gpu_autocast_ctx # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -333,7 +334,7 @@ def test_grad_scaler(self): gt_ = gt.clone() # Reference - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model(x) loss = ((gt - y) ** 2).mean() @@ -342,7 +343,7 @@ def test_grad_scaler(self): scaler.update() # DUT - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model_(x) loss_ = ((gt_ - y) ** 2).mean() @@ -384,7 +385,7 @@ def test_grad_scaler_capturable(self): gt_ = gt.clone() # Reference - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model(x) loss = ((gt - y) ** 2).mean() @@ -393,7 +394,7 @@ def test_grad_scaler_capturable(self): scaler.update() # DUT - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model_(x) loss_ = ((gt_ - y) ** 2).mean() @@ -442,7 +443,7 @@ def test_grad_scaler_capturable_master(self): gt_ = gt.clone() # Reference - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model(x) loss = ((gt - y) ** 2).mean() @@ -451,7 +452,7 @@ def test_grad_scaler_capturable_master(self): scaler.update() # DUT - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model_(x) loss_ = ((gt_ - y) ** 2).mean() diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 490ac3b160..e3af4d60f8 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -245,26 +245,53 @@ def in_fp8_activation_recompute_phase() -> bool: return _FP8_ACTIVATION_RECOMPUTE_PHASE -def _get_active_autocast_contexts(): - """ - Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state - at the time of this function's execution. - """ - autocast_cached = torch.is_autocast_cache_enabled() +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) +if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: - gpu_autocast_enabled = torch.is_autocast_enabled() - gpu_autocast_dtype = torch.get_autocast_gpu_dtype() - gpu_autocast_ctx = torch.cuda.amp.autocast( - gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached - ) + def _get_active_autocast_contexts(): + """ + Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state + at the time of this function's execution. + """ + autocast_cached = torch.is_autocast_cache_enabled() - cpu_autocast_enabled = torch.is_autocast_cpu_enabled() - cpu_autocast_dtype = torch.get_autocast_cpu_dtype() - cpu_autocast_ctx = torch.cpu.amp.autocast( - cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached - ) + gpu_autocast_enabled = torch.is_autocast_enabled("cuda") + gpu_autocast_dtype = torch.get_autocast_dtype("cuda") + gpu_autocast_ctx = torch.amp.autocast( + "cuda", gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached + ) + + cpu_autocast_enabled = torch.is_autocast_enabled("cpu") + cpu_autocast_dtype = torch.get_autocast_dtype("cpu") + cpu_autocast_ctx = torch.amp.autocast( + "cpu", cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached + ) + + return gpu_autocast_ctx, cpu_autocast_ctx + +else: + + def _get_active_autocast_contexts(): + """ + Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state + at the time of this function's execution. + """ + autocast_cached = torch.is_autocast_cache_enabled() + + gpu_autocast_enabled = torch.is_autocast_enabled() + gpu_autocast_dtype = torch.get_autocast_gpu_dtype() + gpu_autocast_ctx = torch.cuda.amp.autocast( + gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached + ) + + cpu_autocast_enabled = torch.is_autocast_cpu_enabled() + cpu_autocast_dtype = torch.get_autocast_cpu_dtype() + cpu_autocast_ctx = torch.cpu.amp.autocast( + cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached + ) - return gpu_autocast_ctx, cpu_autocast_ctx + return gpu_autocast_ctx, cpu_autocast_ctx class _CheckpointFunction(torch.autograd.Function): diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index ed08627e95..0a060e8305 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -5,6 +5,7 @@ """NVFuser functions and JIT utilities""" import os from typing import Callable, Optional, Tuple +from functools import partial import torch @@ -33,6 +34,11 @@ # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True no_torch_dynamo = lambda recursive=True: torch._dynamo.disable +if torch.__version__ >= "2.4": + gpu_autocast_ctx = partial(torch.amp.autocast, device_type="cuda") +else: + gpu_autocast_ctx = torch.cuda.amp.autocast + def set_jit_fusion_options() -> None: """Set PyTorch JIT layer fusion options.""" @@ -110,7 +116,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: """Disable native AMP for bias_gelu_fused_""" - with torch.cuda.amp.autocast(enabled=False): + with gpu_autocast_ctx(enabled=False): if bias is not None and bias.numel() != 0: return bias_gelu_fused_(inp, bias) return gelu_fused_(inp) @@ -120,7 +126,7 @@ def bgrad_dgelu_fused( grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """Disable native AMP for `bgrad_dgelu_fused_`""" - with torch.cuda.amp.autocast(enabled=False): + with gpu_autocast_ctx(enabled=False): if bias is not None and bias.numel() != 0: return bgrad_dgelu_fused_(grad_output, inp, bias) return None, dgelu_fused_(grad_output, inp) @@ -161,7 +167,7 @@ def bias_dropout_add_fused_train( ) -> torch.Tensor: """Disable native AMP and enable grad for BDA""" with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with gpu_autocast_ctx(enabled=False): return bias_dropout_add_fused_train_(x, bias, residual, prob) @@ -177,7 +183,7 @@ def bias_dropout_add_fused_inference( x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float ) -> torch.Tensor: """Disable native AMP for BDA""" - with torch.cuda.amp.autocast(enabled=False): + with gpu_autocast_ctx(enabled=False): return bias_dropout_add_fused_inference_(x, bias, residual, prob) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 12ce5f0877..7bb81b8cd4 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -38,6 +38,7 @@ ) from ..constants import dist_group_type from ..float8_tensor import Float8Tensor +from ..utils import torch_get_autocast_gpu_dtype __all__ = ["initialize_ub", "destroy_ub"] @@ -619,7 +620,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.activation_dtype = torch.get_autocast_gpu_dtype() + self.activation_dtype = torch_get_autocast_gpu_dtype() return # All checks after this have already been performed once, thus skip diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 0c439ac417..2ddc09adbd 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -16,7 +16,7 @@ layernorm_fwd_inf, ) from ..jit import no_torch_dynamo -from ..utils import cast_if_needed +from ..utils import cast_if_needed, torch_get_autocast_gpu_dtype __all__ = ["LayerNorm"] @@ -193,7 +193,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: # Note: This will soon be deprecated with # https://github.com/NVIDIA/TransformerEngine/pull/1033 if torch.is_autocast_enabled(): - self.activation_dtype = torch.get_autocast_gpu_dtype() + self.activation_dtype = torch_get_autocast_gpu_dtype() elif self.activation_dtype != inp.dtype: dtype = inp.dtype for name, param in self.named_parameters(): diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index fc6ec5746f..6041be47eb 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -13,7 +13,7 @@ from .. import cpp_extensions as tex from ..jit import no_torch_dynamo -from ..utils import cast_if_needed +from ..utils import cast_if_needed, torch_get_autocast_gpu_dtype __all__ = ["RMSNorm"] @@ -190,7 +190,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: # Note: This will soon be deprecated with # https://github.com/NVIDIA/TransformerEngine/pull/1033 if torch.is_autocast_enabled(): - self.activation_dtype = torch.get_autocast_gpu_dtype() + self.activation_dtype = torch_get_autocast_gpu_dtype() elif self.activation_dtype != inp.dtype: dtype = inp.dtype for name, param in self.named_parameters(): diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index ad5476450b..bf44c46658 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.utils import ( cast_if_needed, get_default_init_method, + torch_get_autocast_gpu_dtype, ) from transformer_engine.pytorch.constants import ( AttnMaskTypes, @@ -677,7 +678,7 @@ def forward( # For AMP if torch.is_autocast_enabled(): - hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) + hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype()) # Self attention. self_attention_outputs = self.self_attention( diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 947c642c2c..ac7819c9b0 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -305,3 +305,16 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: index2 = torch.cuda.current_device() return index1 == index2 return device1 == device2 + + +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) +if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: + + def torch_get_autocast_gpu_dtype() -> torch.dtype: + return torch.get_autocast_dtype("cuda") + +else: + + def torch_get_autocast_gpu_dtype() -> torch.dtype: + return torch.get_autocast_gpu_dtype()