diff --git a/tests/test_priors.py b/tests/test_priors.py index 6def5bbfe..1c00af60d 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -9,7 +9,7 @@ from torchmdnet import models from torchmdnet.models.model import create_model, create_prior_models from torchmdnet.module import LNNP -from torchmdnet.priors import Atomref, D2, ZBL, Coulomb +from torchmdnet.priors import Atomref, LearnableAtomref, D2, ZBL, Coulomb from torchmdnet.models.utils import scatter from utils import load_example_args, create_example_batch, DummyDataset from os.path import dirname, join @@ -49,6 +49,10 @@ def test_atomref_trainable(trainable): atomref = Atomref(max_z=100, dataset=dataset, trainable=trainable) assert atomref.atomref.weight.requires_grad == trainable +def test_learnableatomref(): + atomref = LearnableAtomref(max_z=100) + assert atomref.atomref.weight.requires_grad == True + def test_zbl(): pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types diff --git a/torchmdnet/priors/__init__.py b/torchmdnet/priors/__init__.py index 399588524..7b6450566 100644 --- a/torchmdnet/priors/__init__.py +++ b/torchmdnet/priors/__init__.py @@ -2,9 +2,9 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from torchmdnet.priors.atomref import Atomref +from torchmdnet.priors.atomref import Atomref, LearnableAtomref from torchmdnet.priors.d2 import D2 from torchmdnet.priors.zbl import ZBL from torchmdnet.priors.coulomb import Coulomb -__all__ = ["Atomref", "D2", "ZBL", "Coulomb"] +__all__ = ["Atomref", "LearnableAtomref", "D2", "ZBL", "Coulomb"] diff --git a/torchmdnet/priors/atomref.py b/torchmdnet/priors/atomref.py index 7613be43e..f2114bfaa 100644 --- a/torchmdnet/priors/atomref.py +++ b/torchmdnet/priors/atomref.py @@ -94,3 +94,24 @@ def pre_reduce( return x + self.atomref(z) else: return x + + + +class LearnableAtomref(Atomref): + r"""LearnableAtomref prior model. + + This prior model is used to add learned atomic reference values to the input features. The atomic reference values are learned as an embedding layer and are added to the input features as: + + .. math:: + + x' = x + \\textrm{atomref}(z) + + where :math:`x` is the input feature tensor, :math:`z` is the atomic number tensor, and :math:`\\textrm{atomref}` is the embedding layer. + + + Args: + max_z (int, optional): Maximum atomic number to consider. + """ + + def __init__(self, max_z=None): + super().__init__(max_z, trainable=True, enable=True)