Skip to content

Commit

Permalink
fix the EnergyRefRemover which was accidentally removing the total su…
Browse files Browse the repository at this point in the history
…m of all atomref energies of the whole batch from the energies
  • Loading branch information
stefdoerr committed Feb 16, 2024
1 parent 94cc325 commit 500b04c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def __init__(self, atomref):
self._atomref = atomref

def forward(self, data):
self._atomref = self._atomref.to(data.z.device)
self._atomref = self._atomref.to(data.z.device).type(data.y.dtype)
if "y" in data:
data.y -= self._atomref[data.z].sum()
data.y.index_add_(0, data.batch, -self._atomref[data.z])
return data


Expand Down

0 comments on commit 500b04c

Please sign in to comment.