From d1d36448583ba98e8193c6bfee22a7795fe5e796 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sun, 20 Oct 2024 19:51:43 -0700 Subject: [PATCH 1/7] fix autocast related deprecation warning Signed-off-by: Xin Yao --- tests/pytorch/test_fused_optimizer.py | 13 ++-- transformer_engine/pytorch/distributed.py | 61 +++++++++++++------ transformer_engine/pytorch/jit.py | 14 +++-- transformer_engine/pytorch/module/base.py | 3 +- .../pytorch/module/layernorm.py | 4 +- transformer_engine/pytorch/module/rmsnorm.py | 4 +- transformer_engine/pytorch/transformer.py | 3 +- transformer_engine/pytorch/utils.py | 13 ++++ 8 files changed, 82 insertions(+), 33 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 4d4eb38342..dccf81829e 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -15,6 +15,7 @@ from transformer_engine.pytorch import fp8_model_init from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.jit import gpu_autocast_ctx # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -550,7 +551,7 @@ def test_grad_scaler(self): gt_ = gt.clone() # Reference - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model(x) loss = ((gt - y) ** 2).mean() @@ -559,7 +560,7 @@ def test_grad_scaler(self): scaler.update() # DUT - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model_(x) loss_ = ((gt_ - y) ** 2).mean() @@ -601,7 +602,7 @@ def test_grad_scaler_capturable(self): gt_ = gt.clone() # Reference - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model(x) loss = ((gt - y) ** 2).mean() @@ -610,7 +611,7 @@ def test_grad_scaler_capturable(self): scaler.update() # DUT - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model_(x) loss_ = ((gt_ - y) ** 2).mean() @@ -659,7 +660,7 @@ def test_grad_scaler_capturable_master(self): gt_ = gt.clone() # Reference - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model(x) loss = ((gt - y) ** 2).mean() @@ -668,7 +669,7 @@ def test_grad_scaler_capturable_master(self): scaler.update() # DUT - with torch.cuda.amp.autocast(enabled=True): + with gpu_autocast_ctx(enabled=True): y = self.model_(x) loss_ = ((gt_ - y) ** 2).mean() diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 490ac3b160..e3af4d60f8 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -245,26 +245,53 @@ def in_fp8_activation_recompute_phase() -> bool: return _FP8_ACTIVATION_RECOMPUTE_PHASE -def _get_active_autocast_contexts(): - """ - Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state - at the time of this function's execution. - """ - autocast_cached = torch.is_autocast_cache_enabled() +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) +if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: - gpu_autocast_enabled = torch.is_autocast_enabled() - gpu_autocast_dtype = torch.get_autocast_gpu_dtype() - gpu_autocast_ctx = torch.cuda.amp.autocast( - gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached - ) + def _get_active_autocast_contexts(): + """ + Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state + at the time of this function's execution. + """ + autocast_cached = torch.is_autocast_cache_enabled() - cpu_autocast_enabled = torch.is_autocast_cpu_enabled() - cpu_autocast_dtype = torch.get_autocast_cpu_dtype() - cpu_autocast_ctx = torch.cpu.amp.autocast( - cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached - ) + gpu_autocast_enabled = torch.is_autocast_enabled("cuda") + gpu_autocast_dtype = torch.get_autocast_dtype("cuda") + gpu_autocast_ctx = torch.amp.autocast( + "cuda", gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached + ) + + cpu_autocast_enabled = torch.is_autocast_enabled("cpu") + cpu_autocast_dtype = torch.get_autocast_dtype("cpu") + cpu_autocast_ctx = torch.amp.autocast( + "cpu", cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached + ) + + return gpu_autocast_ctx, cpu_autocast_ctx + +else: + + def _get_active_autocast_contexts(): + """ + Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state + at the time of this function's execution. + """ + autocast_cached = torch.is_autocast_cache_enabled() + + gpu_autocast_enabled = torch.is_autocast_enabled() + gpu_autocast_dtype = torch.get_autocast_gpu_dtype() + gpu_autocast_ctx = torch.cuda.amp.autocast( + gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached + ) + + cpu_autocast_enabled = torch.is_autocast_cpu_enabled() + cpu_autocast_dtype = torch.get_autocast_cpu_dtype() + cpu_autocast_ctx = torch.cpu.amp.autocast( + cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached + ) - return gpu_autocast_ctx, cpu_autocast_ctx + return gpu_autocast_ctx, cpu_autocast_ctx class _CheckpointFunction(torch.autograd.Function): diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index ed08627e95..0a060e8305 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -5,6 +5,7 @@ """NVFuser functions and JIT utilities""" import os from typing import Callable, Optional, Tuple +from functools import partial import torch @@ -33,6 +34,11 @@ # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True no_torch_dynamo = lambda recursive=True: torch._dynamo.disable +if torch.__version__ >= "2.4": + gpu_autocast_ctx = partial(torch.amp.autocast, device_type="cuda") +else: + gpu_autocast_ctx = torch.cuda.amp.autocast + def set_jit_fusion_options() -> None: """Set PyTorch JIT layer fusion options.""" @@ -110,7 +116,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: """Disable native AMP for bias_gelu_fused_""" - with torch.cuda.amp.autocast(enabled=False): + with gpu_autocast_ctx(enabled=False): if bias is not None and bias.numel() != 0: return bias_gelu_fused_(inp, bias) return gelu_fused_(inp) @@ -120,7 +126,7 @@ def bgrad_dgelu_fused( grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """Disable native AMP for `bgrad_dgelu_fused_`""" - with torch.cuda.amp.autocast(enabled=False): + with gpu_autocast_ctx(enabled=False): if bias is not None and bias.numel() != 0: return bgrad_dgelu_fused_(grad_output, inp, bias) return None, dgelu_fused_(grad_output, inp) @@ -161,7 +167,7 @@ def bias_dropout_add_fused_train( ) -> torch.Tensor: """Disable native AMP and enable grad for BDA""" with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with gpu_autocast_ctx(enabled=False): return bias_dropout_add_fused_train_(x, bias, residual, prob) @@ -177,7 +183,7 @@ def bias_dropout_add_fused_inference( x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float ) -> torch.Tensor: """Disable native AMP for BDA""" - with torch.cuda.amp.autocast(enabled=False): + with gpu_autocast_ctx(enabled=False): return bias_dropout_add_fused_inference_(x, bias, residual, prob) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 534174380f..5ca34f7597 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -38,6 +38,7 @@ ) from ..constants import dist_group_type from ..float8_tensor import Float8Tensor +from ..utils import torch_get_autocast_gpu_dtype __all__ = ["initialize_ub", "destroy_ub"] @@ -653,7 +654,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.activation_dtype = torch.get_autocast_gpu_dtype() + self.activation_dtype = torch_get_autocast_gpu_dtype() return # All checks after this have already been performed once, thus skip diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 0c439ac417..2ddc09adbd 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -16,7 +16,7 @@ layernorm_fwd_inf, ) from ..jit import no_torch_dynamo -from ..utils import cast_if_needed +from ..utils import cast_if_needed, torch_get_autocast_gpu_dtype __all__ = ["LayerNorm"] @@ -193,7 +193,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: # Note: This will soon be deprecated with # https://github.com/NVIDIA/TransformerEngine/pull/1033 if torch.is_autocast_enabled(): - self.activation_dtype = torch.get_autocast_gpu_dtype() + self.activation_dtype = torch_get_autocast_gpu_dtype() elif self.activation_dtype != inp.dtype: dtype = inp.dtype for name, param in self.named_parameters(): diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index fc6ec5746f..6041be47eb 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -13,7 +13,7 @@ from .. import cpp_extensions as tex from ..jit import no_torch_dynamo -from ..utils import cast_if_needed +from ..utils import cast_if_needed, torch_get_autocast_gpu_dtype __all__ = ["RMSNorm"] @@ -190,7 +190,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: # Note: This will soon be deprecated with # https://github.com/NVIDIA/TransformerEngine/pull/1033 if torch.is_autocast_enabled(): - self.activation_dtype = torch.get_autocast_gpu_dtype() + self.activation_dtype = torch_get_autocast_gpu_dtype() elif self.activation_dtype != inp.dtype: dtype = inp.dtype for name, param in self.named_parameters(): diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index ad5476450b..bf44c46658 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.utils import ( cast_if_needed, get_default_init_method, + torch_get_autocast_gpu_dtype, ) from transformer_engine.pytorch.constants import ( AttnMaskTypes, @@ -677,7 +678,7 @@ def forward( # For AMP if torch.is_autocast_enabled(): - hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) + hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype()) # Self attention. self_attention_outputs = self.self_attention( diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 947c642c2c..ac7819c9b0 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -305,3 +305,16 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: index2 = torch.cuda.current_device() return index1 == index2 return device1 == device2 + + +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) +if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: + + def torch_get_autocast_gpu_dtype() -> torch.dtype: + return torch.get_autocast_dtype("cuda") + +else: + + def torch_get_autocast_gpu_dtype() -> torch.dtype: + return torch.get_autocast_gpu_dtype() From cac145ad48cf660aa45ccec0565e5603c1e52c40 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sun, 20 Oct 2024 20:00:59 -0700 Subject: [PATCH 2/7] add docstring Signed-off-by: Xin Yao --- transformer_engine/pytorch/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index ac7819c9b0..9d7695675e 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -312,9 +312,11 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: def torch_get_autocast_gpu_dtype() -> torch.dtype: + """Get PyTorch autocast GPU dtype.""" return torch.get_autocast_dtype("cuda") else: def torch_get_autocast_gpu_dtype() -> torch.dtype: + """Get PyTorch autocast GPU dtype.""" return torch.get_autocast_gpu_dtype() From f22029400065c01f0bfd59d5f678442b98fa2067 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sun, 20 Oct 2024 20:51:01 -0700 Subject: [PATCH 3/7] fix kwargs for torch.amp.autocast Signed-off-by: Xin Yao --- transformer_engine/pytorch/distributed.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e3af4d60f8..5e311e8082 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -251,21 +251,27 @@ def in_fp8_activation_recompute_phase() -> bool: def _get_active_autocast_contexts(): """ - Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state - at the time of this function's execution. + Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast + state at the time of this function's execution. """ autocast_cached = torch.is_autocast_cache_enabled() gpu_autocast_enabled = torch.is_autocast_enabled("cuda") gpu_autocast_dtype = torch.get_autocast_dtype("cuda") gpu_autocast_ctx = torch.amp.autocast( - "cuda", gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached + "cuda", + enabled=gpu_autocast_enabled, + dtype=gpu_autocast_dtype, + cache_enabled=autocast_cached, ) cpu_autocast_enabled = torch.is_autocast_enabled("cpu") cpu_autocast_dtype = torch.get_autocast_dtype("cpu") cpu_autocast_ctx = torch.amp.autocast( - "cpu", cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached + "cpu", + enabled=cpu_autocast_enabled, + dtype=cpu_autocast_dtype, + cache_enabled=autocast_cached, ) return gpu_autocast_ctx, cpu_autocast_ctx @@ -274,8 +280,8 @@ def _get_active_autocast_contexts(): def _get_active_autocast_contexts(): """ - Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state - at the time of this function's execution. + Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast + state at the time of this function's execution. """ autocast_cached = torch.is_autocast_cache_enabled() From d2f5be81d68e5d3b305fd41b6a9576c43f7deb3c Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 22 Oct 2024 20:39:14 -0700 Subject: [PATCH 4/7] check torch version inside functions Signed-off-by: Xin Yao --- transformer_engine/pytorch/distributed.py | 34 ++++++++--------------- transformer_engine/pytorch/utils.py | 17 ++++-------- 2 files changed, 17 insertions(+), 34 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 5e311e8082..b659048b2b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -245,17 +245,16 @@ def in_fp8_activation_recompute_phase() -> bool: return _FP8_ACTIVATION_RECOMPUTE_PHASE -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) -if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: - - def _get_active_autocast_contexts(): - """ - Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast - state at the time of this function's execution. - """ - autocast_cached = torch.is_autocast_cache_enabled() +def _get_active_autocast_contexts(): + """ + Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state + at the time of this function's execution. + """ + autocast_cached = torch.is_autocast_cache_enabled() + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: gpu_autocast_enabled = torch.is_autocast_enabled("cuda") gpu_autocast_dtype = torch.get_autocast_dtype("cuda") gpu_autocast_ctx = torch.amp.autocast( @@ -273,18 +272,7 @@ def _get_active_autocast_contexts(): dtype=cpu_autocast_dtype, cache_enabled=autocast_cached, ) - - return gpu_autocast_ctx, cpu_autocast_ctx - -else: - - def _get_active_autocast_contexts(): - """ - Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast - state at the time of this function's execution. - """ - autocast_cached = torch.is_autocast_cache_enabled() - + else: gpu_autocast_enabled = torch.is_autocast_enabled() gpu_autocast_dtype = torch.get_autocast_gpu_dtype() gpu_autocast_ctx = torch.cuda.amp.autocast( @@ -297,7 +285,7 @@ def _get_active_autocast_contexts(): cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached ) - return gpu_autocast_ctx, cpu_autocast_ctx + return gpu_autocast_ctx, cpu_autocast_ctx class _CheckpointFunction(torch.autograd.Function): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 9d7695675e..935838ad3a 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -307,16 +307,11 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: return device1 == device2 -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) -if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: - - def torch_get_autocast_gpu_dtype() -> torch.dtype: - """Get PyTorch autocast GPU dtype.""" +def torch_get_autocast_gpu_dtype() -> torch.dtype: + """Get PyTorch autocast GPU dtype.""" + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: return torch.get_autocast_dtype("cuda") - -else: - - def torch_get_autocast_gpu_dtype() -> torch.dtype: - """Get PyTorch autocast GPU dtype.""" + else: return torch.get_autocast_gpu_dtype() From 64a506ca01b67a4691ef2aa26b7a7ddaa3492a24 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 22 Oct 2024 20:44:07 -0700 Subject: [PATCH 5/7] check torch version inside functions Signed-off-by: Xin Yao --- transformer_engine/pytorch/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 935838ad3a..90a0edc2c0 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -313,5 +313,4 @@ def torch_get_autocast_gpu_dtype() -> torch.dtype: TORCH_MINOR = int(torch.__version__.split(".")[1]) if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: return torch.get_autocast_dtype("cuda") - else: - return torch.get_autocast_gpu_dtype() + return torch.get_autocast_gpu_dtype() From 30f8cae6a8fb048e48cb0212defdd208125db336 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 4 Nov 2024 18:16:57 -0800 Subject: [PATCH 6/7] use packaging.version Signed-off-by: Xin Yao --- tests/pytorch/test_fused_optimizer.py | 2 +- transformer_engine/pytorch/distributed.py | 6 ++---- transformer_engine/pytorch/jit.py | 22 ++++++++-------------- transformer_engine/pytorch/transformer.py | 5 ++--- transformer_engine/pytorch/utils.py | 20 +++++++++++++++++--- 5 files changed, 30 insertions(+), 25 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index dccf81829e..9d9dbf78e8 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -15,7 +15,7 @@ from transformer_engine.pytorch import fp8_model_init from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.jit import gpu_autocast_ctx +from transformer_engine.pytorch.utils import gpu_autocast_ctx # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index b659048b2b..c3cd3d2bcb 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -17,7 +17,7 @@ from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules -from .utils import safely_set_viewless_tensor_data +from .utils import safely_set_viewless_tensor_data, is_torch_min_version from .constants import dist_group_type from .fp8 import FP8GlobalStateManager from .float8_tensor import Float8Tensor @@ -252,9 +252,7 @@ def _get_active_autocast_contexts(): """ autocast_cached = torch.is_autocast_cache_enabled() - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: + if is_torch_min_version("2.4.0a0"): gpu_autocast_enabled = torch.is_autocast_enabled("cuda") gpu_autocast_dtype = torch.get_autocast_dtype("cuda") gpu_autocast_ctx = torch.amp.autocast( diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 0a060e8305..8cbe2a1327 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -5,28 +5,29 @@ """NVFuser functions and JIT utilities""" import os from typing import Callable, Optional, Tuple -from functools import partial import torch +from .utils import is_torch_min_version, gpu_autocast_ctx + # pylint: disable=unnecessary-lambda-assignment jit_fuser = torch.jit.script -if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): +if is_torch_min_version("2a0") and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): jit_fuser = torch.compile # See: https://github.com/NVIDIA/TransformerEngine/issues/597 dropout_fuser = torch.jit.script -if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): +if is_torch_min_version("2.2a0") and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): dropout_fuser = torch.compile # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 no_torch_dynamo = lambda recursive=True: lambda func: func -if torch.__version__ >= "2": +if is_torch_min_version("2a0"): import torch._dynamo - if torch.__version__ >= "2.1": + if is_torch_min_version("2.1a0"): no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable( f, recursive=recursive ) @@ -34,20 +35,13 @@ # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True no_torch_dynamo = lambda recursive=True: torch._dynamo.disable -if torch.__version__ >= "2.4": - gpu_autocast_ctx = partial(torch.amp.autocast, device_type="cuda") -else: - gpu_autocast_ctx = torch.cuda.amp.autocast - def set_jit_fusion_options() -> None: """Set PyTorch JIT layer fusion options.""" # flags required to enable jit fusion kernels - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - if TORCH_MAJOR == 2 and TORCH_MINOR >= 2: + if is_torch_min_version("2.2.0a0"): pass - elif (TORCH_MAJOR == 2) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): + elif is_torch_min_version("1.10.0a0"): # nvfuser torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index bf44c46658..210dd62931 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -27,6 +27,7 @@ cast_if_needed, get_default_init_method, torch_get_autocast_gpu_dtype, + is_torch_min_version, ) from transformer_engine.pytorch.constants import ( AttnMaskTypes, @@ -431,9 +432,7 @@ def __init__( self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None # Set bias+dropout+add fusion grad_enable execution handler. - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + use_nvfuser = is_torch_min_version("1.10.0a0") and not is_torch_min_version("2.2.0a0") self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad if self.bias_dropout_fusion: diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 90a0edc2c0..fe2f087d48 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -6,11 +6,21 @@ from __future__ import annotations import functools import math +from packaging.version import Version as PkgVersion from typing import Any, Callable, Optional, Tuple import torch import transformer_engine.pytorch.cpp_extensions as ext +_torch_version = PkgVersion(torch.__version__) + + +def is_torch_min_version(version, check_equality=True): + """Check if minimum version of `torch` is installed.""" + if check_equality: + return _torch_version >= PkgVersion(version) + return _torch_version > PkgVersion(version) + def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """Check if any of the given tensors require gradient.""" @@ -309,8 +319,12 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: def torch_get_autocast_gpu_dtype() -> torch.dtype: """Get PyTorch autocast GPU dtype.""" - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: + if is_torch_min_version("2.4.0a0"): return torch.get_autocast_dtype("cuda") return torch.get_autocast_gpu_dtype() + + +if is_torch_min_version("2.4.0a0"): + gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda") +else: + gpu_autocast_ctx = torch.cuda.amp.autocast From e2b1184bc011ce845610d66dd82d88d72bc8d2a6 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 4 Nov 2024 18:33:42 -0800 Subject: [PATCH 7/7] fix lint Signed-off-by: Xin Yao --- transformer_engine/pytorch/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index fe2f087d48..10f96ffdc6 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -6,8 +6,8 @@ from __future__ import annotations import functools import math -from packaging.version import Version as PkgVersion from typing import Any, Callable, Optional, Tuple +from packaging.version import Version as PkgVersion import torch import transformer_engine.pytorch.cpp_extensions as ext