From 9b46e47b8fec8cb57006d7eb7741aa59097b97d5 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 30 Oct 2023 16:07:20 +0100 Subject: [PATCH] Add docstring for extensions --- torchmdnet/extensions/__init__.py | 55 ++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/torchmdnet/extensions/__init__.py b/torchmdnet/extensions/__init__.py index 68398fbd5..2ba936c1b 100644 --- a/torchmdnet/extensions/__init__.py +++ b/torchmdnet/extensions/__init__.py @@ -4,25 +4,28 @@ import os.path as osp import torch import importlib.machinery +from typing import Tuple + def _load_library(library): - """ Load a dynamic library containing torch extensions from the given path. + """Load a dynamic library containing torch extensions from the given path. Args: library (str): The name of the library to load. """ # Find the specification for the library - spec = importlib.machinery.PathFinder().find_spec( - library, [osp.dirname(__file__)] - ) + spec = importlib.machinery.PathFinder().find_spec(library, [osp.dirname(__file__)]) # Check if the specification is found and load the library if spec is not None: torch.ops.load_library(spec.origin) else: - raise ImportError(f"Could not find module '{library}' in {osp.dirname(__file__)}") + raise ImportError( + f"Could not find module '{library}' in {osp.dirname(__file__)}" + ) + _load_library("torchmdnet_extensions") -@torch.jit.script + def is_current_stream_capturing(): """Returns True if the current CUDA stream is capturing. @@ -30,14 +33,29 @@ def is_current_stream_capturing(): This utility is required because the builtin torch function that does this is not scriptable. """ - _is_current_stream_capturing = torch.ops.torchmdnet_extensions.is_current_stream_capturing + _is_current_stream_capturing = ( + torch.ops.torchmdnet_extensions.is_current_stream_capturing + ) return _is_current_stream_capturing() -get_neighbor_pairs_kernel = torch.ops.torchmdnet_extensions.get_neighbor_pairs -get_neighbor_pairs_kernel.__doc__ = """ - Computes the neighbor pairs for a given set of atomic positions. + +def get_neighbor_pairs_kernel( + strategy: str, + positions: torch.Tensor, + batch: torch.Tensor, + box_vectors: torch.Tensor, + use_periodic: bool, + cutoff_lower: float, + cutoff_upper: float, + max_num_pairs: int, + loop: bool, + include_transpose: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes the neighbor pairs for a given set of atomic positions. + The list is generated as a list of pairs (i,j) without any enforced ordering. The list is padded with -1 to the maximum number of pairs. + Parameters ---------- strategy : str @@ -76,7 +94,22 @@ def is_current_stream_capturing(): ----- This function is a torch extension loaded from `torch.ops.torchmdnet_extensions.get_neighbor_pairs`. """ + return torch.ops.torchmdnet_extensions.get_neighbor_pairs( + strategy, + positions, + batch, + box_vectors, + use_periodic, + cutoff_lower, + cutoff_upper, + max_num_pairs, + loop, + include_transpose, + ) + + # For some unknown reason torch.compile is not able to compile this function -if int(torch.__version__.split('.')[0]) >= 2: +if int(torch.__version__.split(".")[0]) >= 2: import torch._dynamo as dynamo + dynamo.disallow_in_graph(get_neighbor_pairs_kernel)