Skip to content

Commit

Permalink
Start refact Graph
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoOlivo committed Feb 3, 2025
1 parent c9ec3da commit acd6e45
Showing 1 changed file with 54 additions and 18 deletions.
72 changes: 54 additions & 18 deletions pina/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ class Graph:
"""

def __init__(self,
x=None,
pos=None,
edge_index=None,
x,
pos,
edge_index,
edge_attr=None,
build_edge_attr=False,
undirected=False,
additional_params=None,
**kwargs):
additional_params=None):
"""
Constructor for the Graph class.
:param x: The node features.
Expand Down Expand Up @@ -165,35 +164,72 @@ def _check_input_consistency(x, pos):
raise ValueError("x and pos must have the same length.")
return max(len(x), len(pos))


class RadiusGraph:
def __new__(cls,
x,
pos,
r,
build_edge_attr=False,
undirected=False,
additional_params=None, ):
if isinstance(pos, (torch.Tensor, LabelTensor)):
if isinstance(pos, LabelTensor):
pos = pos.tensor
return RadiusGraph._radius_graph(pos, r)
return [RadiusGraph._radius_graph(p, r) for p in pos]
return Graph(x=x, pos=pos, method='radius',
build_edge_attr=build_edge_attr,
undirected=undirected, additional_params=additional_params)

@staticmethod
def _knn_graph(points, k):
def _radius_graph(points, r):
"""
Implementation of the k-nearest neighbors graph construction.
Implementation of the radius graph construction.
:param points: The input points.
:type points: torch.Tensor
:param k: The number of nearest neighbors.
:type k: int
:param r: The radius.
:type r: float
:return: The edge index.
:rtype: torch.Tensor
"""
dist = torch.cdist(points, points, p=2)
knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:]
row = torch.arange(points.size(0)).repeat_interleave(k)
col = knn_indices.flatten()
edge_index = torch.stack([row, col], dim=0)
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
return edge_index


class KNNGraph:
def __new__(cls,
x,
pos,
k,
build_edge_attr=False,
undirected=False,
additional_params=None,
):
if isinstance(pos, (torch.Tensor, LabelTensor)):
if isinstance(pos, LabelTensor):
pos = pos.tensor
return KNNGraph._knn_graph(pos, k)
return [KNNGraph._knn_graph(p, k) for p in pos]
return Graph(x=x, pos=pos, method='radius',
build_edge_attr=build_edge_attr,
undirected=undirected, additional_params=additional_params)

@staticmethod
def _radius_graph(points, r):
def _knn_graph(points, k):
"""
Implementation of the radius graph construction.
Implementation of the k-nearest neighbors graph construction.
:param points: The input points.
:type points: torch.Tensor
:param r: The radius.
:type r: float
:param k: The number of nearest neighbors.
:type k: int
:return: The edge index.
:rtype: torch.Tensor
"""
dist = torch.cdist(points, points, p=2)
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
knn_indices = torch.topk(dist, k=k + 1, largest=False).indices[:, 1:]
row = torch.arange(points.size(0)).repeat_interleave(k)
col = knn_indices.flatten()
edge_index = torch.stack([row, col], dim=0)
return edge_index

0 comments on commit acd6e45

Please sign in to comment.