diff --git a/pina/graph.py b/pina/graph.py index d856ad03..44097431 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -25,25 +25,44 @@ def __init__( additional_params=None ): """ - Constructor for the Graph class. - :param x: The node features. + Constructor for the Graph class. This object creates a list of PyTorch Geometric Data objects. + Based on the input of x and pos there could be the following cases: + 1. 1 pos, 1 x: a single graph will be created + 2. N pos, 1 x: N graphs will be created with the same node features + 3. 1 pos, N x: N graphs will be created with the same nodes but different nodes features + 4. N pos, N x: N graph will be created + + :param x: Node features. Can be a single 2D tensor of shape [num_nodes, num_node_features], + or a 3D tensor of shape [n_graphs, num_nodes, num_node_features] + or a list of such 2D tensors of shape [num_nodes ,num_node_features]. :type x: torch.Tensor or list[torch.Tensor] - :param pos: The node positions. + :param pos: Node coordinates. Can be a single 2D tensor of shape [num_nodes, num_coordinates], + or a 3D tensor of shape [n_graphs, num_nodes, num_coordinates] + or a list of such 2D tensors of shape [num_nodes ,num_coordinates]. :type pos: torch.Tensor or list[torch.Tensor] - :param edge_index: The edge index. + :param edge_index: The edge index defining connections between nodes. + It should be a 2D tensor of shape [2, num_edges] + or a 3D tensor of shape [n_graphs, 2, n_edges] + or a list of such tensors 2D tensors o shape [2, num_edges]. :type edge_index: torch.Tensor or list[torch.Tensor] - :param edge_attr: The edge attributes. - :type edge_attr: torch.Tensor or list[torch.Tensor] - :param build_edge_attr: Whether to build the edge attributes. - :type build_edge_attr: bool - :param undirected: Whether to build an undirected graph. - :type undirected: bool - :param custom_build_edge_attr: Custom function to build the edge - attributes. - :type custom_build_edge_attr: function - :param additional_params: Additional parameters. - :type additional_params: dict + :param edge_attr: Edge features. If provided, should have the shape [num_edges, num_edge_features] + or be a list of such tensors for multiple graphs. + :type edge_attr: torch.Tensor or list[torch.Tensor], optional + :param build_edge_attr: Whether to compute edge attributes during initialization. + :type build_edge_attr: bool, default=False + :param undirected: If True, converts the graph(s) into an undirected graph by adding reciprocal edges. + :type undirected: bool, default=False + :param custom_build_edge_attr: A user-defined function to generate edge attributes dynamically. + The function should take (x, pos, edge_index) as input and return a tensor + of shape [num_edges, num_edge_features]. + :type custom_build_edge_attr: function or callable, optional + :param additional_params: Dictionary containing extra attributes to be added to each Data object. + Keys represent attribute names, and values should be tensors or lists of tensors. + :type additional_params: dict, optional + + Note: if """ + self.data = [] x, pos, edge_index = self._check_input_consistency(x, pos, edge_index) @@ -85,7 +104,8 @@ def __init__( # Build the edge attributes edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr, - data_len, edge_index, pos, x) + data_len, edge_index, pos, + x) # Perform the graph construction self._build_graph_list(x, pos, edge_index, edge_attr, additional_params) @@ -128,14 +148,32 @@ def _check_input_consistency(x, pos, edge_index=None): # If x is a 3D tensor, we split it into a list of 2D tensors if isinstance(x, torch.Tensor) and x.ndim == 3: x = [x[i] for i in range(x.shape[0])] + elif (not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and + not (isinstance(x, torch.Tensor) and x.ndim == 2)): + raise TypeError("x must be either a list of 2D tensors or a 2D " + "tensor or a 3D tensor") # If pos is a 3D tensor, we split it into a list of 2D tensors if isinstance(pos, torch.Tensor) and pos.ndim == 3: pos = [pos[i] for i in range(pos.shape[0])] + elif not (isinstance(pos, list) and all( + t.ndim == 2 for t in pos)) and not ( + isinstance(pos, torch.Tensor) and pos.ndim == 2): + raise TypeError("pos must be either a list of 2D tensors or a 2D " + "tensor or a 3D tensor") # If edge_index is a 3D tensor, we split it into a list of 2D tensors - if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3: - edge_index = [edge_index[i] for i in range(edge_index.shape[0])] + if edge_index is not None: + if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3: + edge_index = [edge_index[i] for i in range(edge_index.shape[0])] + elif not (isinstance(edge_index, list) and all( + t.ndim == 2 for t in edge_index)) and not ( + isinstance(edge_index, + torch.Tensor) and edge_index.ndim == 2): + raise TypeError( + "edge_index must be either a list of 2D tensors or a 2D " + "tensor or a 3D tensor") + return x, pos, edge_index @staticmethod @@ -185,7 +223,7 @@ def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len, return [edge_attr] * data_len if build_edge_attr: - return [self._build_edge_attr(x,pos_, edge_index_) for + return [self._build_edge_attr(x, pos_, edge_index_) for pos_, edge_index_ in zip(pos, edge_index)] @@ -256,34 +294,3 @@ def _knn_graph(points, k): col = knn_indices.flatten() edge_index = torch.stack([row, col], dim=0) return edge_index - - -class TemporalGraph(Graph): - def __init__( - self, - x, - pos, - t, - edge_index=None, - edge_attr=None, - build_edge_attr=False, - undirected=False, - r=None - ): - - x, pos, edge_index = self._check_input_consistency(x, pos, edge_index) - print(len(pos)) - if edge_index is None: - edge_index = [RadiusGraph._radius_graph(p, r) for p in pos] - additional_params = {'t': t} - self._check_time_consistency(pos, t) - super().__init__(x=x, pos=pos, edge_index=edge_index, - edge_attr=edge_attr, - build_edge_attr=build_edge_attr, - undirected=undirected, - additional_params=additional_params) - - @staticmethod - def _check_time_consistency(pos, times): - if len(pos) != len(times): - raise ValueError("pos and times must have the same length.") diff --git a/tests/test_graph.py b/tests/test_graph.py index 5521be00..660ec342 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,7 +1,6 @@ import pytest import torch -from pina import Graph -from pina.graph import RadiusGraph, KNNGraph, TemporalGraph +from pina.graph import RadiusGraph, KNNGraph @pytest.mark.parametrize( @@ -146,19 +145,6 @@ def test_additional_parameters_2(additional_parameters): assert all(hasattr(d, 'y') for d in data) assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) - -def test_temporal_graph(): - x = torch.rand(3, 10, 2) - pos = torch.rand(3, 10, 2) - t = torch.rand(3) - graph = TemporalGraph(x=x, pos=pos, build_edge_attr=True, r=.3, t=t) - assert len(graph.data) == 3 - data = graph.data - assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x)) - assert all(hasattr(d, 't') for d in data) - assert all(d_.t == t_ for (d_, t_) in zip(data, t)) - - def test_custom_build_edge_attr_func(): x = torch.rand(3, 10, 2) pos = torch.rand(3, 10, 2)