Skip to content

Commit

Permalink
Small changes to I
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed May 10, 2024
1 parent 5212c8e commit c1bd2e9
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,18 @@ def vector_to_skewtensor(vector):
def vector_to_symtensor(vector):
"""Creates a symmetric traceless tensor from the outer product of a vector with itself."""
tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2))
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[
..., None, None
] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype)
S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I
S = 0.5 * (tensor + tensor.transpose(-2, -1))
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)
S.diagonal(offset=0, dim1=-1, dim2=-2).sub_(I.unsqueeze(-1))
return S


@nvtx_annotate("decompose_tensor")
def decompose_tensor(tensor):
"""Full tensor decomposition into irreducible components."""
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)
A = 0.5 * (tensor - tensor.transpose(-2, -1))
S = tensor - A
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)
S.diagonal(offset=0, dim1=-1, dim2=-2).sub_(I.unsqueeze(-1))
return I, A, S

Expand Down

0 comments on commit c1bd2e9

Please sign in to comment.