Skip to content

Commit

Permalink
Store I as a single number
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed May 10, 2024
1 parent ada1192 commit 5212c8e
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,10 @@ def vector_to_symtensor(vector):
@nvtx_annotate("decompose_tensor")
def decompose_tensor(tensor):
"""Full tensor decomposition into irreducible components."""
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[
..., None, None
] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype)
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)
A = 0.5 * (tensor - tensor.transpose(-2, -1))
S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I
S = tensor - A
S.diagonal(offset=0, dim1=-1, dim2=-2).sub_(I.unsqueeze(-1))
return I, A, S


Expand Down Expand Up @@ -260,7 +259,7 @@ def _compute_neighbors(
def output(self, X: Tensor) -> Tensor:
I, A, S = decompose_tensor(X) # shape: (n_atoms, hidden_channels, 3, 3)
x = torch.cat(
(tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1
(3 * I**2, tensor_norm(A), tensor_norm(S)), dim=-1
) # shape: (n_atoms, 3*hidden_channels)
x = self.out_norm(x) # shape: (n_atoms, 3*hidden_channels)
x = self.act(self.linear((x))) # shape: (n_atoms, hidden_channels)
Expand Down Expand Up @@ -322,10 +321,7 @@ def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor:
.unsqueeze(-1)
).expand(-1, -1, 3)
I, A, S = decompose_tensor(X)
I = (
self.linearI(I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
* factor[..., 0, None, None]
)
I = self.linearI(I) * factor[..., 0]
A = (
self.linearA(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
* factor[..., 1, None, None]
Expand All @@ -334,7 +330,8 @@ def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor:
self.linearS(S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
* factor[..., 2, None, None]
)
dX = I + A + S
dX = A + S
dX.diagonal(dim1=-2, dim2=-1).add_(I.unsqueeze(-1))
return dX


Expand Down Expand Up @@ -490,10 +487,11 @@ def forward(
@nvtx_annotate("compute_tensor_edge_features")
def compute_tensor_edge_features(X, edge_index, factor):
I, A, S = decompose_tensor(X)
msg = (
factor[..., 0, None, None] * I.index_select(0, edge_index[1])
+ factor[..., 1, None, None] * A.index_select(0, edge_index[1])
+ factor[..., 2, None, None] * S.index_select(0, edge_index[1])
msg = factor[..., 1, None, None] * A.index_select(0, edge_index[1]) + factor[
..., 2, None, None
] * S.index_select(0, edge_index[1])
msg.diagonal(dim1=-2, dim2=-1).add_(
factor[..., 0, None] * I.index_select(0, edge_index[1]).unsqueeze(-1)
)
return msg

Expand Down

0 comments on commit 5212c8e

Please sign in to comment.