Skip to content

Commit

Permalink
Refact GNO
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoOlivo committed Feb 4, 2025
1 parent 8eb178f commit 2b4c0f2
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 63 deletions.
107 changes: 63 additions & 44 deletions pina/model/gno.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,71 @@
import torch
from torch.nn import Tanh
from .layers import GraphIntegralKernel
from .feed_forward import FeedForward

class GNO(torch.nn.Module):
def __init__(self,
lifting_operator,
projection_operator,
edge_features,
n_layers=1,
kernel_n_layers=0,
kernel_inner_size=None,
kernel_layers=None,
common=False
):
super(GNO, self).__init__()
self.lifting_operator = lifting_operator
self.projection_operator = projection_operator
self.hidden_dim = lifting_operator.out_features
self.tanh = Tanh()
if common:
dense = FeedForward(input_dimensions=edge_features,
output_dimensions=self.hidden_dim ** 2,
n_layers=kernel_n_layers,
inner_size=kernel_inner_size,
layers=kernel_layers)
W = FeedForward(input_dimensions=self.hidden_dim,
output_dimensions=self.hidden_dim,
n_layers=1)
self.kernels = torch.nn.ModuleList(
[GraphIntegralKernel(width=lifting_operator.out_features,
kernel_width=edge_features,
W=W,
dense=dense) for _ in range(n_layers)])
else:
self.kernels = torch.nn.ModuleList(
[GraphIntegralKernel(width=lifting_operator.out_features,
kernel_width=edge_features,
n_layers=kernel_n_layers,
inner_size=kernel_inner_size,
layers=kernel_layers) for _ in
range(n_layers)])
from .layers import GraphIntegralLayer
from .base_no import KernelNeuralOperator


class GraphNeuralKernel(torch.nn.Module):
def __init__(
self,
width,
edge_features,
n_layers=2,
internal_n_layers=0,
inner_size=None,
internal_layers=None,
func=None
):
super().__init__()
if func is None:
func = Tanh

self.layers = torch.nn.ModuleList(
[GraphIntegralLayer(width=width,
edges_features=edge_features,
n_layers=internal_n_layers,
inner_size=inner_size,
layers=internal_layers,
func=func) for _ in range(n_layers)]
)

def forward(self, x, edge_index, edge_attr):
for layer in self.layers:
x = layer(x, edge_index, edge_attr)
return x


class GNO(KernelNeuralOperator):
def __init__(
self,
lifting_operator,
projection_operator,
edge_features,
n_layers=10,
internal_n_layers=0,
inner_size=None,
internal_layers=None,
func=None
):
if func is None:
func = Tanh

super().__init__(
lifting_operator=lifting_operator,
integral_kernels=GraphNeuralKernel(
width=lifting_operator.out_features,
edge_features=edge_features,
internal_n_layers=internal_n_layers,
inner_size=inner_size,
internal_layers=internal_layers,
func=func,
n_layers=n_layers,
),
projection_operator=projection_operator
)

def forward(self, batch):
x, edge_index, edge_attr = batch.x, batch.edge_index, batch.edge_attr
x = self.lifting_operator(x)
for kernel in self.kernels:
x = kernel(x, edge_index, edge_attr)
x = self.tanh(x)
x = self.integral_kernels(x, edge_index, edge_attr)
x = self.projection_operator(x)
return x
4 changes: 2 additions & 2 deletions pina/model/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"AVNOBlock",
"LowRankBlock",
"RBFBlock",
"GraphIntegralKernel"
"GraphIntegralLayer"
]

from .convolution_2d import ContinuousConvBlock
Expand All @@ -32,4 +32,4 @@
from .avno_layer import AVNOBlock
from .lowrank_layer import LowRankBlock
from .rbf_layer import RBFBlock
from .graph_integral_kernel import GraphIntegralKernel
from .graph_integral_kernel import GraphIntegralLayer
30 changes: 13 additions & 17 deletions pina/model/layers/graph_integral_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,26 @@
from torch_geometric.nn import MessagePassing


class GraphIntegralKernel(MessagePassing):
class GraphIntegralLayer(MessagePassing):
def __init__(self,
width,
kernel_width,
edges_features,
n_layers=0,
inner_size=None,
layers=None,
W=None,
dense=None,
func = None
):
super(GraphIntegralKernel, self).__init__(aggr='mean')
from ..feed_forward import FeedForward
from pina.model import FeedForward
super(GraphIntegralLayer, self).__init__(aggr='mean')
self.width = width
if dense is None:
self.dense = FeedForward(input_dimensions=kernel_width,
self.dense = FeedForward(input_dimensions=edges_features,
output_dimensions=width ** 2,
n_layers=n_layers,
inner_size=inner_size,
layers=layers)
else:
self.dense = dense
if W is None:
self.W = FeedForward(input_dimensions=width, output_dimensions=width,
n_layers=1)
else:
self.W = W
layers=layers,
func=func)
self.W = torch.nn.Linear(width, width)
self.func = func()

def message(self, x_j, edge_attr):
x = self.dense(edge_attr).view(-1, self.width, self.width)
Expand All @@ -38,4 +32,6 @@ def update(self, aggr_out, x):
return aggr_out

def forward(self, x, edge_index, edge_attr):
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
return self.func(
self.propagate(edge_index, x=x, edge_attr=edge_attr)
)

0 comments on commit 2b4c0f2

Please sign in to comment.