diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 64f9b9d36..5c3158e80 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -212,10 +212,16 @@ def forward( edge_vec is not None ), "Distance module did not return directional information" # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom + # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q + if q is None: + q = torch.zeros_like(z, device=z.device, dtype=z.dtype) + else: + q = q[batch] zp = z if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) + q = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) @@ -228,7 +234,7 @@ def forward( edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) for layer in self.layers: - X = layer(X, edge_index, edge_weight, edge_attr) + X = layer(X, edge_index, edge_weight, edge_attr, q) I, A, S = decompose_tensor(X) x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) @@ -379,7 +385,7 @@ def reset_parameters(self): linear.reset_parameters() def forward( - self, X: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor + self, X: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, q: Tensor ) -> Tensor: C = self.cutoff(edge_weight) @@ -401,7 +407,7 @@ def forward( if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) B = torch.matmul(Y, msg) - I, A, S = decompose_tensor(A + B) + I, A, S = decompose_tensor((1 + 0.1*q[...,None,None,None])*(A + B)) if self.equivariance_invariance_group == "SO(3)": B = torch.matmul(Y, msg) I, A, S = decompose_tensor(2 * B) @@ -411,5 +417,5 @@ def forward( A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) dX = I + A + S - X = X + dX + torch.matrix_power(dX, 2) + X = X + dX + (1 + 0.1*q[...,None,None,None]) * torch.matrix_power(dX, 2) return X