Skip to content

Commit

Permalink
Add docstring to extension
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Oct 30, 2023
1 parent a5f8a63 commit 65b825f
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions torchmdnet/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,48 @@ def is_current_stream_capturing():
return _is_current_stream_capturing()

get_neighbor_pairs_kernel = torch.ops.torchmdnet_extensions.get_neighbor_pairs
get_neighbor_pairs_kernel.__doc__ = """
Computes the neighbor pairs for a given set of atomic positions.
The list is generated as a list of pairs (i,j) without any enforced ordering.
The list is padded with -1 to the maximum number of pairs.
Parameters
----------
strategy : str
Strategy to use for computing the neighbor list. Can be one of :code:`["shared", "brute", "cell"]`.
positions : torch.Tensor
A tensor with shape (N, 3) representing the atomic positions.
batch : torch.Tensor
A tensor with shape (N,). Specifies the batch for each atom.
box_vectors : torch.Tensor
The vectors defining the periodic box with shape `(3, 3)`.
use_periodic : bool
Whether to apply periodic boundary conditions.
cutoff_lower : float
Lower cutoff for the neighbor list.
cutoff_upper : float
Upper cutoff for the neighbor list.
max_num_pairs : int
Maximum number of pairs to store.
loop : bool
Whether to include self-interactions.
include_transpose : bool
Whether to include the transpose of the neighbor list (pair i,j and pair j,i).
Returns
-------
neighbors : torch.Tensor
List of neighbors for each atom. Shape (2, max_num_pairs).
distances : torch.Tensor
List of distances for each atom. Shape (max_num_pairs,).
distance_vecs : torch.Tensor
List of distance vectors for each atom. Shape (max_num_pairs, 3).
num_pairs : torch.Tensor
The number of pairs found.
Notes
-----
This function is a torch extension loaded from `torch.ops.torchmdnet_extensions.get_neighbor_pairs`.
"""
# For some unknown reason torch.compile is not able to compile this function
if int(torch.__version__.split('.')[0]) >= 2:
import torch._dynamo as dynamo
Expand Down

0 comments on commit 65b825f

Please sign in to comment.