diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index ac7819c9b0..9d7695675e 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -312,9 +312,11 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: def torch_get_autocast_gpu_dtype() -> torch.dtype: + """Get PyTorch autocast GPU dtype.""" return torch.get_autocast_dtype("cuda") else: def torch_get_autocast_gpu_dtype() -> torch.dtype: + """Get PyTorch autocast GPU dtype.""" return torch.get_autocast_gpu_dtype()