Skip to content

Commit

Permalink
Small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Oct 31, 2023
1 parent 5d16163 commit e54a9fb
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 25 deletions.
5 changes: 3 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,16 @@ def get_latest_git_tag(repo_path="."):

autoclass_content = "both"
autodoc_typehints = "none"
autodoc_inherit_docstrings = False
autodoc_inherit_docstrings = True
sphinx_autodoc_typehints = True
html_show_sourcelink = True
autodoc_default_options = {
"members": True,
"member-order": "bysource",
"undoc-members": True,
"exclude-members": "__weakref__",
"undoc-members": False,
"show-inheritance": True,
"inherited-members": False,
}
# Exclude all torchmdnet.datasets.*.rst files in source/generated/
exclude_patterns = [
Expand Down
2 changes: 1 addition & 1 deletion docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Available Datasets
------------------

.. automodule:: torchmdnet.datasets
:no-index:
:noindex:

.. include:: generated/torchmdnet.datasets.rst
:start-line: 5
52 changes: 35 additions & 17 deletions torchmdnet/datasets/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ANIBase(Dataset):
- Smith, J. S., Zubatyuk, R., Nebgen, B., Lubbers, N., Barros, K., Roitberg, A. E., Isayev, O., & Tretiak, S. (2020). The ANI-1ccx and ANI-1x data sets, coupled-cluster and density functional theory properties for molecules. Scientific Data, 7, Article 134.
"""

_HARTREE_TO_EV = 27.211386246
HARTREE_TO_EV = 27.211386246 #::meta private:

@property
def raw_url(self):
Expand All @@ -43,7 +43,7 @@ def raw_file_names(self):
def compute_reference_energy(self, atomic_numbers):
atomic_numbers = np.array(atomic_numbers)
energy = sum(self._ELEMENT_ENERGIES[z] for z in atomic_numbers)
return energy * ANIBase._HARTREE_TO_EV
return energy * ANIBase.HARTREE_TO_EV

def sample_iter(self, mol_ids=False):
raise NotImplementedError()
Expand Down Expand Up @@ -169,7 +169,21 @@ def len(self):
return len(self.y_mm)

def get(self, idx):
"""Get a single sample from the dataset.
Data object contains the following attributes by default:
- :obj:`z` (:class:`torch.LongTensor`): Atomic numbers of shape :obj:`[num_nodes]`.
- :obj:`pos` (:class:`torch.FloatTensor`): Atomic positions of shape :obj:`[num_nodes, 3]`.
- :obj:`y` (:class:`torch.FloatTensor`): Energies of shape :obj:`[1, 1]`.
- :obj:`neg_dy` (:class:`torch.FloatTensor`, *optional*): Negative gradients of shape :obj:`[num_nodes, 3]`.
Args:
idx (int): Index of the sample.
Returns:
:class:`torch_geometric.data.Data`: The data object.
"""
atoms = slice(self.idx_mm[idx], self.idx_mm[idx + 1])
z = pt.tensor(self.z_mm[atoms], dtype=pt.long)
pos = pt.tensor(self.pos_mm[atoms], dtype=pt.float32)
Expand All @@ -187,12 +201,13 @@ def get(self, idx):

class ANI1(ANIBase):
__doc__ = ANIBase.__doc__
_ELEMENT_ENERGIES = {
# Avoid sphinx from documenting this
ELEMENT_ENERGIES = {
1: -0.500607632585,
6: -37.8302333826,
7: -54.5680045287,
8: -75.0362229210,
}
} #::meta private:

@property
def raw_url(self):
Expand Down Expand Up @@ -222,7 +237,7 @@ def sample_iter(self, mol_ids=False):
)
all_pos = pt.tensor(mol["coordinates"][:], dtype=pt.float32)
all_y = pt.tensor(
mol["energies"][:] * self._HARTREE_TO_EV, dtype=pt.float64
mol["energies"][:] * self.HARTREE_TO_EV, dtype=pt.float64
)

assert all_pos.shape[0] == all_y.shape[0]
Expand All @@ -244,10 +259,10 @@ def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
"""
refs = pt.zeros(max_z)
refs[1] = -0.500607632585 * self._HARTREE_TO_EV # H
refs[6] = -37.8302333826 * self._HARTREE_TO_EV # C
refs[7] = -54.5680045287 * self._HARTREE_TO_EV # N
refs[8] = -75.0362229210 * self._HARTREE_TO_EV # O
refs[1] = -0.500607632585 * self.HARTREE_TO_EV # H
refs[6] = -37.8302333826 * self.HARTREE_TO_EV # C
refs[7] = -54.5680045287 * self.HARTREE_TO_EV # N
refs[8] = -75.0362229210 * self.HARTREE_TO_EV # O

return refs.view(-1, 1)

Expand Down Expand Up @@ -278,23 +293,26 @@ def get_atomref(self, max_z=100):
warnings.warn("Atomic references from the ANI-1 dataset are used!")

refs = pt.zeros(max_z)
refs[1] = -0.500607632585 * self._HARTREE_TO_EV # H
refs[6] = -37.8302333826 * self._HARTREE_TO_EV # C
refs[7] = -54.5680045287 * self._HARTREE_TO_EV # N
refs[8] = -75.0362229210 * self._HARTREE_TO_EV # O
refs[1] = -0.500607632585 * self.HARTREE_TO_EV # H
refs[6] = -37.8302333826 * self.HARTREE_TO_EV # C
refs[7] = -54.5680045287 * self.HARTREE_TO_EV # N
refs[8] = -75.0362229210 * self.HARTREE_TO_EV # O

return refs.view(-1, 1)


class ANI1X(ANI1XBase):
__doc__ = ANIBase.__doc__
_ELEMENT_ENERGIES = {
ELEMENT_ENERGIES = {
1: -0.500607632585,
6: -37.8302333826,
7: -54.5680045287,
8: -75.0362229210,
}
"""
:meta private:
"""
def sample_iter(self, mol_ids=False):

assert len(self.raw_paths) == 1
Expand All @@ -305,10 +323,10 @@ def sample_iter(self, mol_ids=False):
z = pt.tensor(mol["atomic_numbers"][:], dtype=pt.long)
all_pos = pt.tensor(mol["coordinates"][:], dtype=pt.float32)
all_y = pt.tensor(
mol["wb97x_dz.energy"][:] * self._HARTREE_TO_EV, dtype=pt.float64
mol["wb97x_dz.energy"][:] * self.HARTREE_TO_EV, dtype=pt.float64
)
all_neg_dy = pt.tensor(
mol["wb97x_dz.forces"][:] * self._HARTREE_TO_EV, dtype=pt.float32
mol["wb97x_dz.forces"][:] * self.HARTREE_TO_EV, dtype=pt.float32
)

assert all_pos.shape[0] == all_y.shape[0]
Expand Down Expand Up @@ -356,7 +374,7 @@ def sample_iter(self, mol_ids=False):
z = pt.tensor(mol["atomic_numbers"][:], dtype=pt.long)
all_pos = pt.tensor(mol["coordinates"][:], dtype=pt.float32)
all_y = pt.tensor(
mol["ccsd(t)_cbs.energy"][:] * self._HARTREE_TO_EV, dtype=pt.float64
mol["ccsd(t)_cbs.energy"][:] * self.HARTREE_TO_EV, dtype=pt.float64
)

assert all_pos.shape[0] == all_y.shape[0]
Expand Down
11 changes: 6 additions & 5 deletions torchmdnet/datasets/qm9q.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

class QM9q(Dataset):

HARTREE_TO_EV = 27.211386246
BORH_TO_ANGSTROM = 0.529177
DEBYE_TO_EANG = 0.2081943 # Debey -> e*A

HARTREE_TO_EV = 27.211386246#::meta private:
BORH_TO_ANGSTROM = 0.529177 #::meta private:
DEBYE_TO_EANG = 0.2081943 #::meta private: Debey -> e*A

# Ion energies of elements
ELEMENT_ENERGIES = {
Expand All @@ -19,13 +20,13 @@ class QM9q(Dataset):
7: {-1: -54.4626446440, 0: -54.5269367415, 1: -53.9895574739},
8: {-1: -74.9699154500, 0: -74.9812632126, 1: -74.4776884006},
9: {-1: -99.6695561536, 0: -99.6185158728},
}
} #::meta private:

# Select an ion with the lowest energy for each element
INITIAL_CHARGES = {
element: sorted(zip(charges.values(), charges.keys()))[0][1]
for element, charges in ELEMENT_ENERGIES.items()
}
} #::meta private:

def __init__(
self,
Expand Down

0 comments on commit e54a9fb

Please sign in to comment.