Skip to content

Commit

Permalink
Allow appending extra values to embedding vector
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman committed Feb 21, 2024
1 parent 166b7db commit 58f298f
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 8 deletions.
11 changes: 11 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,14 @@ def test_gradients(model_name):
torch.autograd.gradcheck(
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("use_batch", [True, False])
def test_extra_embedding(model_name, use_batch):
z, pos, batch = create_example_batch()
args = load_example_args(model_name, prior_model=None)
args["extra_embedding"] = ["atomic", "global"]
model = create_model(args)
batch = batch if use_batch else None
model(z, pos, batch=batch, extra_args={'atomic':torch.rand(6), 'global':torch.rand(2)})
21 changes: 19 additions & 2 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def create_model(args, prior_model=None, mean=None, std=None):
args["static_shapes"] = False
if "vector_cutoff" not in args:
args["vector_cutoff"] = False
if "extra_embedding" not in args:
extra_embedding = None
elif isinstance(args["extra_embedding"], str):
extra_embedding = [args["extra_embedding"]]
else:
extra_embedding = args["extra_embedding"]

shared_args = dict(
hidden_channels=args["embedding_dimension"],
Expand All @@ -57,6 +63,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
else None
),
dtype=dtype,
extra_embedding=extra_embedding
)

# representation network
Expand Down Expand Up @@ -370,7 +377,7 @@ def forward(
If this is omitted, periodic boundary conditions are not applied.
q (Tensor, optional): Atomic charges in the molecule. Shape: (N,).
s (Tensor, optional): Atomic spins in the molecule. Shape: (N,).
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the model.
Returns:
Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise.
Expand All @@ -380,9 +387,19 @@ def forward(

if self.derivative:
pos.requires_grad_(True)
if self.representation_model.extra_embedding is None:
extra_embedding_args = None
else:
extra = []
for arg in self.representation_model.extra_embedding:
t = extra_args[arg]
if t.shape != z.shape:
t = t[batch]
extra.append(t)
extra_embedding_args = tuple(extra)
# run the potentially wrapped representation model
x, v, z, pos, batch = self.representation_model(
z, pos, batch, box=box, q=q, s=s
z, pos, batch, box=box, q=q, s=s, extra_embedding_args=extra_embedding_args
)
# apply the output network
x = self.output_model.pre_reduce(x, v, z, pos, batch)
Expand Down
24 changes: 21 additions & 3 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ class TensorNet(nn.Module):
(default: :obj:`True`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)
extra_embedding (tuple, optional): the names of extra fields to append to the embedding
vector for each atom
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -139,6 +142,7 @@ def __init__(
check_errors=True,
dtype=torch.float32,
box_vecs=None,
extra_embedding=None
):
super(TensorNet, self).__init__()

Expand All @@ -163,6 +167,7 @@ def __init__(
self.activation = activation
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
self.extra_embedding = extra_embedding
act_class = act_class_mapping[activation]
self.distance_expansion = rbf_class_mapping[rbf_type](
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf
Expand All @@ -176,6 +181,7 @@ def __init__(
trainable_rbf,
max_z,
dtype,
extra_embedding
)

self.layers = nn.ModuleList()
Expand Down Expand Up @@ -228,6 +234,7 @@ def forward(
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
# Obtain graph, with distances and relative position vectors
edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
Expand Down Expand Up @@ -258,7 +265,7 @@ def forward(
# Normalizing edge vectors by their length can result in NaNs, breaking Autograd.
# I avoid dividing by zero by setting the weight of self edges and self loops to 1
edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1)
X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr)
X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr, extra_embedding_args)
for layer in self.layers:
X = layer(X, edge_index, edge_weight, edge_attr, q)
I, A, S = decompose_tensor(X)
Expand Down Expand Up @@ -287,6 +294,7 @@ def __init__(
trainable_rbf=False,
max_z=128,
dtype=torch.float32,
extra_embedding=None
):
super(TensorEmbedding, self).__init__()

Expand All @@ -297,6 +305,10 @@ def __init__(
self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper)
self.max_z = max_z
self.emb = nn.Embedding(max_z, hidden_channels, dtype=dtype)
if extra_embedding is not None:
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
else:
self.reshape_embedding = None
self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype)
self.act = activation()
self.linears_tensor = nn.ModuleList()
Expand All @@ -319,15 +331,20 @@ def reset_parameters(self):
self.distance_proj2.reset_parameters()
self.distance_proj3.reset_parameters()
self.emb.reset_parameters()
if self.reshape_embedding is not None:
self.reshape_embedding.reset_parameters()
self.emb2.reset_parameters()
for linear in self.linears_tensor:
linear.reset_parameters()
for linear in self.linears_scalar:
linear.reset_parameters()
self.init_norm.reset_parameters()

def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor:
def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor, extra_embedding_args: Optional[Tuple[Tensor]]) -> Tensor:
Z = self.emb(z)
if self.reshape_embedding is not None:
Z = torch.cat((Z,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
Z = self.reshape_embedding(Z)
Zij = self.emb2(
Z.index_select(0, edge_index.t().reshape(-1)).view(
-1, self.hidden_channels * 2
Expand Down Expand Up @@ -362,8 +379,9 @@ def forward(
edge_weight: Tensor,
edge_vec_norm: Tensor,
edge_attr: Tensor,
extra_embedding_args: Optional[Tuple[Tensor]]
) -> Tensor:
Zij = self._get_atomic_number_message(z, edge_index)
Zij = self._get_atomic_number_message(z, edge_index, extra_embedding_args)
Iij, Aij, Sij = self._get_tensor_messages(
Zij, edge_weight, edge_vec_norm, edge_attr
)
Expand Down
16 changes: 15 additions & 1 deletion torchmdnet/models/torchmd_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ class TorchMD_ET(nn.Module):
(default: :obj:`False`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)
extra_embedding (tuple, optional): the names of extra fields to append to the embedding
vector for each atom
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -102,6 +104,7 @@ def __init__(
box_vecs=None,
vector_cutoff=False,
dtype=torch.float32,
extra_embedding=None
):
super(TorchMD_ET, self).__init__()

Expand Down Expand Up @@ -133,10 +136,15 @@ def __init__(
self.cutoff_upper = cutoff_upper
self.max_z = max_z
self.dtype = dtype
self.extra_embedding = extra_embedding

act_class = act_class_mapping[activation]

self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)
if extra_embedding is not None:
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
else:
self.reshape_embedding = None

self.distance = OptimizedDistance(
cutoff_lower,
Expand Down Expand Up @@ -181,6 +189,8 @@ def __init__(

def reset_parameters(self):
self.embedding.reset_parameters()
if self.reshape_embedding is not None:
self.reshape_embedding.reset_parameters()
self.distance_expansion.reset_parameters()
if self.neighbor_embedding is not None:
self.neighbor_embedding.reset_parameters()
Expand All @@ -196,8 +206,12 @@ def forward(
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
x = self.embedding(z)
if self.reshape_embedding is not None:
x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
x = self.reshape_embedding(x)

edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
# This assert must be here to convince TorchScript that edge_vec is not None
Expand Down
16 changes: 15 additions & 1 deletion torchmdnet/models/torchmd_gn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ class TorchMD_GN(nn.Module):
(default: :obj:`None`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)
extra_embedding (tuple, optional): the names of extra fields to append to the embedding
vector for each atom
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -107,6 +109,7 @@ def __init__(
aggr="add",
dtype=torch.float32,
box_vecs=None,
extra_embedding=None
):
super(TorchMD_GN, self).__init__()

Expand Down Expand Up @@ -136,10 +139,15 @@ def __init__(
self.cutoff_upper = cutoff_upper
self.max_z = max_z
self.aggr = aggr
self.extra_embedding = extra_embedding

act_class = act_class_mapping[activation]

self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)
if extra_embedding is not None:
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
else:
self.reshape_embedding = None

self.distance = OptimizedDistance(
cutoff_lower,
Expand Down Expand Up @@ -184,6 +192,8 @@ def __init__(

def reset_parameters(self):
self.embedding.reset_parameters()
if self.reshape_embedding is not None:
self.reshape_embedding.reset_parameters()
self.distance_expansion.reset_parameters()
if self.neighbor_embedding is not None:
self.neighbor_embedding.reset_parameters()
Expand All @@ -198,8 +208,12 @@ def forward(
box: Optional[Tensor] = None,
s: Optional[Tensor] = None,
q: Optional[Tensor] = None,
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
x = self.embedding(z)
if self.reshape_embedding is not None:
x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
x = self.reshape_embedding(x)

edge_index, edge_weight, _ = self.distance(pos, batch, box)
edge_attr = self.distance_expansion(edge_weight)
Expand Down
16 changes: 15 additions & 1 deletion torchmdnet/models/torchmd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ class TorchMD_T(nn.Module):
(default: :obj:`None`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)
extra_embedding (tuple, optional): the names of extra fields to append to the embedding
vector for each atom
(default: :obj:`None`)
"""

def __init__(
Expand All @@ -98,6 +100,7 @@ def __init__(
max_num_neighbors=32,
dtype=torch.float,
box_vecs=None,
extra_embedding=None
):
super(TorchMD_T, self).__init__()

Expand All @@ -124,11 +127,16 @@ def __init__(
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
self.max_z = max_z
self.extra_embedding = extra_embedding

act_class = act_class_mapping[activation]
attn_act_class = act_class_mapping[attn_activation]

self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)
if extra_embedding is not None:
self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype)
else:
self.reshape_embedding = None

self.distance = OptimizedDistance(
cutoff_lower,
Expand Down Expand Up @@ -177,6 +185,8 @@ def __init__(

def reset_parameters(self):
self.embedding.reset_parameters()
if self.reshape_embedding is not None:
self.reshape_embedding.reset_parameters()
self.distance_expansion.reset_parameters()
if self.neighbor_embedding is not None:
self.neighbor_embedding.reset_parameters()
Expand All @@ -192,8 +202,12 @@ def forward(
box: Optional[Tensor] = None,
s: Optional[Tensor] = None,
q: Optional[Tensor] = None,
extra_embedding_args: [Optional[Tuple[Tensor]]] = None
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
x = self.embedding(z)
if self.reshape_embedding is not None:
x = torch.cat((x,)+tuple(t.unsqueeze(1) for t in extra_embedding_args), dim=1)
x = self.reshape_embedding(x)

edge_index, edge_weight, _ = self.distance(pos, batch, box)
edge_attr = self.distance_expansion(edge_weight)
Expand Down
1 change: 1 addition & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_argparse():
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.')
parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.')
parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')
parser.add_argument('--extra-embedding', type=str, default=None, help='Extra fields of the dataset to pass to the model and append to the embedding vector.', action="extend", nargs="*")
parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model')
parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model')
parser.add_argument('--activation', type=str, default='silu', choices=list(act_class_mapping.keys()), help='Activation function')
Expand Down

0 comments on commit 58f298f

Please sign in to comment.