diff --git a/tests/test_model.py b/tests/test_model.py index b792595b..b29f4471 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -227,3 +227,14 @@ def test_gradients(model_name): torch.autograd.gradcheck( model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3 ) + + +@mark.parametrize("model_name", models.__all_models__) +@mark.parametrize("use_batch", [True, False]) +def test_extra_embedding(model_name, use_batch): + z, pos, batch = create_example_batch() + args = load_example_args(model_name, prior_model=None) + args["extra_embedding"] = ["atomic", "global"] + model = create_model(args) + batch = batch if use_batch else None + model(z, pos, batch=batch, extra_args={'atomic':torch.rand(6), 'global':torch.rand(2)}) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a2a80f90..913d967d 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -38,6 +38,12 @@ def create_model(args, prior_model=None, mean=None, std=None): args["static_shapes"] = False if "vector_cutoff" not in args: args["vector_cutoff"] = False + if "extra_embedding" not in args: + extra_embedding = None + elif isinstance(args["extra_embedding"], str): + extra_embedding = [args["extra_embedding"]] + else: + extra_embedding = args["extra_embedding"] shared_args = dict( hidden_channels=args["embedding_dimension"], @@ -57,6 +63,7 @@ def create_model(args, prior_model=None, mean=None, std=None): else None ), dtype=dtype, + extra_embedding=extra_embedding ) # representation network @@ -370,7 +377,7 @@ def forward( If this is omitted, periodic boundary conditions are not applied. q (Tensor, optional): Atomic charges in the molecule. Shape: (N,). s (Tensor, optional): Atomic spins in the molecule. Shape: (N,). - extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model. + extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the model. Returns: Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise. @@ -380,9 +387,19 @@ def forward( if self.derivative: pos.requires_grad_(True) + if self.representation_model.extra_embedding is None: + extra_embedding_args = None + else: + extra = [] + for arg in self.representation_model.extra_embedding: + t = extra_args[arg] + if t.shape != z.shape: + t = t[batch] + extra.append(t) + extra_embedding_args = tuple(extra) # run the potentially wrapped representation model x, v, z, pos, batch = self.representation_model( - z, pos, batch, box=box, q=q, s=s + z, pos, batch, box=box, q=q, s=s, extra_embedding_args=extra_embedding_args ) # apply the output network x = self.output_model.pre_reduce(x, v, z, pos, batch) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index a2006c9e..3a40ec6f 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -120,6 +120,9 @@ class TensorNet(nn.Module): (default: :obj:`True`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) + extra_embedding (tuple, optional): the names of extra fields to append to the embedding + vector for each atom + (default: :obj:`None`) """ def __init__( @@ -139,6 +142,7 @@ def __init__( check_errors=True, dtype=torch.float32, box_vecs=None, + extra_embedding=None ): super(TensorNet, self).__init__() @@ -163,6 +167,7 @@ def __init__( self.activation = activation self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper + self.extra_embedding = extra_embedding act_class = act_class_mapping[activation] self.distance_expansion = rbf_class_mapping[rbf_type]( cutoff_lower, cutoff_upper, num_rbf, trainable_rbf @@ -176,6 +181,7 @@ def __init__( trainable_rbf, max_z, dtype, + extra_embedding ) self.layers = nn.ModuleList() @@ -228,6 +234,7 @@ def forward( box: Optional[Tensor] = None, q: Optional[Tensor] = None, s: Optional[Tensor] = None, + extra_embedding_args: [Optional[Tuple[Tensor]]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: # Obtain graph, with distances and relative position vectors edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) @@ -258,7 +265,7 @@ def forward( # Normalizing edge vectors by their length can result in NaNs, breaking Autograd. # I avoid dividing by zero by setting the weight of self edges and self loops to 1 edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) - X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) + X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr, extra_embedding_args) for layer in self.layers: X = layer(X, edge_index, edge_weight, edge_attr, q) I, A, S = decompose_tensor(X) @@ -287,6 +294,7 @@ def __init__( trainable_rbf=False, max_z=128, dtype=torch.float32, + extra_embedding=None ): super(TensorEmbedding, self).__init__() @@ -297,6 +305,10 @@ def __init__( self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) self.max_z = max_z self.emb = nn.Embedding(max_z, hidden_channels, dtype=dtype) + if extra_embedding is not None: + self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype) + else: + self.reshape_embedding = None self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype) self.act = activation() self.linears_tensor = nn.ModuleList() @@ -319,6 +331,8 @@ def reset_parameters(self): self.distance_proj2.reset_parameters() self.distance_proj3.reset_parameters() self.emb.reset_parameters() + if self.reshape_embedding is not None: + self.reshape_embedding.reset_parameters() self.emb2.reset_parameters() for linear in self.linears_tensor: linear.reset_parameters() @@ -326,8 +340,11 @@ def reset_parameters(self): linear.reset_parameters() self.init_norm.reset_parameters() - def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor: + def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor, extra_embedding_args: Optional[Tuple[Tensor]]) -> Tensor: Z = self.emb(z) + if self.reshape_embedding is not None: + Z = torch.cat((Z,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1) + Z = self.reshape_embedding(Z) Zij = self.emb2( Z.index_select(0, edge_index.t().reshape(-1)).view( -1, self.hidden_channels * 2 @@ -362,8 +379,9 @@ def forward( edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor, + extra_embedding_args: Optional[Tuple[Tensor]] ) -> Tensor: - Zij = self._get_atomic_number_message(z, edge_index) + Zij = self._get_atomic_number_message(z, edge_index, extra_embedding_args) Iij, Aij, Sij = self._get_tensor_messages( Zij, edge_weight, edge_vec_norm, edge_attr ) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 5ff168d5..6fe211b2 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -79,7 +79,9 @@ class TorchMD_ET(nn.Module): (default: :obj:`False`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - + extra_embedding (tuple, optional): the names of extra fields to append to the embedding + vector for each atom + (default: :obj:`None`) """ def __init__( @@ -102,6 +104,7 @@ def __init__( box_vecs=None, vector_cutoff=False, dtype=torch.float32, + extra_embedding=None ): super(TorchMD_ET, self).__init__() @@ -133,10 +136,15 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.dtype = dtype + self.extra_embedding = extra_embedding act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) + if extra_embedding is not None: + self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype) + else: + self.reshape_embedding = None self.distance = OptimizedDistance( cutoff_lower, @@ -181,6 +189,8 @@ def __init__( def reset_parameters(self): self.embedding.reset_parameters() + if self.reshape_embedding is not None: + self.reshape_embedding.reset_parameters() self.distance_expansion.reset_parameters() if self.neighbor_embedding is not None: self.neighbor_embedding.reset_parameters() @@ -196,8 +206,12 @@ def forward( box: Optional[Tensor] = None, q: Optional[Tensor] = None, s: Optional[Tensor] = None, + extra_embedding_args: [Optional[Tuple[Tensor]]] = None ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: x = self.embedding(z) + if self.reshape_embedding is not None: + x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1) + x = self.reshape_embedding(x) edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) # This assert must be here to convince TorchScript that edge_vec is not None diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 31d68ae0..690e34e7 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -86,7 +86,9 @@ class TorchMD_GN(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - + extra_embedding (tuple, optional): the names of extra fields to append to the embedding + vector for each atom + (default: :obj:`None`) """ def __init__( @@ -107,6 +109,7 @@ def __init__( aggr="add", dtype=torch.float32, box_vecs=None, + extra_embedding=None ): super(TorchMD_GN, self).__init__() @@ -136,10 +139,15 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.aggr = aggr + self.extra_embedding = extra_embedding act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) + if extra_embedding is not None: + self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype) + else: + self.reshape_embedding = None self.distance = OptimizedDistance( cutoff_lower, @@ -184,6 +192,8 @@ def __init__( def reset_parameters(self): self.embedding.reset_parameters() + if self.reshape_embedding is not None: + self.reshape_embedding.reset_parameters() self.distance_expansion.reset_parameters() if self.neighbor_embedding is not None: self.neighbor_embedding.reset_parameters() @@ -198,8 +208,12 @@ def forward( box: Optional[Tensor] = None, s: Optional[Tensor] = None, q: Optional[Tensor] = None, + extra_embedding_args: [Optional[Tuple[Tensor]]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) + if self.reshape_embedding is not None: + x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1) + x = self.reshape_embedding(x) edge_index, edge_weight, _ = self.distance(pos, batch, box) edge_attr = self.distance_expansion(edge_weight) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index c11efc08..89655740 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -76,7 +76,9 @@ class TorchMD_T(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - + extra_embedding (tuple, optional): the names of extra fields to append to the embedding + vector for each atom + (default: :obj:`None`) """ def __init__( @@ -98,6 +100,7 @@ def __init__( max_num_neighbors=32, dtype=torch.float, box_vecs=None, + extra_embedding=None ): super(TorchMD_T, self).__init__() @@ -124,11 +127,16 @@ def __init__( self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.max_z = max_z + self.extra_embedding = extra_embedding act_class = act_class_mapping[activation] attn_act_class = act_class_mapping[attn_activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) + if extra_embedding is not None: + self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype) + else: + self.reshape_embedding = None self.distance = OptimizedDistance( cutoff_lower, @@ -177,6 +185,8 @@ def __init__( def reset_parameters(self): self.embedding.reset_parameters() + if self.reshape_embedding is not None: + self.reshape_embedding.reset_parameters() self.distance_expansion.reset_parameters() if self.neighbor_embedding is not None: self.neighbor_embedding.reset_parameters() @@ -192,8 +202,12 @@ def forward( box: Optional[Tensor] = None, s: Optional[Tensor] = None, q: Optional[Tensor] = None, + extra_embedding_args: [Optional[Tuple[Tensor]]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) + if self.reshape_embedding is not None: + x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1) + x = self.reshape_embedding(x) edge_index, edge_weight, _ = self.distance(pos, batch, box) edge_attr = self.distance_expansion(edge_weight) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index a51cfe45..26b334f9 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -80,6 +80,7 @@ def get_argparse(): parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.') parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') + parser.add_argument('--extra-embedding', type=str, default=None, help='Extra fields of the dataset to pass to the model and append to the embedding vector.', action="extend", nargs="*") parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model') parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model') parser.add_argument('--activation', type=str, default='silu', choices=list(act_class_mapping.keys()), help='Activation function')