From 1594898488a2802a3ce57d228619ae98793b632b Mon Sep 17 00:00:00 2001 From: caic99 Date: Tue, 24 Dec 2024 02:36:31 +0000 Subject: [PATCH] Perf: allow tf32 datatype for matmul --- deepmd/pt/entrypoints/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 1e5314a821..ce7fc136bd 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -103,7 +103,8 @@ def get_trainer( finetune_links=None, ): multi_task = "model_dict" in config.get("model", {}) - + # https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere + torch.backends.cuda.matmul.allow_tf32 = True # Initialize DDP local_rank = os.environ.get("LOCAL_RANK") if local_rank is not None: