From 1420ea63a8e879dec870b90ecd83883d21c13879 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sun, 20 Oct 2024 20:00:59 -0700 Subject: [PATCH] add docstring --- transformer_engine/pytorch/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index ac7819c9b0..9d7695675e 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -312,9 +312,11 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: def torch_get_autocast_gpu_dtype() -> torch.dtype: + """Get PyTorch autocast GPU dtype.""" return torch.get_autocast_dtype("cuda") else: def torch_get_autocast_gpu_dtype() -> torch.dtype: + """Get PyTorch autocast GPU dtype.""" return torch.get_autocast_gpu_dtype()