diff --git a/tests/test_model.py b/tests/test_model.py index 1dd5e354..f606559e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -116,26 +116,28 @@ def test_cuda_graph_compatible(model_name): if not torch.cuda.is_available(): pytest.skip("CUDA not available") z, pos, batch = create_example_batch() - args = {"model": model_name, - "embedding_dimension": 128, - "num_layers": 2, - "num_rbf": 32, - "rbf_type": "expnorm", - "trainable_rbf": False, - "activation": "silu", - "cutoff_lower": 0.0, - "cutoff_upper": 5.0, - "max_z": 100, - "max_num_neighbors": 128, - "equivariance_invariance_group": "O(3)", - "prior_model": None, - "atom_filter": -1, - "derivative": True, - "check_errors": False, - "static_shapes": True, - "output_model": "Scalar", - "reduce_op": "sum", - "precision": 32 } + args = { + "model": model_name, + "embedding_dimension": 128, + "num_layers": 2, + "num_rbf": 32, + "rbf_type": "expnorm", + "trainable_rbf": False, + "activation": "silu", + "cutoff_lower": 0.0, + "cutoff_upper": 5.0, + "max_z": 100, + "max_num_neighbors": 128, + "equivariance_invariance_group": "O(3)", + "prior_model": None, + "atom_filter": -1, + "derivative": True, + "check_errors": False, + "static_shapes": True, + "output_model": "Scalar", + "reduce_op": "sum", + "precision": 32, + } model = create_model(args).to(device="cuda") model.eval() z = z.to("cuda") @@ -260,3 +262,21 @@ def test_ensemble(): assert neg_dy_std.shape == deriv.shape assert (y_std == 0).all() assert (neg_dy_std == 0).all() + + import zipfile + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + ensemble_zip = join(tmpdir, "ensemble.zip") + with zipfile.ZipFile(ensemble_zip, "w") as zipf: + for i, ckpt in enumerate(ckpts): + zipf.write(ckpt, f"model_{i}.ckpt") + ensemble_model = load_model(ensemble_zip, return_std=True) + pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch) + + torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5) + assert y_std.shape == pred.shape + assert neg_dy_std.shape == deriv.shape + assert (y_std == 0).all() + assert (neg_dy_std == 0).all() diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index c090f90f..91369304 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -1,8 +1,10 @@ # Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) - +from glob import glob +import os import re +import tempfile from typing import Optional, List, Tuple, Dict import torch from torch.autograd import grad @@ -13,6 +15,7 @@ from torchmdnet import priors from lightning_utilities.core.rank_zero import rank_zero_warn import warnings +import zipfile def create_model(args, prior_model=None, mean=None, std=None): @@ -139,26 +142,72 @@ def create_model(args, prior_model=None, mean=None, std=None): return model +def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs): + """Load an ensemble of models from a list of checkpoint files or a zip file. + + Args: + filepath (str or list): Can be any of the following: + + - Path to a zip file containing multiple checkpoint files. + - List of paths to checkpoint files. + + args (dict, optional): Arguments for the model. Defaults to None. + device (str, optional): Device on which the model should be loaded. Defaults to "cpu". + return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. + **kwargs: Extra keyword arguments for the model, will be passed to :py:mod:`load_model`. + + Returns: + nn.Module: An instance of :py:mod:`Ensemble`. + """ + if isinstance(filepath, (list, tuple)): + assert all(isinstance(f, str) for f in filepath), "Invalid filepath list." + model_list = [ + load_model(f, args=args, device=device, **kwargs) for f in filepath + ] + elif filepath.endswith(".zip"): + with tempfile.TemporaryDirectory() as tmpdir: + with zipfile.ZipFile(filepath, "r") as z: + z.extractall(tmpdir) + ckpt_list = glob(os.path.join(tmpdir, "*.ckpt")) + assert len(ckpt_list) > 0, "No checkpoint files found in zip file." + model_list = [ + load_model(f, args=args, device=device, **kwargs) for f in ckpt_list + ] + else: + raise ValueError( + "Invalid filepath. Must be a list of paths or a path to a zip file." + ) + return Ensemble( + model_list, + return_std=return_std, + ) + + def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): """Load a model from a checkpoint file. - If a list of paths is given, an :py:mod:`Ensemble` model is returned. + If a list of paths or a path to a zip file is given, an :py:mod:`Ensemble` model is returned. Args: - filepath (str or list): Path to the checkpoint file or a list of paths. + filepath (str or list): Can be any of the following: + + - Path to a checkpoint file. In this case, a :py:mod:`TorchMD_Net` model is returned. + - Path to a zip file containing multiple checkpoint files. In this case, an :py:mod:`Ensemble` model is returned. + - List of paths to checkpoint files. In this case, an :py:mod:`Ensemble` model is returned. + args (dict, optional): Arguments for the model. Defaults to None. device (str, optional): Device on which the model should be loaded. Defaults to "cpu". return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False. **kwargs: Extra keyword arguments for the model. Returns: - nn.Module: An instance of the TorchMD_Net model. + nn.Module: An instance of the TorchMD_Net model or an Ensemble model. """ - if isinstance(filepath, (list, tuple)): - return Ensemble( - [load_model(f, args=args, device=device, **kwargs) for f in filepath], - return_std=return_std, + isEnsemble = isinstance(filepath, (list, tuple)) or filepath.endswith(".zip") + if isEnsemble: + return load_ensemble( + filepath, args=args, device=device, return_std=return_std, **kwargs ) - + assert isinstance(filepath, str) ckpt = torch.load(filepath, map_location="cpu") if args is None: args = ckpt["hyper_parameters"]