diff --git a/openmmml/models/macepotential.py b/openmmml/models/macepotential.py index 0cf1397..03b2077 100644 --- a/openmmml/models/macepotential.py +++ b/openmmml/models/macepotential.py @@ -283,13 +283,6 @@ def __init__( self.register_buffer("batch", torch.zeros(nodeAttrs.shape[0], dtype=torch.long, requires_grad=False)) self.register_buffer("pbc", torch.tensor([periodic, periodic, periodic], dtype=torch.bool, requires_grad=False)) - self.inputDict = { - "ptr": self.ptr, - "node_attrs": self.node_attrs, - "batch": self.batch, - "pbc": self.pbc, - } - def _getNeighborPairs( self, positions: torch.Tensor, cell: Optional[torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -369,12 +362,19 @@ def forward( edgeIndex, shifts = self._getNeighborPairs(positions, cell) # Update input dictionary. - self.inputDict["positions"] = positions - self.inputDict["edge_index"] = edgeIndex - self.inputDict["shifts"] = shifts + inputDict = { + "ptr": self.ptr, + "node_attrs": self.node_attrs, + "batch": self.batch, + "pbc": self.pbc, + "positions": positions, + "edge_index": edgeIndex, + "shifts": shifts, + "cell": cell if cell is not None else torch.zeros(3, 3, dtype=self.dtype), + } # Predict the energy. - energy = self.model(self.inputDict, compute_force=False)[ + energy = self.model(inputDict, compute_force=False)[ self.returnEnergyType ]