diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index b659048b2b..c3cd3d2bcb 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -17,7 +17,7 @@ from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules -from .utils import safely_set_viewless_tensor_data +from .utils import safely_set_viewless_tensor_data, is_torch_min_version from .constants import dist_group_type from .fp8 import FP8GlobalStateManager from .float8_tensor import Float8Tensor @@ -252,9 +252,7 @@ def _get_active_autocast_contexts(): """ autocast_cached = torch.is_autocast_cache_enabled() - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: + if is_torch_min_version("2.4.0a0"): gpu_autocast_enabled = torch.is_autocast_enabled("cuda") gpu_autocast_dtype = torch.get_autocast_dtype("cuda") gpu_autocast_ctx = torch.amp.autocast( diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 0a060e8305..8cbe2a1327 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -5,28 +5,29 @@ """NVFuser functions and JIT utilities""" import os from typing import Callable, Optional, Tuple -from functools import partial import torch +from .utils import is_torch_min_version, gpu_autocast_ctx + # pylint: disable=unnecessary-lambda-assignment jit_fuser = torch.jit.script -if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): +if is_torch_min_version("2a0") and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): jit_fuser = torch.compile # See: https://github.com/NVIDIA/TransformerEngine/issues/597 dropout_fuser = torch.jit.script -if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): +if is_torch_min_version("2.2a0") and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): dropout_fuser = torch.compile # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 no_torch_dynamo = lambda recursive=True: lambda func: func -if torch.__version__ >= "2": +if is_torch_min_version("2a0"): import torch._dynamo - if torch.__version__ >= "2.1": + if is_torch_min_version("2.1a0"): no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable( f, recursive=recursive ) @@ -34,20 +35,13 @@ # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True no_torch_dynamo = lambda recursive=True: torch._dynamo.disable -if torch.__version__ >= "2.4": - gpu_autocast_ctx = partial(torch.amp.autocast, device_type="cuda") -else: - gpu_autocast_ctx = torch.cuda.amp.autocast - def set_jit_fusion_options() -> None: """Set PyTorch JIT layer fusion options.""" # flags required to enable jit fusion kernels - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - if TORCH_MAJOR == 2 and TORCH_MINOR >= 2: + if is_torch_min_version("2.2.0a0"): pass - elif (TORCH_MAJOR == 2) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): + elif is_torch_min_version("1.10.0a0"): # nvfuser torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index bf44c46658..210dd62931 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -27,6 +27,7 @@ cast_if_needed, get_default_init_method, torch_get_autocast_gpu_dtype, + is_torch_min_version, ) from transformer_engine.pytorch.constants import ( AttnMaskTypes, @@ -431,9 +432,7 @@ def __init__( self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None # Set bias+dropout+add fusion grad_enable execution handler. - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + use_nvfuser = is_torch_min_version("1.10.0a0") and not is_torch_min_version("2.2.0a0") self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad if self.bias_dropout_fusion: diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 90a0edc2c0..fe2f087d48 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -6,11 +6,21 @@ from __future__ import annotations import functools import math +from packaging.version import Version as PkgVersion from typing import Any, Callable, Optional, Tuple import torch import transformer_engine.pytorch.cpp_extensions as ext +_torch_version = PkgVersion(torch.__version__) + + +def is_torch_min_version(version, check_equality=True): + """Check if minimum version of `torch` is installed.""" + if check_equality: + return _torch_version >= PkgVersion(version) + return _torch_version > PkgVersion(version) + def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """Check if any of the given tensors require gradient.""" @@ -309,8 +319,12 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: def torch_get_autocast_gpu_dtype() -> torch.dtype: """Get PyTorch autocast GPU dtype.""" - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: + if is_torch_min_version("2.4.0a0"): return torch.get_autocast_dtype("cuda") return torch.get_autocast_gpu_dtype() + + +if is_torch_min_version("2.4.0a0"): + gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda") +else: + gpu_autocast_ctx = torch.cuda.amp.autocast