Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Oct 21, 2024
1 parent 3fe0c0f commit 0987833
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,11 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool:
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4:

def torch_get_autocast_gpu_dtype() -> torch.dtype:
"""Get PyTorch autocast GPU dtype."""
return torch.get_autocast_dtype("cuda")

else:

def torch_get_autocast_gpu_dtype() -> torch.dtype:
"""Get PyTorch autocast GPU dtype."""
return torch.get_autocast_gpu_dtype()

0 comments on commit 0987833

Please sign in to comment.