Skip to content

Commit

Permalink
Merge pull request #279 from RaulPPelaez/ref_energies
Browse files Browse the repository at this point in the history
Training with relative energies, infering with total energy.
  • Loading branch information
stefdoerr authored Feb 15, 2024
2 parents c0edfed + ce008ab commit 0a2106b
Show file tree
Hide file tree
Showing 16 changed files with 330 additions and 120 deletions.
45 changes: 44 additions & 1 deletion docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,42 @@ If you want to train on custom data, first have a look at :py:mod:`torchmdnet.da

To add a new dataset, you need to:

1. Write a new class inheriting from :py:mod:`torch_geometric.data.Dataset`, see `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html>`_ for a tutorial on that.
1. Write a new class inheriting from :py:mod:`torch_geometric.data.Dataset`, see `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html>`_ for a tutorial on that. You may also start from :ref:`one of the base datasets we provide <base datasets>`.
2. Add the new class to the :py:mod:`torchmdnet.datasets` module (by listing it in the `__all__` variable in the `datasets/__init__.py` file) so that the :ref:`configuration file <configuration-file>` recognizes it.

.. note::

The dataset must return torch-geometric `Data` objects, containing at least the keys `z` (atom types) and `pos` (atomic coordinates), as well as `y` (label), `neg_dy` (negative derivative of the label w.r.t atom coordinates) or both.


Datasets and Atomref
--------------------

Compatibility with the :py:mod:`torchmdnet.priors.Atomref` prior (and thus :ref:`delta learning <delta-learning>`) requires the dataset to define a method called :code:`get_atomref` as follows:

.. code:: python
class MyDataset(Dataset):
# ...
def get_atomref(self, max_z=100):
"""Atomic energy reference values.
Args:
max_z (int): Maximum atomic number
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
refs = pt.zeros(max_z)
# Set the references for each element present in the dataset.
refs[1] = -0.500607632585
refs[6] = -37.8302333826
refs[7] = -54.5680045287
refs[8] = -75.0362229210
# ...
return refs.view(-1, 1)
Available Datasets
------------------

Expand All @@ -79,3 +108,17 @@ Available Datasets

.. include:: generated/torchmdnet.datasets.rst
:start-line: 5

.. _base datasets:
Base Datasets
-------------
The following datasets are used as a base for other datasets.

.. autoclass:: torchmdnet.datasets.ani.ANIBase
:noindex:

.. autoclass:: torchmdnet.datasets.comp6.COMP6Base
:noindex:

.. autoclass:: torchmdnet.datasets.memdataset.MemmappedDataset
:noindex:
39 changes: 39 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,45 @@ Once you have trained a model you should have a checkpoint that you can load for

.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.

.. _delta-learning:
Training on relative energies
-----------------------------

It might be useful to train the model on relative energies but then make the model produce total energies when running inference.
TorchMD-Net supports delta training via the :code:`remove_ref_energy` option. Passing this option when training (either via the :ref:`configuration-file` or using the :ref:`torchmd-train` command line interface) will subtract the reference energy from each atom in a sample before passing it to the model.

.. note:: Delta learning requires a :ref:`dataset <Datasets>` that is compatible with :py:mod:`torchmdnet.priors.Atomref`.

If :code:`remove_ref_energy` is turned on, the reference energy is stored in the checkpoint file and is added back to the output of the model during inference if the model is loaded with :code:`remove_ref_energy=False`.

.. note:: The reference energies are stored as an :py:mod:`torchmdnet.priors.Atomref` prior with :code:`enable=False`.

Example
~~~~~~~

First we train a model with the :code:`remove_ref_energy` option turned on:

.. code:: shell
torchmd-train --config /path/to/config.yaml --remove_ref_energy
Then we load the model for inference:

.. code:: python
import torch
from torchmdnet.models.model import load_model
checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt"
model = load_model(checkpoint, remove_ref_energy=False)
# An arbitrary set of inputs for the model
n_atoms = 10
zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long)
z = zs[torch.randint(0, len(zs), (n_atoms,))]
pos = torch.randn(len(z), 3)
batch = torch.zeros(len(z), dtype=torch.long)
y, neg_dy = model(z, pos, batch)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ def test_datamodule_create(tmpdir):
dl2 = data._get_dataloader(data.train_dataset, "train", store_dataloader=False)
assert dl1 is not dl2

def test_dataloader_get(tmpdir):
args = load_example_args("graph-network")
args["train_size"] = 800
args["val_size"] = 100
args["test_size"] = 100
args["log_dir"] = tmpdir

dataset = DummyDataset()
data = DataModule(args, dataset=dataset)
data.prepare_data()
data.setup("fit")
# Get the first element of the training set
assert data.train_dataloader().dataset[0] is not None
# Assert the elements are in there (z, pos, y, neg_dy)
item = data.train_dataloader().dataset[0]
assert "z" in item
assert "pos" in item
assert "y" in item
assert "neg_dy" in item
# Assert that the dataloader is not empty
assert len(data.train_dataloader()) > 0

@mark.parametrize("energy,forces", [(True, True), (True, False), (False, True)])
@mark.parametrize("has_atomref", [True, False])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset_comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_dataset_s66x8():
),
atol=1e-4,
)
assert pt.allclose(sample.y, pt.tensor([[-47.5919]]))
assert pt.allclose(sample.y, pt.tensor([[-5755.7288331]],dtype=pt.float64))
assert pt.allclose(
sample.neg_dy,
-pt.tensor(
Expand Down
16 changes: 13 additions & 3 deletions tests/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@


@mark.parametrize("model_name", models.__all_models__)
def test_atomref(model_name):
@mark.parametrize("enable_atomref", [True, False])
def test_atomref(model_name, enable_atomref):
dataset = DummyDataset(has_atomref=True)
atomref = Atomref(max_z=100, dataset=dataset)
atomref = Atomref(max_z=100, dataset=dataset, enable=enable_atomref)
z, pos, batch = create_example_batch()

# create model with atomref
Expand All @@ -36,9 +37,18 @@ def test_atomref(model_name):
x_no_atomref, _ = model_no_atomref(z, pos, batch)

# check if the output of both models differs by the expected atomref contribution
expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1)
if enable_atomref:
expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1)
else:
expected_offset = 0
torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset)

@mark.parametrize("trainable", [True, False])
def test_atomref_trainable(trainable):
dataset = DummyDataset(has_atomref=True)
atomref = Atomref(max_z=100, dataset=dataset, trainable=trainable)
assert atomref.atomref.weight.requires_grad == trainable

def test_zbl():
pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
Expand Down
2 changes: 2 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs
args["prior_model"] = None
if "box_vecs" not in args:
args["box_vecs"] = None
if "remove_ref_energy" not in args:
args["remove_ref_energy"] = False
for key, val in kwargs.items():
assert key in args, f"Broken test! Unknown key '{key}'."
args[key] = val
Expand Down
31 changes: 0 additions & 31 deletions torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,6 @@
from torch_geometric.data import Dataset
from torchmdnet.utils import make_splits, MissingEnergyException
from torchmdnet.models.utils import scatter
from torchmdnet.models.utils import dtype_mapping


class FloatCastDatasetWrapper(Dataset):
"""A wrapper around a torch_geometric dataset that casts all floating point
tensors to a given dtype.
"""

def __init__(self, dataset, dtype=torch.float64):
super(FloatCastDatasetWrapper, self).__init__(
dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter
)
self._dataset = dataset
self._dtype = dtype

def len(self):
return len(self._dataset)

def get(self, idx):
data = self._dataset.get(idx)
for key, value in data:
if torch.is_tensor(value) and torch.is_floating_point(value):
setattr(data, key, value.to(self._dtype))
return data

def __getattr__(self, __name):
return getattr(self.__getattribute__("_dataset"), __name)


class DataModule(LightningDataModule):
Expand Down Expand Up @@ -82,10 +55,6 @@ def setup(self, stage):
self.hparams["dataset_root"], **dataset_arg
)

self.dataset = FloatCastDatasetWrapper(
self.dataset, dtype_mapping[self.hparams["precision"]]
)

self.idx_train, self.idx_val, self.idx_test = make_splits(
len(self.dataset),
self.hparams["train_size"],
Expand Down
9 changes: 8 additions & 1 deletion torchmdnet/datasets/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,14 @@ def raw_file_names(self):
raise NotImplementedError

def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior."""
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
Args:
max_z (int): Maximum atomic number
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
refs = pt.zeros(max_z)
for key, val in self._ELEMENT_ENERGIES.items():
refs[key] = val * self.HARTREE_TO_EV
Expand Down
32 changes: 23 additions & 9 deletions torchmdnet/datasets/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


class COMP6Base(MemmappedDataset):
ELEMENT_ENERGIES = {
_ELEMENT_ENERGIES = {
1: -0.500607632585,
6: -37.8302333826,
7: -54.5680045287,
Expand All @@ -45,7 +45,6 @@ def __init__(
transform,
pre_transform,
pre_filter,
remove_ref_energy=False,
properties=("y", "neg_dy"),
)

Expand All @@ -60,11 +59,20 @@ def raw_url(self):
f"{url_prefix}/{self.raw_url_name}/{name}" for name in self.raw_file_names
]

@staticmethod
def compute_reference_energy(atomic_numbers):
atomic_numbers = np.array(atomic_numbers)
energy = sum(COMP6Base.ELEMENT_ENERGIES[z] for z in atomic_numbers)
return energy * COMP6Base.HARTREE_TO_EV
def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
Args:
max_z (int): Maximum atomic number
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
refs = pt.zeros(max_z)
for key, val in self._ELEMENT_ENERGIES.items():
refs[key] = val * self.HARTREE_TO_EV

return refs.view(-1, 1)

def download(self):
for url in self.raw_url:
Expand All @@ -89,7 +97,6 @@ def sample_iter(self, mol_ids=False):
all_neg_dy = (
-all_neg_dy
) # The COMP6 datasets accidentally have gradients as forces
all_y -= self.compute_reference_energy(z)

assert all_pos.shape[0] == all_y.shape[0]
assert all_pos.shape[1] == z.shape[0]
Expand Down Expand Up @@ -338,7 +345,14 @@ def sample_iter(self, mol_ids=False):
yield data

def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior."""
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
Args:
max_z (int): Maximum atomic number
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
refs = pt.zeros(max_z)
for key, val in self._ELEMENT_ENERGIES.items():
refs[key] = val * self.HARTREE_TO_EV
Expand Down
13 changes: 3 additions & 10 deletions torchmdnet/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,9 @@ class Custom(Dataset):
forceglob (string, optional): Glob path for force files. Stored as "neg_dy".
(default: :obj:`None`)
preload_memory_limit (int, optional): If the dataset is smaller than this limit (in MB), preload it into CPU memory.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
pre_transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before being saved to disk.
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean value,
indicating whether the data object should be included in the final
dataset.
transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access.
pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk.
pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset.
Example:
Expand Down
Loading

0 comments on commit 0a2106b

Please sign in to comment.