Skip to content

Commit

Permalink
Blacken
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Nov 20, 2023
1 parent 5a13d78 commit e21cc94
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
27 changes: 21 additions & 6 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True


def vector_to_skewtensor(vector):
"""Creates a skew-symmetric tensor from a vector."""
batch_size = vector.size(0)
Expand All @@ -33,6 +34,7 @@ def vector_to_skewtensor(vector):
tensor = tensor.view(-1, 3, 3)
return tensor.squeeze(0)


def vector_to_symtensor(vector):
"""Creates a symmetric traceless tensor from the outer product of a vector with itself."""
tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2))
Expand All @@ -42,6 +44,7 @@ def vector_to_symtensor(vector):
S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I
return S


def decompose_tensor(tensor):
"""Full tensor decomposition into irreducible components."""
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[
Expand All @@ -51,10 +54,12 @@ def decompose_tensor(tensor):
S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I
return I, A, S


def tensor_norm(tensor):
"""Computes Frobenius norm."""
return (tensor**2).sum((-2, -1))


class TensorNet(nn.Module):
r"""TensorNet's architecture.
From TensorNet: Cartesian Tensor Representations for Efficient Learning of Molecular Potentials; G. Simeon and G. de Fabritiis.
Expand Down Expand Up @@ -237,7 +242,9 @@ def forward(
# WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs
edge_index = edge_index.masked_fill(mask, z.shape[0])
edge_weight = edge_weight.masked_fill(mask[0], 0)
edge_vec = edge_vec.masked_fill(mask[0].unsqueeze(-1).expand_as(edge_vec), 0)
edge_vec = edge_vec.masked_fill(
mask[0].unsqueeze(-1).expand_as(edge_vec), 0
)
edge_attr = self.distance_expansion(edge_weight)
mask = edge_index[0] == edge_index[1]
# Normalizing edge vectors by their length can result in NaNs, breaking Autograd.
Expand All @@ -261,6 +268,7 @@ class TensorEmbedding(nn.Module):
:meta private:
"""

def __init__(
self,
hidden_channels,
Expand Down Expand Up @@ -377,7 +385,9 @@ def forward(
return X


def tensor_message_passing(edge_index: Tensor, factor: Tensor, tensor: Tensor, natoms: int) -> Tensor:
def tensor_message_passing(
edge_index: Tensor, factor: Tensor, tensor: Tensor, natoms: int
) -> Tensor:
"""Message passing for tensors."""
msg = factor * tensor.index_select(0, edge_index[1])
shape = (natoms, tensor.shape[1], tensor.shape[2], tensor.shape[3])
Expand All @@ -391,6 +401,7 @@ class Interaction(nn.Module):
:meta private:
"""

def __init__(
self,
num_rbf,
Expand Down Expand Up @@ -432,9 +443,13 @@ def reset_parameters(self):
linear.reset_parameters()

def forward(
self, X: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, q: Tensor
self,
X: Tensor,
edge_index: Tensor,
edge_weight: Tensor,
edge_attr: Tensor,
q: Tensor,
) -> Tensor:

C = self.cutoff(edge_weight)
for linear_scalar in self.linears_scalar:
edge_attr = self.act(linear_scalar(edge_attr))
Expand All @@ -460,7 +475,7 @@ def forward(
if self.equivariance_invariance_group == "O(3)":
A = torch.matmul(msg, Y)
B = torch.matmul(Y, msg)
I, A, S = decompose_tensor((1 + 0.1*q[...,None,None,None])*(A + B))
I, A, S = decompose_tensor((1 + 0.1 * q[..., None, None, None]) * (A + B))
if self.equivariance_invariance_group == "SO(3)":
B = torch.matmul(Y, msg)
I, A, S = decompose_tensor(2 * B)
Expand All @@ -470,5 +485,5 @@ def forward(
A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
dX = I + A + S
X = X + dX + (1 + 0.1*q[...,None,None,None]) * torch.matrix_power(dX, 2)
X = X + dX + (1 + 0.1 * q[..., None, None, None]) * torch.matrix_power(dX, 2)
return X
2 changes: 1 addition & 1 deletion torchmdnet/models/torchmd_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def forward(
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:

x = self.embedding(z)

edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
Expand Down Expand Up @@ -234,6 +233,7 @@ class EquivariantMultiHeadAttention(nn.Module):
:meta private:
"""

def __init__(
self,
hidden_channels,
Expand Down
3 changes: 3 additions & 0 deletions torchmdnet/models/torchmd_gn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
scatter,
)


class TorchMD_GN(nn.Module):
r"""The TorchMD Graph Network architecture.
Code adapted from https://github.com/rusty1s/pytorch_geometric/blob/d7d8e5e2edada182d820bbb1eec5f016f50db1e0/torch_geometric/nn/models/schnet.py#L38
Expand Down Expand Up @@ -224,6 +225,7 @@ class InteractionBlock(nn.Module):
:meta private:
"""

def __init__(
self,
hidden_channels,
Expand Down Expand Up @@ -284,6 +286,7 @@ class CFConv(nn.Module):
:meta private:
"""

def __init__(
self,
in_channels,
Expand Down

0 comments on commit e21cc94

Please sign in to comment.