Skip to content

Commit

Permalink
Use __call__ instead of forward for compatibility with previous
Browse files Browse the repository at this point in the history
versions of geometric
  • Loading branch information
RaulPPelaez committed Jun 10, 2024
1 parent dbc1681 commit 84d6e6a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, dtype=torch.float64):
super(FloatCastDatasetWrapper, self).__init__()
self._dtype = dtype

def forward(self, data):
def __call__(self, data):
for key, value in data:
if torch.is_tensor(value) and torch.is_floating_point(value):
setattr(data, key, value.to(self._dtype))
Expand All @@ -41,7 +41,7 @@ def __init__(self, atomref):
super(EnergyRefRemover, self).__init__()
self._atomref = atomref

def forward(self, data):
def __call__(self, data):
self._atomref = self._atomref.to(data.z.device).type(data.y.dtype)
if "y" in data:
data.y.index_add_(0, data.batch, -self._atomref[data.z])
Expand Down

0 comments on commit 84d6e6a

Please sign in to comment.