Skip to content

Commit

Permalink
Merge pull request #311 from torchmd/ensemble_zips
Browse files Browse the repository at this point in the history
Support for ensemble model zip files
  • Loading branch information
stefdoerr authored Mar 28, 2024
2 parents 8a1be71 + 1a7b274 commit 8b47246
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 29 deletions.
60 changes: 40 additions & 20 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,26 +116,28 @@ def test_cuda_graph_compatible(model_name):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
z, pos, batch = create_example_batch()
args = {"model": model_name,
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_errors": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32 }
args = {
"model": model_name,
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_errors": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32,
}
model = create_model(args).to(device="cuda")
model.eval()
z = z.to("cuda")
Expand Down Expand Up @@ -260,3 +262,21 @@ def test_ensemble():
assert neg_dy_std.shape == deriv.shape
assert (y_std == 0).all()
assert (neg_dy_std == 0).all()

import zipfile
import tempfile

with tempfile.TemporaryDirectory() as tmpdir:
ensemble_zip = join(tmpdir, "ensemble.zip")
with zipfile.ZipFile(ensemble_zip, "w") as zipf:
for i, ckpt in enumerate(ckpts):
zipf.write(ckpt, f"model_{i}.ckpt")
ensemble_model = load_model(ensemble_zip, return_std=True)
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)

torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
assert y_std.shape == pred.shape
assert neg_dy_std.shape == deriv.shape
assert (y_std == 0).all()
assert (neg_dy_std == 0).all()
67 changes: 58 additions & 9 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

from glob import glob
import os
import re
import tempfile
from typing import Optional, List, Tuple, Dict
import torch
from torch.autograd import grad
Expand All @@ -13,6 +15,7 @@
from torchmdnet import priors
from lightning_utilities.core.rank_zero import rank_zero_warn
import warnings
import zipfile


def create_model(args, prior_model=None, mean=None, std=None):
Expand Down Expand Up @@ -139,26 +142,72 @@ def create_model(args, prior_model=None, mean=None, std=None):
return model


def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs):
"""Load an ensemble of models from a list of checkpoint files or a zip file.
Args:
filepath (str or list): Can be any of the following:
- Path to a zip file containing multiple checkpoint files.
- List of paths to checkpoint files.
args (dict, optional): Arguments for the model. Defaults to None.
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False.
**kwargs: Extra keyword arguments for the model, will be passed to :py:mod:`load_model`.
Returns:
nn.Module: An instance of :py:mod:`Ensemble`.
"""
if isinstance(filepath, (list, tuple)):
assert all(isinstance(f, str) for f in filepath), "Invalid filepath list."
model_list = [
load_model(f, args=args, device=device, **kwargs) for f in filepath
]
elif filepath.endswith(".zip"):
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(filepath, "r") as z:
z.extractall(tmpdir)
ckpt_list = glob(os.path.join(tmpdir, "*.ckpt"))
assert len(ckpt_list) > 0, "No checkpoint files found in zip file."
model_list = [
load_model(f, args=args, device=device, **kwargs) for f in ckpt_list
]
else:
raise ValueError(
"Invalid filepath. Must be a list of paths or a path to a zip file."
)
return Ensemble(
model_list,
return_std=return_std,
)


def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
"""Load a model from a checkpoint file.
If a list of paths is given, an :py:mod:`Ensemble` model is returned.
If a list of paths or a path to a zip file is given, an :py:mod:`Ensemble` model is returned.
Args:
filepath (str or list): Path to the checkpoint file or a list of paths.
filepath (str or list): Can be any of the following:
- Path to a checkpoint file. In this case, a :py:mod:`TorchMD_Net` model is returned.
- Path to a zip file containing multiple checkpoint files. In this case, an :py:mod:`Ensemble` model is returned.
- List of paths to checkpoint files. In this case, an :py:mod:`Ensemble` model is returned.
args (dict, optional): Arguments for the model. Defaults to None.
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False.
**kwargs: Extra keyword arguments for the model.
Returns:
nn.Module: An instance of the TorchMD_Net model.
nn.Module: An instance of the TorchMD_Net model or an Ensemble model.
"""
if isinstance(filepath, (list, tuple)):
return Ensemble(
[load_model(f, args=args, device=device, **kwargs) for f in filepath],
return_std=return_std,
isEnsemble = isinstance(filepath, (list, tuple)) or filepath.endswith(".zip")
if isEnsemble:
return load_ensemble(
filepath, args=args, device=device, return_std=return_std, **kwargs
)

assert isinstance(filepath, str)
ckpt = torch.load(filepath, map_location="cpu")
if args is None:
args = ckpt["hyper_parameters"]
Expand Down

0 comments on commit 8b47246

Please sign in to comment.