Skip to content

Commit

Permalink
Use @stefdoerr xyz parser
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Jul 9, 2024
1 parent 33cd988 commit 191f454
Showing 1 changed file with 62 additions and 42 deletions.
104 changes: 62 additions & 42 deletions torchmdnet/datasets/maceoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,54 @@
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

import hashlib
import ase
import h5py
from ase.data import atomic_numbers
import numpy as np
import os
import torch as pt
from torchmdnet.datasets.memdataset import MemmappedDataset
from torch_geometric.data import Data, download_url
from tqdm import tqdm
import tarfile
import tempfile
import re
import ase.io
import logging
import re
from tqdm import tqdm


def parse_maceoff_tar(tar_file):
energy_re = re.compile("energy=(\S+)")
with tarfile.open(tar_file, "r:gz") as tar:
for member in tar.getmembers():
f = tar.extractfile(member)
if f is None:
continue
n_atoms = None
counter = 0
positions = []
numbers = []
forces = []
energy = None
for line in f:
line = line.decode("utf-8").strip()
if n_atoms is None:
n_atoms = int(line)
positions = []
numbers = []
forces = []
energy = None
counter = 1
continue
if counter == 1:
props = line
energy = float(energy_re.search(props).group(1))
counter = 2
continue
el, x, y, z, fx, fy, fz, _, _, _ = line.split()
numbers.append(atomic_numbers[el])
positions.append([float(x), float(y), float(z)])
forces.append([float(fx), float(fy), float(fz)])
counter += 1
if counter == n_atoms + 2:
n_atoms = None
yield energy, numbers, positions, forces


class MACEOFF(MemmappedDataset):
Expand Down Expand Up @@ -75,44 +110,29 @@ def __init__(
def sample_iter(self, mol_ids=False):
assert len(self.raw_paths) == 1
logging.info(f"Processing dataset {self.raw_file_names}")
with tempfile.TemporaryDirectory() as tmp_dir:
tar_path = os.path.join(self.raw_dir, self.raw_file_names)
xyz_path = os.path.join(
tmp_dir, re.sub(r"\.tar\.gz$", ".xyz", self.raw_file_names)
)
logging.info(f"Extracting {tar_path} to {tmp_dir}")
with tarfile.open(tar_path, "r:gz") as tar:
tar.extractall(tmp_dir)
assert os.path.exists(xyz_path)
for mol in tqdm(ase.io.iread(xyz_path), desc="Processing conformations"):
energy = (
mol.info["energy"]
if "energy" in mol.info
else mol.get_potential_energy()
for energy, numbers, positions, forces in tqdm(
parse_maceoff_tar(self.raw_paths[0]), desc="Processing conformations"
):
data = Data(
**dict(
z=pt.tensor(np.array(numbers), dtype=pt.long),
pos=pt.tensor(positions, dtype=pt.float32),
y=pt.tensor(energy, dtype=pt.float64).view(1, 1),
neg_dy=pt.tensor(forces, dtype=pt.float32),
)
forces = (
mol.arrays["forces"] if "forces" in mol.arrays else mol.get_forces()
)
data = Data(
**dict(
z=pt.tensor(np.array(mol.numbers), dtype=pt.long),
pos=pt.tensor(mol.positions, dtype=pt.float32),
y=pt.tensor(energy, dtype=pt.float64).view(1, 1),
neg_dy=pt.tensor(forces, dtype=pt.float32),
)
)
assert data.y.shape == (1, 1)
assert data.z.shape[0] == data.pos.shape[0]
assert data.neg_dy.shape[0] == data.pos.shape[0]
# Skip samples with large forces
if self.max_gradient:
if data.neg_dy.norm(dim=1).max() > float(self.max_gradient):
continue
if self.pre_filter is not None and not self.pre_filter(data):
)
assert data.y.shape == (1, 1)
assert data.z.shape[0] == data.pos.shape[0]
assert data.neg_dy.shape[0] == data.pos.shape[0]
# Skip samples with large forces
if self.max_gradient:
if data.neg_dy.norm(dim=1).max() > float(self.max_gradient):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
yield data
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
yield data

def download(self):
download_url(self.raw_url, self.raw_dir, filename=self.raw_file_names)

0 comments on commit 191f454

Please sign in to comment.