Skip to content

Commit

Permalink
Merge pull request #283 from RaulPPelaez/returntype
Browse files Browse the repository at this point in the history
Make TorchMD_Net always return two tensors
  • Loading branch information
RaulPPelaez authored Feb 16, 2024
2 parents e0ecc48 + f4b6827 commit 94cc325
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
12 changes: 12 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def test_torchscript(model_name, device):
grad_outputs=grad_outputs,
)[0]

def test_torchscript_output_modification():
model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True))
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.model = model
def forward(self, z, pos, batch):
y, neg_dy = self.model(z, pos, batch=batch)
# A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor]
return y, 2*neg_dy
torch.jit.script(MyModel())

@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("device", ["cpu", "cuda"])
def test_torchscript_dynamic_shapes(model_name, device):
Expand Down
27 changes: 16 additions & 11 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ def create_model(args, prior_model=None, mean=None, std=None):
max_z=args["max_z"],
check_errors=bool(args["check_errors"]),
max_num_neighbors=args["max_num_neighbors"],
box_vecs=torch.tensor(args["box_vecs"], dtype=dtype)
if args["box_vecs"] is not None
else None,
box_vecs=(
torch.tensor(args["box_vecs"], dtype=dtype)
if args["box_vecs"] is not None
else None
),
dtype=dtype,
)

Expand Down Expand Up @@ -164,8 +166,12 @@ def load_model(filepath, args=None, device="cpu", **kwargs):
model = create_model(args)
if delta_learning and "remove_ref_energy" in kwargs:
if not kwargs["remove_ref_energy"]:
assert len(model.prior_model) > 0, "Atomref prior must be added during training (with enable=False) for total energy prediction."
assert isinstance(model.prior_model[-1], priors.Atomref), "I expected the last prior to be Atomref."
assert (
len(model.prior_model) > 0
), "Atomref prior must be added during training (with enable=False) for total energy prediction."
assert isinstance(
model.prior_model[-1], priors.Atomref
), "I expected the last prior to be Atomref."
# Set the Atomref prior to enabled
model.prior_model[-1].enable = True

Expand Down Expand Up @@ -333,7 +339,7 @@ def forward(
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
) -> Tuple[Tensor, Tensor]:
"""
Compute the output of the model.
Expand Down Expand Up @@ -415,9 +421,8 @@ def forward(
create_graph=self.training,
retain_graph=self.training,
)[0]
if dy is None:
raise RuntimeError("Autograd returned None for the force prediction.")

assert dy is not None, "Autograd returned None for the force prediction."
return y, -dy
# TODO: return only `out` once Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180)
return y, None
# Returning an empty tensor allows to decorate this method as always returning two tensors.
# This is required to overcome a TorchScript limitation, xref https://github.com/openmm/openmm-torch/issues/135
return y, torch.empty(0)

0 comments on commit 94cc325

Please sign in to comment.