Skip to content

Commit

Permalink
Add docstring for extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Oct 30, 2023
1 parent 65b825f commit 9b46e47
Showing 1 changed file with 44 additions and 11 deletions.
55 changes: 44 additions & 11 deletions torchmdnet/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,58 @@
import os.path as osp
import torch
import importlib.machinery
from typing import Tuple


def _load_library(library):
""" Load a dynamic library containing torch extensions from the given path.
"""Load a dynamic library containing torch extensions from the given path.
Args:
library (str): The name of the library to load.
"""
# Find the specification for the library
spec = importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]
)
spec = importlib.machinery.PathFinder().find_spec(library, [osp.dirname(__file__)])
# Check if the specification is found and load the library
if spec is not None:
torch.ops.load_library(spec.origin)
else:
raise ImportError(f"Could not find module '{library}' in {osp.dirname(__file__)}")
raise ImportError(
f"Could not find module '{library}' in {osp.dirname(__file__)}"
)


_load_library("torchmdnet_extensions")

@torch.jit.script

def is_current_stream_capturing():
"""Returns True if the current CUDA stream is capturing.
Returns False if CUDA is not available or the current stream is not capturing.
This utility is required because the builtin torch function that does this is not scriptable.
"""
_is_current_stream_capturing = torch.ops.torchmdnet_extensions.is_current_stream_capturing
_is_current_stream_capturing = (
torch.ops.torchmdnet_extensions.is_current_stream_capturing
)
return _is_current_stream_capturing()

get_neighbor_pairs_kernel = torch.ops.torchmdnet_extensions.get_neighbor_pairs
get_neighbor_pairs_kernel.__doc__ = """
Computes the neighbor pairs for a given set of atomic positions.

def get_neighbor_pairs_kernel(
strategy: str,
positions: torch.Tensor,
batch: torch.Tensor,
box_vectors: torch.Tensor,
use_periodic: bool,
cutoff_lower: float,
cutoff_upper: float,
max_num_pairs: int,
loop: bool,
include_transpose: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes the neighbor pairs for a given set of atomic positions.
The list is generated as a list of pairs (i,j) without any enforced ordering.
The list is padded with -1 to the maximum number of pairs.
Parameters
----------
strategy : str
Expand Down Expand Up @@ -76,7 +94,22 @@ def is_current_stream_capturing():
-----
This function is a torch extension loaded from `torch.ops.torchmdnet_extensions.get_neighbor_pairs`.
"""
return torch.ops.torchmdnet_extensions.get_neighbor_pairs(
strategy,
positions,
batch,
box_vectors,
use_periodic,
cutoff_lower,
cutoff_upper,
max_num_pairs,
loop,
include_transpose,
)


# For some unknown reason torch.compile is not able to compile this function
if int(torch.__version__.split('.')[0]) >= 2:
if int(torch.__version__.split(".")[0]) >= 2:
import torch._dynamo as dynamo

dynamo.disallow_in_graph(get_neighbor_pairs_kernel)

0 comments on commit 9b46e47

Please sign in to comment.