From e66b3b1d92166b2c27d7d2b05f1c7f10fd61fb35 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Thu, 2 Nov 2023 12:51:50 +0100 Subject: [PATCH 1/9] Update tensornet.py for support of total charge q --- torchmdnet/models/tensornet.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 64f9b9d36..74832408d 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -202,7 +202,7 @@ def forward( z: Tensor, pos: Tensor, batch: Tensor, - q: Optional[Tensor] = None, + q: Optional[Tensor], s: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: # Obtain graph, with distances and relative position vectors @@ -227,8 +227,12 @@ def forward( # I avoid dividing by zero by setting the weight of self edges and self loops to 1 edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) + if q is None: + q = 0 + else: + q = q[...,None,None,None] for layer in self.layers: - X = layer(X, edge_index, edge_weight, edge_attr) + X = layer(X, edge_index, edge_weight, edge_attr, q) I, A, S = decompose_tensor(X) x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) @@ -379,7 +383,7 @@ def reset_parameters(self): linear.reset_parameters() def forward( - self, X: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor + self, X: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, q: Tensor ) -> Tensor: C = self.cutoff(edge_weight) @@ -401,7 +405,7 @@ def forward( if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) B = torch.matmul(Y, msg) - I, A, S = decompose_tensor(A + B) + I, A, S = decompose_tensor((1 + 0.1*q)*(A + B)) if self.equivariance_invariance_group == "SO(3)": B = torch.matmul(Y, msg) I, A, S = decompose_tensor(2 * B) From f832d2b71e6947ffae089cd908bb6a2dd0ec964e Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:11:33 +0100 Subject: [PATCH 2/9] Update tensornet.py for q support --- torchmdnet/models/tensornet.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 74832408d..36ca8fa20 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -217,6 +217,12 @@ def forward( mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) # I trick the model into thinking that the masked edges pertain to the extra atom + if q is None: + q = 0 + else: + batchp = torch.cat((batch, batch[-1] + 1), dim=0) + qp = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) + q = qp[batchp][...,None,None,None] # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) edge_weight = edge_weight.masked_fill(mask[0], 0) @@ -230,7 +236,7 @@ def forward( if q is None: q = 0 else: - q = q[...,None,None,None] + q = q[batch][...,None,None,None] for layer in self.layers: X = layer(X, edge_index, edge_weight, edge_attr, q) I, A, S = decompose_tensor(X) From 20e0bd3d5921fd2c4b297c48940c801bc3c31b84 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:12:30 +0100 Subject: [PATCH 3/9] fix --- torchmdnet/models/tensornet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 36ca8fa20..0fd0f0189 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -202,7 +202,7 @@ def forward( z: Tensor, pos: Tensor, batch: Tensor, - q: Optional[Tensor], + q: Optional[Tensor] = None, s: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: # Obtain graph, with distances and relative position vectors From 88f22f3e42bfc86ab65c6f724f27cb6f890c91b4 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:22:03 +0100 Subject: [PATCH 4/9] fix --- torchmdnet/models/tensornet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 0fd0f0189..c9a1d85d3 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -421,5 +421,5 @@ def forward( A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) dX = I + A + S - X = X + dX + torch.matrix_power(dX, 2) + X = X + dX + (1 + 0.1*q) * torch.matrix_power(dX, 2) return X From fbb94e8439f8aff0bcc1244b5b954ee2abd28d97 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 3 Nov 2023 08:49:44 +0100 Subject: [PATCH 5/9] fix comment --- torchmdnet/models/tensornet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index c9a1d85d3..f4fa35354 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -216,13 +216,13 @@ def forward( if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) - # I trick the model into thinking that the masked edges pertain to the extra atom if q is None: q = 0 else: batchp = torch.cat((batch, batch[-1] + 1), dim=0) qp = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) q = qp[batchp][...,None,None,None] + # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) edge_weight = edge_weight.masked_fill(mask[0], 0) From fa26101b158dc96f714595007e531c8dfe17e0a6 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 3 Nov 2023 09:01:01 +0100 Subject: [PATCH 6/9] initialize zero charge as tensor, move charge broadcasting to interaction module --- torchmdnet/models/tensornet.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index f4fa35354..0c7d4c918 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -217,11 +217,11 @@ def forward( mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) if q is None: - q = 0 + q = torch.zeros_like(zp, device=z.device, dtype=z.dtype) else: batchp = torch.cat((batch, batch[-1] + 1), dim=0) qp = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) - q = qp[batchp][...,None,None,None] + q = qp[batchp] # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) @@ -234,9 +234,9 @@ def forward( edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) if q is None: - q = 0 + q = torch.zeros_like(z, device=z.device, dtype=z.dtype) else: - q = q[batch][...,None,None,None] + q = q[batch] for layer in self.layers: X = layer(X, edge_index, edge_weight, edge_attr, q) I, A, S = decompose_tensor(X) @@ -411,7 +411,7 @@ def forward( if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) B = torch.matmul(Y, msg) - I, A, S = decompose_tensor((1 + 0.1*q)*(A + B)) + I, A, S = decompose_tensor((1 + 0.1*q[...,None,None,None])*(A + B)) if self.equivariance_invariance_group == "SO(3)": B = torch.matmul(Y, msg) I, A, S = decompose_tensor(2 * B) @@ -421,5 +421,5 @@ def forward( A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) dX = I + A + S - X = X + dX + (1 + 0.1*q) * torch.matrix_power(dX, 2) + X = X + dX + (1 + 0.1*q[...,None,None,None]) * torch.matrix_power(dX, 2) return X From 9c037da8b89270a06fe457f2cb0a33b4bab654ae Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 3 Nov 2023 09:07:14 +0100 Subject: [PATCH 7/9] add clarification comment --- torchmdnet/models/tensornet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 0c7d4c918..7b8ae4734 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -216,6 +216,7 @@ def forward( if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) + # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q if q is None: q = torch.zeros_like(zp, device=z.device, dtype=z.dtype) else: @@ -233,6 +234,7 @@ def forward( # I avoid dividing by zero by setting the weight of self edges and self loops to 1 edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) + # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q if q is None: q = torch.zeros_like(z, device=z.device, dtype=z.dtype) else: From c030457b0c24246d123b75fcf678f2a991805ded Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 3 Nov 2023 09:38:56 +0100 Subject: [PATCH 8/9] trying fix --- torchmdnet/models/tensornet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 7b8ae4734..fb186cfd9 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -220,9 +220,9 @@ def forward( if q is None: q = torch.zeros_like(zp, device=z.device, dtype=z.dtype) else: - batchp = torch.cat((batch, batch[-1] + 1), dim=0) - qp = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) - q = qp[batchp] + #batchp = torch.cat((batch, batch[-1] + 1), dim=0) + #qp = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) + q = torch.cat((q[batch], torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) From eafd1347945dd882a1ac54bb63def4b6bb239b14 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Fri, 3 Nov 2023 09:55:02 +0100 Subject: [PATCH 9/9] try fix --- torchmdnet/models/tensornet.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index fb186cfd9..5c3158e80 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -212,17 +212,16 @@ def forward( edge_vec is not None ), "Distance module did not return directional information" # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom + # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q + if q is None: + q = torch.zeros_like(z, device=z.device, dtype=z.dtype) + else: + q = q[batch] zp = z if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) - # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q - if q is None: - q = torch.zeros_like(zp, device=z.device, dtype=z.dtype) - else: - #batchp = torch.cat((batch, batch[-1] + 1), dim=0) - #qp = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) - q = torch.cat((q[batch], torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) + q = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) @@ -234,11 +233,6 @@ def forward( # I avoid dividing by zero by setting the weight of self edges and self loops to 1 edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) - # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q - if q is None: - q = torch.zeros_like(z, device=z.device, dtype=z.dtype) - else: - q = q[batch] for layer in self.layers: X = layer(X, edge_index, edge_weight, edge_attr, q) I, A, S = decompose_tensor(X)