diff --git a/torchmdnet/extensions/__init__.py b/torchmdnet/extensions/__init__.py index 464eaade7..68398fbd5 100644 --- a/torchmdnet/extensions/__init__.py +++ b/torchmdnet/extensions/__init__.py @@ -34,7 +34,48 @@ def is_current_stream_capturing(): return _is_current_stream_capturing() get_neighbor_pairs_kernel = torch.ops.torchmdnet_extensions.get_neighbor_pairs +get_neighbor_pairs_kernel.__doc__ = """ + Computes the neighbor pairs for a given set of atomic positions. + The list is generated as a list of pairs (i,j) without any enforced ordering. + The list is padded with -1 to the maximum number of pairs. + Parameters + ---------- + strategy : str + Strategy to use for computing the neighbor list. Can be one of :code:`["shared", "brute", "cell"]`. + positions : torch.Tensor + A tensor with shape (N, 3) representing the atomic positions. + batch : torch.Tensor + A tensor with shape (N,). Specifies the batch for each atom. + box_vectors : torch.Tensor + The vectors defining the periodic box with shape `(3, 3)`. + use_periodic : bool + Whether to apply periodic boundary conditions. + cutoff_lower : float + Lower cutoff for the neighbor list. + cutoff_upper : float + Upper cutoff for the neighbor list. + max_num_pairs : int + Maximum number of pairs to store. + loop : bool + Whether to include self-interactions. + include_transpose : bool + Whether to include the transpose of the neighbor list (pair i,j and pair j,i). + Returns + ------- + neighbors : torch.Tensor + List of neighbors for each atom. Shape (2, max_num_pairs). + distances : torch.Tensor + List of distances for each atom. Shape (max_num_pairs,). + distance_vecs : torch.Tensor + List of distance vectors for each atom. Shape (max_num_pairs, 3). + num_pairs : torch.Tensor + The number of pairs found. + + Notes + ----- + This function is a torch extension loaded from `torch.ops.torchmdnet_extensions.get_neighbor_pairs`. + """ # For some unknown reason torch.compile is not able to compile this function if int(torch.__version__.split('.')[0]) >= 2: import torch._dynamo as dynamo