Skip to content

Commit

Permalink
Refactor GNO model and enhance Graph class documentation and error ha…
Browse files Browse the repository at this point in the history
…ndling. Remove TemporalGraph class
  • Loading branch information
FilippoOlivo committed Feb 5, 2025
1 parent d79017e commit 6243895
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 65 deletions.
107 changes: 57 additions & 50 deletions pina/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,44 @@ def __init__(
additional_params=None
):
"""
Constructor for the Graph class.
:param x: The node features.
Constructor for the Graph class. This object creates a list of PyTorch Geometric Data objects.
Based on the input of x and pos there could be the following cases:
1. 1 pos, 1 x: a single graph will be created
2. N pos, 1 x: N graphs will be created with the same node features
3. 1 pos, N x: N graphs will be created with the same nodes but different nodes features
4. N pos, N x: N graph will be created
:param x: Node features. Can be a single 2D tensor of shape [num_nodes, num_node_features],
or a 3D tensor of shape [n_graphs, num_nodes, num_node_features]
or a list of such 2D tensors of shape [num_nodes ,num_node_features].
:type x: torch.Tensor or list[torch.Tensor]
:param pos: The node positions.
:param pos: Node coordinates. Can be a single 2D tensor of shape [num_nodes, num_coordinates],
or a 3D tensor of shape [n_graphs, num_nodes, num_coordinates]
or a list of such 2D tensors of shape [num_nodes ,num_coordinates].
:type pos: torch.Tensor or list[torch.Tensor]
:param edge_index: The edge index.
:param edge_index: The edge index defining connections between nodes.
It should be a 2D tensor of shape [2, num_edges]
or a 3D tensor of shape [n_graphs, 2, n_edges]
or a list of such tensors 2D tensors o shape [2, num_edges].
:type edge_index: torch.Tensor or list[torch.Tensor]
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor or list[torch.Tensor]
:param build_edge_attr: Whether to build the edge attributes.
:type build_edge_attr: bool
:param undirected: Whether to build an undirected graph.
:type undirected: bool
:param custom_build_edge_attr: Custom function to build the edge
attributes.
:type custom_build_edge_attr: function
:param additional_params: Additional parameters.
:type additional_params: dict
:param edge_attr: Edge features. If provided, should have the shape [num_edges, num_edge_features]
or be a list of such tensors for multiple graphs.
:type edge_attr: torch.Tensor or list[torch.Tensor], optional
:param build_edge_attr: Whether to compute edge attributes during initialization.
:type build_edge_attr: bool, default=False
:param undirected: If True, converts the graph(s) into an undirected graph by adding reciprocal edges.
:type undirected: bool, default=False
:param custom_build_edge_attr: A user-defined function to generate edge attributes dynamically.
The function should take (x, pos, edge_index) as input and return a tensor
of shape [num_edges, num_edge_features].
:type custom_build_edge_attr: function or callable, optional
:param additional_params: Dictionary containing extra attributes to be added to each Data object.
Keys represent attribute names, and values should be tensors or lists of tensors.
:type additional_params: dict, optional
Note: if
"""

self.data = []
x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)

Expand Down Expand Up @@ -85,7 +104,8 @@ def __init__(

# Build the edge attributes
edge_attr = self._check_and_build_edge_attr(edge_attr, build_edge_attr,
data_len, edge_index, pos, x)
data_len, edge_index, pos,
x)

# Perform the graph construction
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
Expand Down Expand Up @@ -128,14 +148,32 @@ def _check_input_consistency(x, pos, edge_index=None):
# If x is a 3D tensor, we split it into a list of 2D tensors
if isinstance(x, torch.Tensor) and x.ndim == 3:
x = [x[i] for i in range(x.shape[0])]
elif (not (isinstance(x, list) and all(t.ndim == 2 for t in x)) and
not (isinstance(x, torch.Tensor) and x.ndim == 2)):
raise TypeError("x must be either a list of 2D tensors or a 2D "
"tensor or a 3D tensor")

# If pos is a 3D tensor, we split it into a list of 2D tensors
if isinstance(pos, torch.Tensor) and pos.ndim == 3:
pos = [pos[i] for i in range(pos.shape[0])]
elif not (isinstance(pos, list) and all(
t.ndim == 2 for t in pos)) and not (
isinstance(pos, torch.Tensor) and pos.ndim == 2):
raise TypeError("pos must be either a list of 2D tensors or a 2D "
"tensor or a 3D tensor")

# If edge_index is a 3D tensor, we split it into a list of 2D tensors
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
if edge_index is not None:
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
elif not (isinstance(edge_index, list) and all(
t.ndim == 2 for t in edge_index)) and not (
isinstance(edge_index,
torch.Tensor) and edge_index.ndim == 2):
raise TypeError(
"edge_index must be either a list of 2D tensors or a 2D "
"tensor or a 3D tensor")

return x, pos, edge_index

@staticmethod
Expand Down Expand Up @@ -185,7 +223,7 @@ def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len,
return [edge_attr] * data_len

if build_edge_attr:
return [self._build_edge_attr(x,pos_, edge_index_) for
return [self._build_edge_attr(x, pos_, edge_index_) for
pos_, edge_index_ in zip(pos, edge_index)]


Expand Down Expand Up @@ -256,34 +294,3 @@ def _knn_graph(points, k):
col = knn_indices.flatten()
edge_index = torch.stack([row, col], dim=0)
return edge_index


class TemporalGraph(Graph):
def __init__(
self,
x,
pos,
t,
edge_index=None,
edge_attr=None,
build_edge_attr=False,
undirected=False,
r=None
):

x, pos, edge_index = self._check_input_consistency(x, pos, edge_index)
print(len(pos))
if edge_index is None:
edge_index = [RadiusGraph._radius_graph(p, r) for p in pos]
additional_params = {'t': t}
self._check_time_consistency(pos, t)
super().__init__(x=x, pos=pos, edge_index=edge_index,
edge_attr=edge_attr,
build_edge_attr=build_edge_attr,
undirected=undirected,
additional_params=additional_params)

@staticmethod
def _check_time_consistency(pos, times):
if len(pos) != len(times):
raise ValueError("pos and times must have the same length.")
16 changes: 1 addition & 15 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import torch
from pina import Graph
from pina.graph import RadiusGraph, KNNGraph, TemporalGraph
from pina.graph import RadiusGraph, KNNGraph


@pytest.mark.parametrize(
Expand Down Expand Up @@ -146,19 +145,6 @@ def test_additional_parameters_2(additional_parameters):
assert all(hasattr(d, 'y') for d in data)
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))


def test_temporal_graph():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
t = torch.rand(3)
graph = TemporalGraph(x=x, pos=pos, build_edge_attr=True, r=.3, t=t)
assert len(graph.data) == 3
data = graph.data
assert all(torch.isclose(d_.x, x_).all() for (d_, x_) in zip(data, x))
assert all(hasattr(d, 't') for d in data)
assert all(d_.t == t_ for (d_, t_) in zip(data, t))


def test_custom_build_edge_attr_func():
x = torch.rand(3, 10, 2)
pos = torch.rand(3, 10, 2)
Expand Down

0 comments on commit 6243895

Please sign in to comment.