Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchscipt compatability for Ensemble #312

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 56 additions & 29 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import re
import tempfile
from typing import Optional, List, Tuple, Dict
from typing import Optional, List, Tuple, Dict, Union
import torch
from torch.autograd import grad
from torch import nn, Tensor
Expand Down Expand Up @@ -142,7 +142,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
return model


def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs):
def load_ensemble(filepath, args=None, device="cpu", **kwargs):
"""Load an ensemble of models from a list of checkpoint files or a zip file.

Args:
Expand All @@ -153,7 +153,6 @@ def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs)

args (dict, optional): Arguments for the model. Defaults to None.
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False.
**kwargs: Extra keyword arguments for the model, will be passed to :py:mod:`load_model`.

Returns:
Expand All @@ -179,11 +178,10 @@ def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs)
)
return Ensemble(
model_list,
return_std=return_std,
)


def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
def load_model(filepath, args=None, device="cpu", **kwargs):
"""Load a model from a checkpoint file.

If a list of paths or a path to a zip file is given, an :py:mod:`Ensemble` model is returned.
Expand All @@ -196,7 +194,6 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):

args (dict, optional): Arguments for the model. Defaults to None.
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False.
**kwargs: Extra keyword arguments for the model.

Returns:
Expand All @@ -205,7 +202,7 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
isEnsemble = isinstance(filepath, (list, tuple)) or filepath.endswith(".zip")
if isEnsemble:
return load_ensemble(
filepath, args=args, device=device, return_std=return_std, **kwargs
filepath, args=args, device=device, **kwargs
)
assert isinstance(filepath, str)
ckpt = torch.load(filepath, map_location="cpu")
Expand Down Expand Up @@ -494,43 +491,73 @@ class Ensemble(torch.nn.ModuleList):

Args:
modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over.
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy).
"""

def __init__(self, modules: List[nn.Module], return_std: bool = False):
def __init__(self, modules: List[nn.Module]):
for module in modules:
assert isinstance(module, TorchMD_Net)
super().__init__(modules)
self.return_std = return_std

def forward(
self,
*args,
**kwargs,
):
"""Average predictions over all models in the ensemble.
The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble.
z: Tensor,
pos: Tensor,
batch: Optional[Tensor] = None,
box: Optional[Tensor] = None,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
extra_args: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]:
Note that this function will fail if derivative=False.

"""
Compute the output of the model.

This function optionally supports periodic boundary conditions with
arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy
certain requirements:

.. code:: python

a[1] = a[2] = b[2] = 0
a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff
a[0] >= 2*b[0]
a[0] >= 2*c[0]
b[1] >= 2*c[1]


These requirements correspond to a particular rotation of the system and
reduced form of the vectors, as well as the requirement that the cutoff be
no larger than half the box width.

Args:
*args: Positional arguments to forward to the models.
**kwargs: Keyword arguments to forward to the models.
Returns:
Tuple[Tensor, Optional[Tensor]] or Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The average and standard deviation of the predictions over all models in the ensemble. If return_std is False, the output is a tuple (mean_y, mean_neg_dy). If return_std is True, the output is a tuple (mean_y, mean_neg_dy, std_y, std_neg_dy).
z (Tensor): Atomic numbers of the atoms in the molecule. Shape: (N,).
pos (Tensor): Atomic positions in the molecule. Shape: (N, 3).
batch (Tensor, optional): Batch indices for the atoms in the molecule. Shape: (N,).
box (Tensor, optional): Box vectors. Shape (3, 3).
The vectors defining the periodic box. This must have shape `(3, 3)`,
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
q (Tensor, optional): Atomic charges in the molecule. Shape: (N,).
s (Tensor, optional): Atomic spins in the molecule. Shape: (N,).
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.

Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: The mean output of the models, the mean derivatives, the std of the outputs, and the std of the derivatives.

"""


y = []
neg_dy = []
for model in self:
res = model(*args, **kwargs)
res = model(z=z, pos=pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args)
y.append(res[0])
neg_dy.append(res[1])
y = torch.stack(y)
neg_dy = torch.stack(neg_dy)
y_mean = torch.mean(y, axis=0)
neg_dy_mean = torch.mean(neg_dy, axis=0)
y_std = torch.std(y, axis=0)
neg_dy_std = torch.std(neg_dy, axis=0)

if self.return_std:
return y_mean, neg_dy_mean, y_std, neg_dy_std
else:
return y_mean, neg_dy_mean
y_mean = torch.mean(y, dim=0)
neg_dy_mean = torch.mean(neg_dy, dim=0)
y_std = torch.std(y, dim=0)
neg_dy_std = torch.std(neg_dy, dim=0)

return y_mean, neg_dy_mean, y_std, neg_dy_std

Loading