diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 86fa22eee..dd71e0367 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -64,13 +64,42 @@ If you want to train on custom data, first have a look at :py:mod:`torchmdnet.da To add a new dataset, you need to: -1. Write a new class inheriting from :py:mod:`torch_geometric.data.Dataset`, see `here `_ for a tutorial on that. +1. Write a new class inheriting from :py:mod:`torch_geometric.data.Dataset`, see `here `_ for a tutorial on that. You may also start from :ref:`one of the base datasets we provide `. 2. Add the new class to the :py:mod:`torchmdnet.datasets` module (by listing it in the `__all__` variable in the `datasets/__init__.py` file) so that the :ref:`configuration file ` recognizes it. .. note:: The dataset must return torch-geometric `Data` objects, containing at least the keys `z` (atom types) and `pos` (atomic coordinates), as well as `y` (label), `neg_dy` (negative derivative of the label w.r.t atom coordinates) or both. + +Datasets and Atomref +-------------------- + +Compatibility with the :py:mod:`torchmdnet.priors.Atomref` prior (and thus :ref:`delta learning `) requires the dataset to define a method called :code:`get_atomref` as follows: + +.. code:: python + + class MyDataset(Dataset): + # ... + def get_atomref(self, max_z=100): + """Atomic energy reference values. + Args: + max_z (int): Maximum atomic number + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ + refs = pt.zeros(max_z) + # Set the references for each element present in the dataset. + refs[1] = -0.500607632585 + refs[6] = -37.8302333826 + refs[7] = -54.5680045287 + refs[8] = -75.0362229210 + # ... + return refs.view(-1, 1) + + + + Available Datasets ------------------ @@ -79,3 +108,17 @@ Available Datasets .. include:: generated/torchmdnet.datasets.rst :start-line: 5 + +.. _base datasets: +Base Datasets +------------- +The following datasets are used as a base for other datasets. + +.. autoclass:: torchmdnet.datasets.ani.ANIBase + :noindex: + +.. autoclass:: torchmdnet.datasets.comp6.COMP6Base + :noindex: + +.. autoclass:: torchmdnet.datasets.memdataset.MemmappedDataset + :noindex: diff --git a/docs/source/models.rst b/docs/source/models.rst index b5e8be5cd..595f43dd1 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -37,6 +37,45 @@ Once you have trained a model you should have a checkpoint that you can load for .. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case. +.. _delta-learning: +Training on relative energies +----------------------------- + +It might be useful to train the model on relative energies but then make the model produce total energies when running inference. +TorchMD-Net supports delta training via the :code:`remove_ref_energy` option. Passing this option when training (either via the :ref:`configuration-file` or using the :ref:`torchmd-train` command line interface) will subtract the reference energy from each atom in a sample before passing it to the model. + +.. note:: Delta learning requires a :ref:`dataset ` that is compatible with :py:mod:`torchmdnet.priors.Atomref`. + +If :code:`remove_ref_energy` is turned on, the reference energy is stored in the checkpoint file and is added back to the output of the model during inference if the model is loaded with :code:`remove_ref_energy=False`. + +.. note:: The reference energies are stored as an :py:mod:`torchmdnet.priors.Atomref` prior with :code:`enable=False`. + +Example +~~~~~~~ + +First we train a model with the :code:`remove_ref_energy` option turned on: + +.. code:: shell + + torchmd-train --config /path/to/config.yaml --remove_ref_energy + +Then we load the model for inference: + +.. code:: python + + import torch + from torchmdnet.models.model import load_model + checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt" + model = load_model(checkpoint, remove_ref_energy=False) + + # An arbitrary set of inputs for the model + n_atoms = 10 + zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long) + z = zs[torch.randint(0, len(zs), (n_atoms,))] + pos = torch.randn(len(z), 3) + batch = torch.zeros(len(z), dtype=torch.long) + + y, neg_dy = model(z, pos, batch) diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 4ea9ea509..1a40d9c74 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -28,6 +28,27 @@ def test_datamodule_create(tmpdir): dl2 = data._get_dataloader(data.train_dataset, "train", store_dataloader=False) assert dl1 is not dl2 +def test_dataloader_get(tmpdir): + args = load_example_args("graph-network") + args["train_size"] = 800 + args["val_size"] = 100 + args["test_size"] = 100 + args["log_dir"] = tmpdir + + dataset = DummyDataset() + data = DataModule(args, dataset=dataset) + data.prepare_data() + data.setup("fit") + # Get the first element of the training set + assert data.train_dataloader().dataset[0] is not None + # Assert the elements are in there (z, pos, y, neg_dy) + item = data.train_dataloader().dataset[0] + assert "z" in item + assert "pos" in item + assert "y" in item + assert "neg_dy" in item + # Assert that the dataloader is not empty + assert len(data.train_dataloader()) > 0 @mark.parametrize("energy,forces", [(True, True), (True, False), (False, True)]) @mark.parametrize("has_atomref", [True, False]) diff --git a/tests/test_dataset_comp6.py b/tests/test_dataset_comp6.py index 6e8bed1d0..d22bcd578 100644 --- a/tests/test_dataset_comp6.py +++ b/tests/test_dataset_comp6.py @@ -39,7 +39,7 @@ def test_dataset_s66x8(): ), atol=1e-4, ) - assert pt.allclose(sample.y, pt.tensor([[-47.5919]])) + assert pt.allclose(sample.y, pt.tensor([[-5755.7288331]],dtype=pt.float64)) assert pt.allclose( sample.neg_dy, -pt.tensor( diff --git a/tests/test_priors.py b/tests/test_priors.py index 8f1f20156..6def5bbfe 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -17,9 +17,10 @@ @mark.parametrize("model_name", models.__all_models__) -def test_atomref(model_name): +@mark.parametrize("enable_atomref", [True, False]) +def test_atomref(model_name, enable_atomref): dataset = DummyDataset(has_atomref=True) - atomref = Atomref(max_z=100, dataset=dataset) + atomref = Atomref(max_z=100, dataset=dataset, enable=enable_atomref) z, pos, batch = create_example_batch() # create model with atomref @@ -36,9 +37,18 @@ def test_atomref(model_name): x_no_atomref, _ = model_no_atomref(z, pos, batch) # check if the output of both models differs by the expected atomref contribution - expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1) + if enable_atomref: + expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1) + else: + expected_offset = 0 torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset) +@mark.parametrize("trainable", [True, False]) +def test_atomref_trainable(trainable): + dataset = DummyDataset(has_atomref=True) + atomref = Atomref(max_z=100, dataset=dataset, trainable=trainable) + assert atomref.atomref.weight.requires_grad == trainable + def test_zbl(): pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types diff --git a/tests/utils.py b/tests/utils.py index afaab7418..ef8bcddb9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,6 +24,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs args["prior_model"] = None if "box_vecs" not in args: args["box_vecs"] = None + if "remove_ref_energy" not in args: + args["remove_ref_energy"] = False for key, val in kwargs.items(): assert key in args, f"Broken test! Unknown key '{key}'." args[key] = val diff --git a/torchmdnet/data.py b/torchmdnet/data.py index a7d6bf161..f12e4591c 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -13,33 +13,6 @@ from torch_geometric.data import Dataset from torchmdnet.utils import make_splits, MissingEnergyException from torchmdnet.models.utils import scatter -from torchmdnet.models.utils import dtype_mapping - - -class FloatCastDatasetWrapper(Dataset): - """A wrapper around a torch_geometric dataset that casts all floating point - tensors to a given dtype. - """ - - def __init__(self, dataset, dtype=torch.float64): - super(FloatCastDatasetWrapper, self).__init__( - dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter - ) - self._dataset = dataset - self._dtype = dtype - - def len(self): - return len(self._dataset) - - def get(self, idx): - data = self._dataset.get(idx) - for key, value in data: - if torch.is_tensor(value) and torch.is_floating_point(value): - setattr(data, key, value.to(self._dtype)) - return data - - def __getattr__(self, __name): - return getattr(self.__getattribute__("_dataset"), __name) class DataModule(LightningDataModule): @@ -82,10 +55,6 @@ def setup(self, stage): self.hparams["dataset_root"], **dataset_arg ) - self.dataset = FloatCastDatasetWrapper( - self.dataset, dtype_mapping[self.hparams["precision"]] - ) - self.idx_train, self.idx_val, self.idx_test = make_splits( len(self.dataset), self.hparams["train_size"], diff --git a/torchmdnet/datasets/ani.py b/torchmdnet/datasets/ani.py index 79c92298b..e7ca1add0 100644 --- a/torchmdnet/datasets/ani.py +++ b/torchmdnet/datasets/ani.py @@ -46,7 +46,14 @@ def raw_file_names(self): raise NotImplementedError def get_atomref(self, max_z=100): - """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.""" + """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. + + Args: + max_z (int): Maximum atomic number + + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ refs = pt.zeros(max_z) for key, val in self._ELEMENT_ENERGIES.items(): refs[key] = val * self.HARTREE_TO_EV diff --git a/torchmdnet/datasets/comp6.py b/torchmdnet/datasets/comp6.py index 8d9632f29..a810a3d4a 100644 --- a/torchmdnet/datasets/comp6.py +++ b/torchmdnet/datasets/comp6.py @@ -23,7 +23,7 @@ class COMP6Base(MemmappedDataset): - ELEMENT_ENERGIES = { + _ELEMENT_ENERGIES = { 1: -0.500607632585, 6: -37.8302333826, 7: -54.5680045287, @@ -45,7 +45,6 @@ def __init__( transform, pre_transform, pre_filter, - remove_ref_energy=False, properties=("y", "neg_dy"), ) @@ -60,11 +59,20 @@ def raw_url(self): f"{url_prefix}/{self.raw_url_name}/{name}" for name in self.raw_file_names ] - @staticmethod - def compute_reference_energy(atomic_numbers): - atomic_numbers = np.array(atomic_numbers) - energy = sum(COMP6Base.ELEMENT_ENERGIES[z] for z in atomic_numbers) - return energy * COMP6Base.HARTREE_TO_EV + def get_atomref(self, max_z=100): + """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. + + Args: + max_z (int): Maximum atomic number + + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ + refs = pt.zeros(max_z) + for key, val in self._ELEMENT_ENERGIES.items(): + refs[key] = val * self.HARTREE_TO_EV + + return refs.view(-1, 1) def download(self): for url in self.raw_url: @@ -89,7 +97,6 @@ def sample_iter(self, mol_ids=False): all_neg_dy = ( -all_neg_dy ) # The COMP6 datasets accidentally have gradients as forces - all_y -= self.compute_reference_energy(z) assert all_pos.shape[0] == all_y.shape[0] assert all_pos.shape[1] == z.shape[0] @@ -338,7 +345,14 @@ def sample_iter(self, mol_ids=False): yield data def get_atomref(self, max_z=100): - """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.""" + """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. + + Args: + max_z (int): Maximum atomic number + + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ refs = pt.zeros(max_z) for key, val in self._ELEMENT_ENERGIES.items(): refs[key] = val * self.HARTREE_TO_EV diff --git a/torchmdnet/datasets/custom.py b/torchmdnet/datasets/custom.py index 47cb627c6..03fcb43f8 100644 --- a/torchmdnet/datasets/custom.py +++ b/torchmdnet/datasets/custom.py @@ -24,16 +24,9 @@ class Custom(Dataset): forceglob (string, optional): Glob path for force files. Stored as "neg_dy". (default: :obj:`None`) preload_memory_limit (int, optional): If the dataset is smaller than this limit (in MB), preload it into CPU memory. - transform (callable, optional): A function/transform that takes in an - :obj:`torch_geometric.data.Data` object and returns a transformed - version. The data object will be transformed before every access. - pre_transform (callable, optional): A function/transform that takes in an - :obj:`torch_geometric.data.Data` object and returns a transformed - version. The data object will be transformed before being saved to disk. - pre_filter (callable, optional): A function that takes in an - :obj:`torch_geometric.data.Data` object and returns a boolean value, - indicating whether the data object should be included in the final - dataset. + transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. + pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. + pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. Example: diff --git a/torchmdnet/datasets/memdataset.py b/torchmdnet/datasets/memdataset.py index 712bbe415..a97b54be8 100644 --- a/torchmdnet/datasets/memdataset.py +++ b/torchmdnet/datasets/memdataset.py @@ -33,23 +33,20 @@ class MemmappedDataset(Dataset): - :obj:`name.dp.mmap`: Dipole moment of each conformation. Args: - root (str): Root directory where the dataset should be stored. - transform (callable, optional): A function/transform that takes in an - :obj:`torch_geometric.data.Data` object and returns a transformed - version. The data object will be transformed before every access. - pre_transform (callable, optional): A function/transform that takes in an - :obj:`torch_geometric.data.Data` object and returns a transformed - version. The data object will be transformed before being saved to disk. - pre_filter (callable, optional): A function that takes in an - :obj:`torch_geometric.data.Data` object and returns a boolean value, - indicating whether the data object should be included in the final - dataset. - remove_ref_energy (bool, optional): If set to :obj:`True`, the reference - energy will be subtracted from the energy of each conformation before - returning it. - properties (tuple of str, optional): The properties to include in the - dataset. Can be any subset of :obj:`y`, :obj:`neg_dy`, :obj:`q`, - :obj:`pq`, and :obj:`dp`. + root (str): Root directory where the dataset should be stored. + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + pre_transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before being saved to disk. + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean value, + indicating whether the data object should be included in the final + dataset. + properties (tuple of str, optional): The properties to include in the + dataset. Can be any subset of :obj:`y`, :obj:`neg_dy`, :obj:`q`, + :obj:`pq`, and :obj:`dp`. """ def __init__( @@ -58,11 +55,9 @@ def __init__( transform=None, pre_transform=None, pre_filter=None, - remove_ref_energy=False, properties=("y", "neg_dy", "q", "pq", "dp"), ): self.name = self.__class__.__name__ - self.remove_ref_energy = remove_ref_energy self.properties = properties super().__init__(root, transform, pre_transform, pre_filter) @@ -110,9 +105,6 @@ def processed_paths_dict(self): ) } - def compute_reference_energy(self, atomic_numbers): - return self.get_atomref()[atomic_numbers].sum() - def sample_iter(self, mol_ids=False): raise NotImplementedError() @@ -253,20 +245,17 @@ def get(self, idx): """ atoms = slice(self.idx_mm[idx], self.idx_mm[idx + 1]) z = pt.tensor(self.z_mm[atoms], dtype=pt.long) - pos = pt.tensor(self.pos_mm[atoms], dtype=pt.float32) + pos = pt.tensor(self.pos_mm[atoms]) props = {} if "y" in self.properties: - y = self.y_mm[idx] - if self.remove_ref_energy: - y -= self.compute_reference_energy(z) - props["y"] = pt.tensor(y, dtype=pt.float32).view(1, 1) + props["y"] = pt.tensor(self.y_mm[idx]).view(1, 1) if "neg_dy" in self.properties: - props["neg_dy"] = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32) + props["neg_dy"] = pt.tensor(self.neg_dy_mm[atoms]) if "q" in self.properties: props["q"] = pt.tensor(self.q_mm[idx], dtype=pt.long) if "pq" in self.properties: - props["pq"] = pt.tensor(self.pq_mm[atoms], dtype=pt.float32) + props["pq"] = pt.tensor(self.pq_mm[atoms]) if "dp" in self.properties: - props["dp"] = pt.tensor(self.dp_mm[idx], dtype=pt.float32) + props["dp"] = pt.tensor(self.dp_mm[idx]) return Data(z=z, pos=pos, **props) diff --git a/torchmdnet/datasets/qm9.py b/torchmdnet/datasets/qm9.py index 79b78109f..6b97787da 100644 --- a/torchmdnet/datasets/qm9.py +++ b/torchmdnet/datasets/qm9.py @@ -28,6 +28,14 @@ def __init__(self, root, transform=None, label=None): super(QM9, self).__init__(root, transform=transform) def get_atomref(self, max_z=100): + """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. + + Args: + max_z (int): Maximum atomic number + + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ atomref = self.atomref(self.label_idx) if atomref is None: return None diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index c301fac28..938c26686 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -154,42 +154,65 @@ def load_model(filepath, args=None, device="cpu", **kwargs): if args is None: args = ckpt["hyper_parameters"] + delta_learning = args["remove_ref_energy"] if "remove_ref_energy" in args else False + for key, value in kwargs.items(): if not key in args: warnings.warn(f"Unknown hyperparameter: {key}={value}") args[key] = value model = create_model(args) + if delta_learning and "remove_ref_energy" in kwargs: + if not kwargs["remove_ref_energy"]: + assert len(model.prior_model) > 0, "Atomref prior must be added during training (with enable=False) for total energy prediction." + assert isinstance(model.prior_model[-1], priors.Atomref), "I expected the last prior to be Atomref." + # Set the Atomref prior to enabled + model.prior_model[-1].enable = True state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()} - # The following are for backward compatibility with models created when atomref was - # the only supported prior. - if "prior_model.initial_atomref" in state_dict: - warnings.warn( - "prior_model.initial_atomref is deprecated and will be removed in a future version. Use prior_model.0.initial_atomref instead.", - category=DeprecationWarning, - stacklevel=2, - ) - state_dict["prior_model.0.initial_atomref"] = state_dict[ - "prior_model.initial_atomref" - ] - del state_dict["prior_model.initial_atomref"] - if "prior_model.atomref.weight" in state_dict: - warnings.warn( - "prior_model.atomref.weight is deprecated and will be removed in a future version. Use prior_model.0.atomref.weight instead.", - category=DeprecationWarning, - stacklevel=2, - ) - state_dict["prior_model.0.atomref.weight"] = state_dict[ - "prior_model.atomref.weight" - ] - del state_dict["prior_model.atomref.weight"] model.load_state_dict(state_dict) return model.to(device) def create_prior_models(args, dataset=None): - """Parse the prior_model configuration option and create the prior models.""" + """Parse the prior_model configuration option and create the prior models. + + The information can be passed in different ways via the args dictionary, which must contain at least the key "prior_model". + + 1. A single prior model name and its arguments as a dictionary: + + ```python + args = { + "prior_model": "Atomref", + "prior_args": {"max_z": 100} + } + ``` + 2. A list of prior model names and their arguments as a list of dictionaries: + + ```python + + args = { + "prior_model": ["Atomref", "D2"], + "prior_args": [{"max_z": 100}, {"max_z": 100}] + } + ``` + + 3. A list of prior model names and their arguments as a dictionary: + + ```python + args = { + "prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}] + } + ``` + + Args: + args (dict): Arguments for the model. + dataset (torch_geometric.data.Dataset, optional): A dataset from which to extract the atomref values. Defaults to None. + + Returns: + list: A list of prior models. + + """ prior_models = [] if args["prior_model"]: prior_model = args["prior_model"] diff --git a/torchmdnet/module.py b/torchmdnet/module.py index b3e7a01a6..373870dc8 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -12,11 +12,48 @@ from lightning import LightningModule from torchmdnet.models.model import create_model, load_model +from torchmdnet.models.utils import dtype_mapping +import torch_geometric.transforms as T +class FloatCastDatasetWrapper(T.BaseTransform): + """A transform that casts all floating point tensors to a given dtype. + tensors to a given dtype. + """ + + def __init__(self, dtype=torch.float64): + super(FloatCastDatasetWrapper, self).__init__() + self._dtype = dtype + + def forward(self, data): + for key, value in data: + if torch.is_tensor(value) and torch.is_floating_point(value): + setattr(data, key, value.to(self._dtype)) + return data + + +class EnergyRefRemover(T.BaseTransform): + """A transform that removes the atom reference energy from the energy of a + dataset. + """ + + def __init__(self, atomref): + super(EnergyRefRemover, self).__init__() + self._atomref = atomref + + def forward(self, data): + if "y" in data: + data.y -= self._atomref[data.z].sum() + return data class LNNP(LightningModule): """ Lightning wrapper for the Neural Network Potentials in TorchMD-Net. + + Args: + hparams (dict): A dictionary containing the hyperparameters of the model. + prior_model (torchmdnet.priors.BasePrior): A prior model to use in the model. + mean (torch.Tensor, optional): The mean of the dataset to normalize the input. + std (torch.Tensor, optional): The standard deviation of the dataset to normalize the input. """ def __init__(self, hparams, prior_model=None, mean=None, std=None): @@ -41,6 +78,12 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): self.losses = None self._reset_losses_dict() + self.data_transform = FloatCastDatasetWrapper(dtype_mapping[self.hparams.precision]) + if self.hparams.remove_ref_energy: + self.data_transform = T.Compose( + [EnergyRefRemover(self.model.prior_model[-1].initial_atomref), self.data_transform] + ) + def configure_optimizers(self): optimizer = AdamW( self.model.parameters(), @@ -145,6 +188,7 @@ def step(self, batch, loss_fn_list, stage): # total_loss: sum of all losses (weighted by the loss weights) for the last loss function in the provided list assert len(loss_fn_list) > 0 assert self.losses is not None + batch = self.data_transform(batch) with torch.set_grad_enabled(stage == "train" or self.hparams.derivative): extra_args = batch.to_dict() for a in ("y", "neg_dy", "z", "pos", "batch", "box", "q", "s"): diff --git a/torchmdnet/priors/atomref.py b/torchmdnet/priors/atomref.py index d7ce9dc14..677bd9562 100644 --- a/torchmdnet/priors/atomref.py +++ b/torchmdnet/priors/atomref.py @@ -11,11 +11,25 @@ class Atomref(BasePrior): r"""Atomref prior model. - When using this in combination with some dataset, the dataset class must implement - the function `get_atomref`, which returns the atomic reference values as a tensor. + + This prior model is used to add atomic reference values to the input features. The atomic reference values are stored in an embedding layer and are added to the input features as: + + .. math:: + + x' = x + \\textrm{atomref}(z) + + where :math:`x` is the input feature tensor, :math:`z` is the atomic number tensor, and :math:`\\textrm{atomref}` is the embedding layer. The atomic reference values are stored in the embedding layer and can be trainable. + + When using this in combination with some dataset, the dataset class must implement the function `get_atomref`, which returns the atomic reference values as a tensor. + + Args: + max_z (int, optional): Maximum atomic number to consider. If `dataset` is not `None`, this argument is ignored. + dataset (torch_geometric.data.Dataset, optional): A dataset from which to extract the atomref values. + trainable (bool, optional): If `False`, the atomref values are not trainable. (default: `False`) + enable (bool, optional): If `False`, the prior is disabled. This is useful if you want to add the reference energies only during inference (or training) (default: `True`) """ - def __init__(self, max_z=None, dataset=None): + def __init__(self, max_z=None, dataset=None, trainable=False, enable=True): super().__init__() if max_z is None and dataset is None: raise ValueError("Can't instantiate Atomref prior, all arguments are None.") @@ -33,21 +47,49 @@ def __init__(self, max_z=None, dataset=None): if atomref.ndim == 1: atomref = atomref.view(-1, 1) self.register_buffer("initial_atomref", atomref) - self.atomref = nn.Embedding(len(atomref), 1) - self.atomref.weight.data.copy_(atomref) + self.atomref = nn.Embedding( + len(atomref), 1, _freeze=not trainable, _weight=atomref + ) + self.enable = enable def reset_parameters(self): self.atomref.weight.data.copy_(self.initial_atomref) def get_init_args(self): - return dict(max_z=self.initial_atomref.size(0)) + return dict( + max_z=self.initial_atomref.size(0), + trainable=self.atomref.weight.requires_grad, + enable=self.enable, + ) - def pre_reduce(self, x: Tensor, z: Tensor, pos: Tensor, batch: Optional[Tensor] = None, extra_args: Optional[Dict[str, Tensor]] = None): - """Adds the stored atomref to the input as: + def pre_reduce( + self, + x: Tensor, + z: Tensor, + pos: Tensor, + batch: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, + ): + """Applies the stored atomref to the input as: .. math:: x' = x + \\textrm{atomref}(z) + .. note:: The atomref operation is an embedding lookup that can be trainable if the `trainable` argument is set to `False`. + + .. note:: This call becomes a no-op if the `enable` argument is set to `False`. + + Args: + x (Tensor): Input feature tensor. + z (Tensor): Atomic number tensor. + pos (Tensor): Atomic positions tensor. Unused. + batch (Tensor, optional): Batch tensor. Unused. (default: `None`). + extra_args (Dict[str, Tensor], optional): Extra arguments. Unused. (default: `None`) + + """ - return x + self.atomref(z) + if self.enable: + return x + self.atomref(z) + else: + return x diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 307fc91d0..46999670b 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -58,7 +58,7 @@ def get_argparse(): parser.add_argument('--num-workers', type=int, default=4, help='Number of workers for data prefetch') parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log') parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm') - + parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') # dataset specific parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")') @@ -146,6 +146,13 @@ def get_args(): def main(): args = get_args() + if args.remove_ref_energy: + if args.prior_model is None: + args.prior_model = [] + if not isinstance(args.prior_model, list): + args.prior_model = [args.prior_model] + args.prior_model.append({"Atomref":{"enable":False}}) + pl.seed_everything(args.seed, workers=True) # initialize data module @@ -155,7 +162,6 @@ def main(): prior_models = create_prior_models(vars(args), data.dataset) args.prior_args = [p.get_init_args() for p in prior_models] - # initialize lightning module model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std)