Skip to content

Commit

Permalink
fix autocast related deprecation warning
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Oct 21, 2024
1 parent 3ea7dd3 commit 3fe0c0f
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 33 deletions.
13 changes: 7 additions & 6 deletions tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.jit import gpu_autocast_ctx

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
Expand Down Expand Up @@ -333,7 +334,7 @@ def test_grad_scaler(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -342,7 +343,7 @@ def test_grad_scaler(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down Expand Up @@ -384,7 +385,7 @@ def test_grad_scaler_capturable(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -393,7 +394,7 @@ def test_grad_scaler_capturable(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down Expand Up @@ -442,7 +443,7 @@ def test_grad_scaler_capturable_master(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -451,7 +452,7 @@ def test_grad_scaler_capturable_master(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down
61 changes: 44 additions & 17 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,26 +245,53 @@ def in_fp8_activation_recompute_phase() -> bool:
return _FP8_ACTIVATION_RECOMPUTE_PHASE


def _get_active_autocast_contexts():
"""
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
at the time of this function's execution.
"""
autocast_cached = torch.is_autocast_cache_enabled()
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:

gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)
def _get_active_autocast_contexts():
"""
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
at the time of this function's execution.
"""
autocast_cached = torch.is_autocast_cache_enabled()

cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)
gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_ctx = torch.amp.autocast(
"cuda", gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)

cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
cpu_autocast_ctx = torch.amp.autocast(
"cpu", cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)

return gpu_autocast_ctx, cpu_autocast_ctx

else:

def _get_active_autocast_contexts():
"""
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
at the time of this function's execution.
"""
autocast_cached = torch.is_autocast_cache_enabled()

gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)

cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)

return gpu_autocast_ctx, cpu_autocast_ctx
return gpu_autocast_ctx, cpu_autocast_ctx


class _CheckpointFunction(torch.autograd.Function):
Expand Down
14 changes: 10 additions & 4 deletions transformer_engine/pytorch/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""NVFuser functions and JIT utilities"""
import os
from typing import Callable, Optional, Tuple
from functools import partial

import torch

Expand Down Expand Up @@ -33,6 +34,11 @@
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable

if torch.__version__ >= "2.4":
gpu_autocast_ctx = partial(torch.amp.autocast, device_type="cuda")
else:
gpu_autocast_ctx = torch.cuda.amp.autocast


def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options."""
Expand Down Expand Up @@ -110,7 +116,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:

def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp)
Expand All @@ -120,7 +126,7 @@ def bgrad_dgelu_fused(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp)
Expand Down Expand Up @@ -161,7 +167,7 @@ def bias_dropout_add_fused_train(
) -> torch.Tensor:
"""Disable native AMP and enable grad for BDA"""
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_train_(x, bias, residual, prob)


Expand All @@ -177,7 +183,7 @@ def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
"""Disable native AMP for BDA"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_inference_(x, bias, residual, prob)


Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ..constants import dist_group_type
from ..float8_tensor import Float8Tensor
from ..utils import torch_get_autocast_gpu_dtype

__all__ = ["initialize_ub", "destroy_ub"]

Expand Down Expand Up @@ -619,7 +620,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
self.activation_dtype = torch_get_autocast_gpu_dtype()
return

# All checks after this have already been performed once, thus skip
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
layernorm_fwd_inf,
)
from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
from ..utils import cast_if_needed, torch_get_autocast_gpu_dtype

__all__ = ["LayerNorm"]

Expand Down Expand Up @@ -193,7 +193,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
self.activation_dtype = torch_get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .. import cpp_extensions as tex
from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
from ..utils import cast_if_needed, torch_get_autocast_gpu_dtype


__all__ = ["RMSNorm"]
Expand Down Expand Up @@ -190,7 +190,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
self.activation_dtype = torch_get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from transformer_engine.pytorch.utils import (
cast_if_needed,
get_default_init_method,
torch_get_autocast_gpu_dtype,
)
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
Expand Down Expand Up @@ -677,7 +678,7 @@ def forward(

# For AMP
if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype())
hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())

# Self attention.
self_attention_outputs = self.self_attention(
Expand Down
13 changes: 13 additions & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,16 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2


TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:

def torch_get_autocast_gpu_dtype() -> torch.dtype:
return torch.get_autocast_dtype("cuda")

else:

def torch_get_autocast_gpu_dtype() -> torch.dtype:
return torch.get_autocast_gpu_dtype()

0 comments on commit 3fe0c0f

Please sign in to comment.