Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Fix autocast deprecation warnings #1277

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import gpu_autocast_ctx

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
Expand Down Expand Up @@ -550,7 +551,7 @@ def test_grad_scaler(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -559,7 +560,7 @@ def test_grad_scaler(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down Expand Up @@ -601,7 +602,7 @@ def test_grad_scaler_capturable(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -610,7 +611,7 @@ def test_grad_scaler_capturable(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down Expand Up @@ -659,7 +660,7 @@ def test_grad_scaler_capturable_master(self):
gt_ = gt.clone()

# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()

Expand All @@ -668,7 +669,7 @@ def test_grad_scaler_capturable_master(self):
scaler.update()

# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

Expand Down
41 changes: 30 additions & 11 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules

from .utils import safely_set_viewless_tensor_data
from .utils import safely_set_viewless_tensor_data, is_torch_min_version
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager
from .float8_tensor import Float8Tensor
Expand Down Expand Up @@ -252,17 +252,36 @@ def _get_active_autocast_contexts():
"""
autocast_cached = torch.is_autocast_cache_enabled()

gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)
if is_torch_min_version("2.4.0a0"):
gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_ctx = torch.amp.autocast(
"cuda",
enabled=gpu_autocast_enabled,
dtype=gpu_autocast_dtype,
cache_enabled=autocast_cached,
)

cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)
cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
cpu_autocast_ctx = torch.amp.autocast(
"cpu",
enabled=cpu_autocast_enabled,
dtype=cpu_autocast_dtype,
cache_enabled=autocast_cached,
)
else:
gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)

cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)

return gpu_autocast_ctx, cpu_autocast_ctx

Expand Down
24 changes: 12 additions & 12 deletions transformer_engine/pytorch/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@

import torch

from .utils import is_torch_min_version, gpu_autocast_ctx

# pylint: disable=unnecessary-lambda-assignment

jit_fuser = torch.jit.script
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
if is_torch_min_version("2a0") and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile

# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
if is_torch_min_version("2.2a0") and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = torch.compile

# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
if is_torch_min_version("2a0"):
import torch._dynamo

if torch.__version__ >= "2.1":
if is_torch_min_version("2.1a0"):
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(
f, recursive=recursive
)
Expand All @@ -37,11 +39,9 @@
def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 2:
if is_torch_min_version("2.2.0a0"):
pass
elif (TORCH_MAJOR == 2) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
elif is_torch_min_version("1.10.0a0"):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
Expand Down Expand Up @@ -110,7 +110,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:

def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp)
Expand All @@ -120,7 +120,7 @@ def bgrad_dgelu_fused(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp)
Expand Down Expand Up @@ -161,7 +161,7 @@ def bias_dropout_add_fused_train(
) -> torch.Tensor:
"""Disable native AMP and enable grad for BDA"""
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_train_(x, bias, residual, prob)


Expand All @@ -177,7 +177,7 @@ def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
"""Disable native AMP for BDA"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_inference_(x, bias, residual, prob)


Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ..constants import dist_group_type
from ..float8_tensor import Float8Tensor
from ..utils import torch_get_autocast_gpu_dtype

__all__ = ["initialize_ub", "destroy_ub"]

Expand Down Expand Up @@ -696,7 +697,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
self.activation_dtype = torch_get_autocast_gpu_dtype()
return

# All checks after this have already been performed once, thus skip
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from transformer_engine.pytorch.utils import (
cast_if_needed,
get_default_init_method,
torch_get_autocast_gpu_dtype,
is_torch_min_version,
)
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
Expand Down Expand Up @@ -430,9 +432,7 @@ def __init__(
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None

# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
use_nvfuser = is_torch_min_version("1.10.0a0") and not is_torch_min_version("2.2.0a0")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We stop using nvfuser starting from torch 2.2, so I think this update is needed. Tell me if you have other concerns.

self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad

if self.bias_dropout_fusion:
Expand Down Expand Up @@ -677,7 +677,7 @@ def forward(

# For AMP
if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype())
hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())

# Self attention.
self_attention_outputs = self.self_attention(
Expand Down
23 changes: 23 additions & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,20 @@
import functools
import math
from typing import Any, Callable, Optional, Tuple
from packaging.version import Version as PkgVersion

import torch
import transformer_engine.pytorch.cpp_extensions as ext

_torch_version = PkgVersion(torch.__version__)


def is_torch_min_version(version, check_equality=True):
"""Check if minimum version of `torch` is installed."""
if check_equality:
return _torch_version >= PkgVersion(version)
return _torch_version > PkgVersion(version)


def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""Check if any of the given tensors require gradient."""
Expand Down Expand Up @@ -305,3 +315,16 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2


def torch_get_autocast_gpu_dtype() -> torch.dtype:
"""Get PyTorch autocast GPU dtype."""
if is_torch_min_version("2.4.0a0"):
return torch.get_autocast_dtype("cuda")
return torch.get_autocast_gpu_dtype()


if is_torch_min_version("2.4.0a0"):
gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda")
else:
gpu_autocast_ctx = torch.cuda.amp.autocast
Loading