diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index e2470dcb2..75869a771 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -8,6 +8,81 @@ class Ace(Dataset): + """The ACE dataset. + + This dataset is sourced from HDF5 files. + + Mandatory HDF5 file attributes: + - `layout`: Must be set to `Ace`. + - `layout_version`: Can be `1.0` or `2.0`. + - `name`: Name of the dataset. + + For `layout_version` 1.0: + - Files can contain multiple molecule groups directly under the root. + - Each molecule group contains: + - `atomic_numbers`: Atomic numbers of the atoms. + - `formal_charges`: Formal charges of the atoms. The sum is the molecule's total charge. Units: electron charges. + - `conformations` subgroup: This subgroup has individual conformation groups, each with datasets for different properties of the conformation. + + For `layout_version` 2.0: + - Files contain a single root group (e.g., a 'master molecule group'). + - Within this root group, there can be multiple molecule groups. + - Each molecule group contains: + - `atomic_numbers`: Atomic numbers of the atoms. + - `formal_charges`: Formal charges of the atoms. + - Datasets for multiple conformations directly, without individual conformation groups. + + Each conformation group (version 1.0) or molecule group (version 2.0) should have the following datasets: + - `positions`: Atomic positions. Units: Angstrom. + - `forces`: Forces on the atoms. Units: eV/Å. + - `partial_charges`: Atomic partial charges. Units: electron charges. + - `dipole_moment`: Molecule's dipole moment. Units: e*Å. + - `formation_energy` (version 1.0) or `formation_energies` (version 2.0): Formation energy. Units: eV. + Each dataset should also have an `units` attribute specifying its units (i.e., `Å`, `eV`, `e*Å`). + + Args: + root (string, optional): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. + pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. + pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. + paths (string or list): Path to the HDF5 files or directory containing the HDF5 files. + max_gradient (float, optional): Maximum gradient norm. Samples with larger gradients are discarded. + subsample_molecules (int, optional): Subsample molecules. Only every `subsample_molecules` molecule is used. + + Examples:: + >>> import numpy as np + >>> from torchmdnet.datasets import Ace + >>> import h5py + >>> + >>> with h5py.File("molecule.h5", 'w') as f: + ... f.attrs["layout"] = "Ace" + ... f.attrs["layout_version"] = "1.0" + ... f.attrs["name"] = "sample_molecule_data" + ... for m in range(3): # Three molecules + ... mol = f.create_group(f"mol_{m+1}") + ... mol["atomic_numbers"] = [1, 6, 8] # H, C, O + ... mol["formal_charges"] = [0, 0, 0] # Neutral charges + ... confs = mol.create_group("conformations") + ... for i in range(2): # Two conformations + ... conf = confs.create_group(f"conf_{i+1}") + ... conf["positions"] = np.random.random((3, 3)) + ... conf["positions"].attrs["units"] = "Å" + ... conf["formation_energy"] = np.random.random() + ... conf["formation_energy"].attrs["units"] = "eV" + ... conf["forces"] = np.random.random((3, 3)) + ... conf["forces"].attrs["units"] = "eV/Å" + ... conf["partial_charges"] = np.random.random(3) + ... conf["partial_charges"].attrs["units"] = "e" + ... conf["dipole_moment"] = np.random.random(3) + ... conf["dipole_moment"].attrs["units"] = "e*Å" + >>> dataset = Ace(root=".", paths="molecule.h5") + >>> len(dataset) + 6 + >>> dataset = Ace(root=".", paths=["molecule.h5", "molecule.h5"]) + >>> len(dataset) + 12 + """ + def __init__( self, root=None, @@ -305,7 +380,22 @@ def len(self): return len(self.y_mm) def get(self, idx): - + """Gets the data object at index :obj:`idx`. + The data object contains the following attributes: + - :obj:`z`: Atomic numbers of the atoms. + - :obj:`pos`: Positions of the atoms. + - :obj:`y`: Formation energy of the molecule. + - :obj:`neg_dy`: Forces on the atoms. + - :obj:`q`: Total charge of the molecule. + - :obj:`pq`: Partial charges of the atoms. + - :obj:`dp`: Dipole moment of the molecule. + + Args: + idx (int): Index of the data object. + + Returns: + :obj:`torch_geometric.data.Data`: The data object. + """ atoms = slice(self.idx_mm[idx], self.idx_mm[idx + 1]) z = pt.tensor(self.z_mm[atoms], dtype=pt.long) pos = pt.tensor(self.pos_mm[atoms], dtype=pt.float32)