diff --git a/tests/test_model.py b/tests/test_model.py index f606559e..dfc078ea 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -246,11 +246,15 @@ def test_gradients(model_name): model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3 ) - -def test_ensemble(): +@mark.parametrize("check_script", [True, False]) +def test_ensemble(check_script): ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3 model = load_model(ckpts[0]) - ensemble_model = load_model(ckpts, return_std=True) + + if check_script: + ensemble_model = torch.jit.script(load_model(ckpts)) + else: + ensemble_model = load_model(ckpts) z, pos, batch = create_example_batch(n_atoms=5) pred, deriv = model(z, pos, batch) @@ -271,7 +275,11 @@ def test_ensemble(): with zipfile.ZipFile(ensemble_zip, "w") as zipf: for i, ckpt in enumerate(ckpts): zipf.write(ckpt, f"model_{i}.ckpt") - ensemble_model = load_model(ensemble_zip, return_std=True) + + if check_script: + ensemble_model = torch.jit.script(load_model(ckpts)) + else: + ensemble_model = load_model(ensemble_zip) pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch) torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 91369304..42a03775 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -5,7 +5,7 @@ import os import re import tempfile -from typing import Optional, List, Tuple, Dict +from typing import Optional, List, Tuple, Dict, Union import torch from torch.autograd import grad from torch import nn, Tensor @@ -142,7 +142,7 @@ def create_model(args, prior_model=None, mean=None, std=None): return model -def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs): +def load_ensemble(filepath, args=None, device="cpu", **kwargs): """Load an ensemble of models from a list of checkpoint files or a zip file. Args: @@ -153,7 +153,6 @@ def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs) args (dict, optional): Arguments for the model. Defaults to None. device (str, optional): Device on which the model should be loaded. Defaults to "cpu". - return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. **kwargs: Extra keyword arguments for the model, will be passed to :py:mod:`load_model`. Returns: @@ -179,11 +178,10 @@ def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs) ) return Ensemble( model_list, - return_std=return_std, ) -def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): +def load_model(filepath, args=None, device="cpu", **kwargs): """Load a model from a checkpoint file. If a list of paths or a path to a zip file is given, an :py:mod:`Ensemble` model is returned. @@ -196,7 +194,6 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): args (dict, optional): Arguments for the model. Defaults to None. device (str, optional): Device on which the model should be loaded. Defaults to "cpu". - return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False. **kwargs: Extra keyword arguments for the model. Returns: @@ -205,7 +202,7 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): isEnsemble = isinstance(filepath, (list, tuple)) or filepath.endswith(".zip") if isEnsemble: return load_ensemble( - filepath, args=args, device=device, return_std=return_std, **kwargs + filepath, args=args, device=device, **kwargs ) assert isinstance(filepath, str) ckpt = torch.load(filepath, map_location="cpu") @@ -490,47 +487,78 @@ def forward( class Ensemble(torch.nn.ModuleList): """Average predictions over an ensemble of TorchMD-Net models. - This module behaves like a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with. + This module behaves similarly to a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with. Args: modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over. - return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy). """ - def __init__(self, modules: List[nn.Module], return_std: bool = False): + def __init__(self, modules: List[nn.Module]): for module in modules: assert isinstance(module, TorchMD_Net) super().__init__(modules) - self.return_std = return_std def forward( self, - *args, - **kwargs, - ): - """Average predictions over all models in the ensemble. - The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble. + z: Tensor, + pos: Tensor, + batch: Optional[Tensor] = None, + box: Optional[Tensor] = None, + q: Optional[Tensor] = None, + s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, + ) -> Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: + """ + Compute the output of the ensemble of models. + + The predictions are the average over all models in the ensemble. + + This function optionally supports periodic boundary conditions with + arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy + certain requirements: + + .. code:: python + + a[1] = a[2] = b[2] = 0 + a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff + a[0] >= 2*b[0] + a[0] >= 2*c[0] + b[1] >= 2*c[1] + + + These requirements correspond to a particular rotation of the system and + reduced form of the vectors, as well as the requirement that the cutoff be + no larger than half the box width. + Args: - *args: Positional arguments to forward to the models. - **kwargs: Keyword arguments to forward to the models. - Returns: - Tuple[Tensor, Optional[Tensor]] or Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The average and standard deviation of the predictions over all models in the ensemble. If return_std is False, the output is a tuple (mean_y, mean_neg_dy). If return_std is True, the output is a tuple (mean_y, mean_neg_dy, std_y, std_neg_dy). + z (Tensor): Atomic numbers of the atoms in the molecule. Shape: (N,). + pos (Tensor): Atomic positions in the molecule. Shape: (N, 3). + batch (Tensor, optional): Batch indices for the atoms in the molecule. Shape: (N,). + box (Tensor, optional): Box vectors. Shape (3, 3). + The vectors defining the periodic box. This must have shape `(3, 3)`, + where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. + If this is omitted, periodic boundary conditions are not applied. + q (Tensor, optional): Atomic charges in the molecule. Shape: (N,). + s (Tensor, optional): Atomic spins in the molecule. Shape: (N,). + extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model. + Returns: + Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The mean output of the models, the mean negative derivatives, the std of the outputs, and the std of the negative derivatives. + """ + y = [] neg_dy = [] for model in self: - res = model(*args, **kwargs) + res = model(z=z, pos=pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args) y.append(res[0]) neg_dy.append(res[1]) y = torch.stack(y) neg_dy = torch.stack(neg_dy) - y_mean = torch.mean(y, axis=0) - neg_dy_mean = torch.mean(neg_dy, axis=0) - y_std = torch.std(y, axis=0) - neg_dy_std = torch.std(neg_dy, axis=0) - - if self.return_std: - return y_mean, neg_dy_mean, y_std, neg_dy_std - else: - return y_mean, neg_dy_mean + y_mean = torch.mean(y, dim=0) + neg_dy_mean = torch.mean(neg_dy, dim=0) + y_std = torch.std(y, dim=0) + neg_dy_std = torch.std(neg_dy, dim=0) + + return y_mean, neg_dy_mean, y_std, neg_dy_std +