Skip to content

Commit

Permalink
Document ACE Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Oct 30, 2023
1 parent 45e9d21 commit a113068
Showing 1 changed file with 91 additions and 1 deletion.
92 changes: 91 additions & 1 deletion torchmdnet/datasets/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,81 @@


class Ace(Dataset):
"""The ACE dataset.
This dataset is sourced from HDF5 files.
Mandatory HDF5 file attributes:
- `layout`: Must be set to `Ace`.
- `layout_version`: Can be `1.0` or `2.0`.
- `name`: Name of the dataset.
For `layout_version` 1.0:
- Files can contain multiple molecule groups directly under the root.
- Each molecule group contains:
- `atomic_numbers`: Atomic numbers of the atoms.
- `formal_charges`: Formal charges of the atoms. The sum is the molecule's total charge. Units: electron charges.
- `conformations` subgroup: This subgroup has individual conformation groups, each with datasets for different properties of the conformation.
For `layout_version` 2.0:
- Files contain a single root group (e.g., a 'master molecule group').
- Within this root group, there can be multiple molecule groups.
- Each molecule group contains:
- `atomic_numbers`: Atomic numbers of the atoms.
- `formal_charges`: Formal charges of the atoms.
- Datasets for multiple conformations directly, without individual conformation groups.
Each conformation group (version 1.0) or molecule group (version 2.0) should have the following datasets:
- `positions`: Atomic positions. Units: Angstrom.
- `forces`: Forces on the atoms. Units: eV/Å.
- `partial_charges`: Atomic partial charges. Units: electron charges.
- `dipole_moment`: Molecule's dipole moment. Units: e*Å.
- `formation_energy` (version 1.0) or `formation_energies` (version 2.0): Formation energy. Units: eV.
Each dataset should also have an `units` attribute specifying its units (i.e., `Å`, `eV`, `e*Å`).
Args:
root (string, optional): Root directory where the dataset should be saved.
transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version.
pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version.
pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset.
paths (string or list): Path to the HDF5 files or directory containing the HDF5 files.
max_gradient (float, optional): Maximum gradient norm. Samples with larger gradients are discarded.
subsample_molecules (int, optional): Subsample molecules. Only every `subsample_molecules` molecule is used.
Examples::
>>> import numpy as np
>>> from torchmdnet.datasets import Ace
>>> import h5py
>>>
>>> with h5py.File("molecule.h5", 'w') as f:
... f.attrs["layout"] = "Ace"
... f.attrs["layout_version"] = "1.0"
... f.attrs["name"] = "sample_molecule_data"
... for m in range(3): # Three molecules
... mol = f.create_group(f"mol_{m+1}")
... mol["atomic_numbers"] = [1, 6, 8] # H, C, O
... mol["formal_charges"] = [0, 0, 0] # Neutral charges
... confs = mol.create_group("conformations")
... for i in range(2): # Two conformations
... conf = confs.create_group(f"conf_{i+1}")
... conf["positions"] = np.random.random((3, 3))
... conf["positions"].attrs["units"] = "Å"
... conf["formation_energy"] = np.random.random()
... conf["formation_energy"].attrs["units"] = "eV"
... conf["forces"] = np.random.random((3, 3))
... conf["forces"].attrs["units"] = "eV/Å"
... conf["partial_charges"] = np.random.random(3)
... conf["partial_charges"].attrs["units"] = "e"
... conf["dipole_moment"] = np.random.random(3)
... conf["dipole_moment"].attrs["units"] = "e*Å"
>>> dataset = Ace(root=".", paths="molecule.h5")
>>> len(dataset)
6
>>> dataset = Ace(root=".", paths=["molecule.h5", "molecule.h5"])
>>> len(dataset)
12
"""

def __init__(
self,
root=None,
Expand Down Expand Up @@ -305,7 +380,22 @@ def len(self):
return len(self.y_mm)

def get(self, idx):

"""Gets the data object at index :obj:`idx`.
The data object contains the following attributes:
- :obj:`z`: Atomic numbers of the atoms.
- :obj:`pos`: Positions of the atoms.
- :obj:`y`: Formation energy of the molecule.
- :obj:`neg_dy`: Forces on the atoms.
- :obj:`q`: Total charge of the molecule.
- :obj:`pq`: Partial charges of the atoms.
- :obj:`dp`: Dipole moment of the molecule.
Args:
idx (int): Index of the data object.
Returns:
:obj:`torch_geometric.data.Data`: The data object.
"""
atoms = slice(self.idx_mm[idx], self.idx_mm[idx + 1])
z = pt.tensor(self.z_mm[atoms], dtype=pt.long)
pos = pt.tensor(self.pos_mm[atoms], dtype=pt.float32)
Expand Down

0 comments on commit a113068

Please sign in to comment.