Skip to content

Commit

Permalink
refactored memmaped datasets into a separate class to reduce code dup…
Browse files Browse the repository at this point in the history
…lication
  • Loading branch information
stefdoerr committed Jan 24, 2024
1 parent d4778e1 commit 7b4922d
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 676 deletions.
183 changes: 8 additions & 175 deletions torchmdnet/datasets/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import hashlib
import h5py
import numpy as np
import os
import torch as pt
from torch_geometric.data import Dataset, Data
from torchmdnet.datasets.memdataset import MemmappedDataset
from torch_geometric.data import Data
from tqdm import tqdm


class Ace(Dataset):
class Ace(MemmappedDataset):
"""The ACE dataset.
This dataset is sourced from HDF5 files.
Expand Down Expand Up @@ -141,44 +141,15 @@ def __init__(
self.paths = paths
self.max_gradient = max_gradient
self.subsample_molecules = int(subsample_molecules)
super().__init__(root, transform, pre_transform, pre_filter)

(
idx_name,
z_name,
pos_name,
y_name,
neg_dy_name,
q_name,
pq_name,
dp_name,
) = self.processed_paths
self.idx_mm = np.memmap(idx_name, mode="r", dtype=np.int64)
self.z_mm = np.memmap(z_name, mode="r", dtype=np.int8)
self.pos_mm = np.memmap(
pos_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
super().__init__(
root, transform, pre_transform, pre_filter, remove_ref_energy=False
)
self.y_mm = np.memmap(y_name, mode="r", dtype=np.float64)
self.neg_dy_mm = np.memmap(
neg_dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3)
)
self.q_mm = np.memmap(q_name, mode="r", dtype=np.int8)
self.pq_mm = np.memmap(pq_name, mode="r", dtype=np.float32)
self.dp_mm = np.memmap(
dp_name, mode="r", dtype=np.float32, shape=(self.y_mm.shape[0], 3)
)

assert self.idx_mm[0] == 0
assert self.idx_mm[-1] == len(self.z_mm)
assert len(self.idx_mm) == len(self.y_mm) + 1

@property
def raw_paths(self):

paths_init = self.paths if isinstance(self.paths, list) else [self.paths]
paths = []
for path in paths_init:

if os.path.isfile(path):
paths.append(path)
continue
Expand All @@ -195,9 +166,7 @@ def raw_paths(self):

@staticmethod
def _load_confs_1_0(mol, n_atoms):

for conf in mol["conformations"].values():

# Skip failed calculations
if "formation_energy" not in conf:
continue
Expand Down Expand Up @@ -226,7 +195,6 @@ def _load_confs_1_0(mol, n_atoms):

@staticmethod
def _load_confs_2_0(mol, n_atoms):

assert mol["positions"].attrs["units"] == "Å"
all_pos = pt.tensor(mol["positions"][...], dtype=pt.float32)
n_confs = all_pos.shape[0]
Expand All @@ -249,19 +217,16 @@ def _load_confs_2_0(mol, n_atoms):
assert all_dp.shape == (n_confs, 3)

for pos, y, neg_dy, pq, dp in zip(all_pos, all_y, all_neg_dy, all_pq, all_dp):

# Skip failed calculations
if y.isnan():
continue

yield pos, y, neg_dy, pq, dp

def sample_iter(self, mol_ids=False):

assert self.subsample_molecules > 0

for path in tqdm(self.raw_paths, desc="Files"):

h5 = h5py.File(path)
assert h5.attrs["layout"] == "Ace"
version = h5.attrs["layout_version"]
Expand All @@ -286,7 +251,6 @@ def sample_iter(self, mol_ids=False):
total=len(mols),
leave=False,
):

# Subsample molecules
if i_mol % self.subsample_molecules != 0:
continue
Expand All @@ -295,8 +259,9 @@ def sample_iter(self, mol_ids=False):
fq = pt.tensor(mol["formal_charges"], dtype=pt.long)
q = fq.sum()

for i_conf, (pos, y, neg_dy, pq, dp) in enumerate(load_confs(mol, n_atoms=len(z))):

for i_conf, (pos, y, neg_dy, pq, dp) in enumerate(
load_confs(mol, n_atoms=len(z))
):
# Skip samples with large forces
if self.max_gradient:
if neg_dy.norm(dim=1).max() > float(self.max_gradient):
Expand All @@ -318,135 +283,3 @@ def sample_iter(self, mol_ids=False):
data = self.pre_transform(data)

yield data

@property
def processed_file_names(self):
return [
f"{self.name}.idx.mmap",
f"{self.name}.z.mmap",
f"{self.name}.pos.mmap",
f"{self.name}.y.mmap",
f"{self.name}.neg_dy.mmap",
f"{self.name}.q.mmap",
f"{self.name}.pq.mmap",
f"{self.name}.dp.mmap",
]

def process(self):

print("Arguments")
print(f" max_gradient: {self.max_gradient} eV/A")
print(f" subsample_molecules: {self.subsample_molecules}\n")

print("Gathering statistics...")
num_all_confs = 0
num_all_atoms = 0
for data in self.sample_iter():
num_all_confs += 1
num_all_atoms += data.z.shape[0]

print(f" Total number of conformers: {num_all_confs}")
print(f" Total number of atoms: {num_all_atoms}")

(
idx_name,
z_name,
pos_name,
y_name,
neg_dy_name,
q_name,
pq_name,
dp_name,
) = self.processed_paths
idx_mm = np.memmap(
idx_name + ".tmp", mode="w+", dtype=np.int64, shape=num_all_confs + 1
)
z_mm = np.memmap(z_name + ".tmp", mode="w+", dtype=np.int8, shape=num_all_atoms)
pos_mm = np.memmap(
pos_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3)
)
y_mm = np.memmap(
y_name + ".tmp", mode="w+", dtype=np.float64, shape=num_all_confs
)
neg_dy_mm = np.memmap(
neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3)
)
q_mm = np.memmap(q_name + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs)
pq_mm = np.memmap(
pq_name + ".tmp", mode="w+", dtype=np.float32, shape=num_all_atoms
)
dp_mm = np.memmap(
dp_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_confs, 3)
)

print("Storing data...")
i_atom = 0
for i_conf, data in enumerate(self.sample_iter()):
i_next_atom = i_atom + data.z.shape[0]

idx_mm[i_conf] = i_atom
z_mm[i_atom:i_next_atom] = data.z.to(pt.int8)
pos_mm[i_atom:i_next_atom] = data.pos
y_mm[i_conf] = data.y
neg_dy_mm[i_atom:i_next_atom] = data.neg_dy
q_mm[i_conf] = data.q.to(pt.int8)
pq_mm[i_atom:i_next_atom] = data.pq
dp_mm[i_conf] = data.dp

i_atom = i_next_atom

idx_mm[-1] = num_all_atoms
assert i_atom == num_all_atoms

idx_mm.flush()
z_mm.flush()
pos_mm.flush()
y_mm.flush()
neg_dy_mm.flush()
q_mm.flush()
pq_mm.flush()
dp_mm.flush()

os.rename(idx_mm.filename, idx_name)
os.rename(z_mm.filename, z_name)
os.rename(pos_mm.filename, pos_name)
os.rename(y_mm.filename, y_name)
os.rename(neg_dy_mm.filename, neg_dy_name)
os.rename(q_mm.filename, q_name)
os.rename(pq_mm.filename, pq_name)
os.rename(dp_mm.filename, dp_name)

def len(self):
return len(self.y_mm)

def get(self, idx):
"""Gets the data object at index :obj:`idx`.
The data object contains the following attributes:
- :obj:`z`: Atomic numbers of the atoms.
- :obj:`pos`: Positions of the atoms.
- :obj:`y`: Formation energy of the molecule.
- :obj:`neg_dy`: Forces on the atoms.
- :obj:`q`: Total charge of the molecule.
- :obj:`pq`: Partial charges of the atoms.
- :obj:`dp`: Dipole moment of the molecule.
Args:
idx (int): Index of the data object.
Returns:
:obj:`torch_geometric.data.Data`: The data object.
"""
atoms = slice(self.idx_mm[idx], self.idx_mm[idx + 1])
z = pt.tensor(self.z_mm[atoms], dtype=pt.long)
pos = pt.tensor(self.pos_mm[atoms], dtype=pt.float32)
y = pt.tensor(self.y_mm[idx], dtype=pt.float32).view(
1, 1
) # It would be better to use float64, but the trainer complaints
neg_dy = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32)
q = pt.tensor(self.q_mm[idx], dtype=pt.long)
pq = pt.tensor(self.pq_mm[atoms], dtype=pt.float32)
dp = pt.tensor(self.dp_mm[idx], dtype=pt.float32)

return Data(z=z, pos=pos, y=y, neg_dy=neg_dy, q=q, pq=pq, dp=dp)
Loading

0 comments on commit 7b4922d

Please sign in to comment.