diff --git a/torchmdnet/datasets/maceoff.py b/torchmdnet/datasets/maceoff.py index a4f795e1..0326a01d 100644 --- a/torchmdnet/datasets/maceoff.py +++ b/torchmdnet/datasets/maceoff.py @@ -101,6 +101,9 @@ def sample_iter(self, mol_ids=False): neg_dy=pt.tensor(forces, dtype=pt.float32), ) ) + assert data.y.shape == (1, 1) + assert data.z.shape[0] == data.pos.shape[0] + assert data.neg_dy.shape[0] == data.pos.shape[0] # Skip samples with large forces if self.max_gradient: if data.neg_dy.norm(dim=1).max() > float(self.max_gradient):