Skip to content

Commit

Permalink
Add LearnableAtomref
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Feb 16, 2024
1 parent 93fe195 commit 10fafea
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
6 changes: 5 additions & 1 deletion tests/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchmdnet import models
from torchmdnet.models.model import create_model, create_prior_models
from torchmdnet.module import LNNP
from torchmdnet.priors import Atomref, D2, ZBL, Coulomb
from torchmdnet.priors import Atomref, LearnableAtomref, D2, ZBL, Coulomb
from torchmdnet.models.utils import scatter
from utils import load_example_args, create_example_batch, DummyDataset
from os.path import dirname, join
Expand Down Expand Up @@ -49,6 +49,10 @@ def test_atomref_trainable(trainable):
atomref = Atomref(max_z=100, dataset=dataset, trainable=trainable)
assert atomref.atomref.weight.requires_grad == trainable

def test_learnableatomref():
atomref = LearnableAtomref(max_z=100)
assert atomref.atomref.weight.requires_grad == True

def test_zbl():
pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
Expand Down
4 changes: 2 additions & 2 deletions torchmdnet/priors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

from torchmdnet.priors.atomref import Atomref
from torchmdnet.priors.atomref import Atomref, LearnableAtomref
from torchmdnet.priors.d2 import D2
from torchmdnet.priors.zbl import ZBL
from torchmdnet.priors.coulomb import Coulomb

__all__ = ["Atomref", "D2", "ZBL", "Coulomb"]
__all__ = ["Atomref", "LearnableAtomref", "D2", "ZBL", "Coulomb"]
21 changes: 21 additions & 0 deletions torchmdnet/priors/atomref.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,24 @@ def pre_reduce(
return x + self.atomref(z)
else:
return x



class LearnableAtomref(Atomref):
r"""LearnableAtomref prior model.
This prior model is used to add learned atomic reference values to the input features. The atomic reference values are learned as an embedding layer and are added to the input features as:
.. math::
x' = x + \\textrm{atomref}(z)
where :math:`x` is the input feature tensor, :math:`z` is the atomic number tensor, and :math:`\\textrm{atomref}` is the embedding layer.
Args:
max_z (int, optional): Maximum atomic number to consider.
"""

def __init__(self, max_z=None):
super().__init__(max_z, trainable=True, enable=True)

0 comments on commit 10fafea

Please sign in to comment.