Skip to content

Commit

Permalink
Add error checking
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Jun 28, 2024
1 parent 3dc2b90 commit 603e795
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions torchmdnet/datasets/maceoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def sample_iter(self, mol_ids=False):
neg_dy=pt.tensor(forces, dtype=pt.float32),
)
)
assert data.y.shape == (1, 1)
assert data.z.shape[0] == data.pos.shape[0]
assert data.neg_dy.shape[0] == data.pos.shape[0]
# Skip samples with large forces
if self.max_gradient:
if data.neg_dy.norm(dim=1).max() > float(self.max_gradient):
Expand Down

0 comments on commit 603e795

Please sign in to comment.