Skip to content

Commit

Permalink
update pairfinders with memory to send trimmed lists
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkle-lanl committed May 9, 2024
1 parent a0389ac commit 602b5d9
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 206 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
energy_node = model.node_from_name("energy")
force_node = physics.GradientNode("force", (energy_node, positions_node), sign=-1)

# Replace pair-finder with more efficient one (the HippynnCalculator also does this)
# Replace pair-finder with more efficient one so that system can fit on GPU
old_pairs_node = model.node_from_name("PairIndexer")
species_node = model.node_from_name("species")
cell_node = model.node_from_name("cell")
Expand Down
194 changes: 0 additions & 194 deletions examples/molecular_dynamics_custom.py

This file was deleted.

23 changes: 16 additions & 7 deletions hippynn/layers/pairs/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cell, mol_i
# print("Pairs found",pair_first.shape)
coordflat = coordinates.reshape(n_molecules * n_atoms_max, 3)[real_atoms]
paircoord = coordflat[pair_first] - coordflat[pair_second] + pair_offsets
distflat2 = paircoord.norm(dim=1)
distflat = paircoord.norm(dim=1)

return distflat2, pair_first, pair_second, paircoord, offsets, offset_index
return distflat, pair_first, pair_second, paircoord, offsets, offset_index


class NPNeighbors(_DispatchNeighbors):
Expand Down Expand Up @@ -336,7 +336,7 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells, mol_

inputs = (coordinates, nonblank, real_atoms, inv_real_atoms, cells, mol_index, n_molecules, n_atoms_max)
outputs = self._pair_indexer(*inputs)
distflat2, pair_first, pair_second, paircoord, offsets, offset_index = outputs
distflat, pair_first, pair_second, paircoord, offsets, offset_index = outputs

with torch.no_grad():
pair_mol = mol_index[pair_first]
Expand All @@ -359,7 +359,16 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells, mol_

coordflat = coordinates.reshape(n_molecules * n_atoms_max, 3)[real_atoms]
paircoord = coordflat[self.pair_first] - coordflat[self.pair_second] + self.pair_offsets
distflat2 = paircoord.norm(dim=1)

return distflat2, self.pair_first, self.pair_second, paircoord, self.offsets, self.offset_index

distflat = paircoord.norm(dim=1)

# We will trim the lists to only send forward relevant atoms, improving performance.
within_cutoff_pairs = distflat < self.hard_dist_cutoff

return (
distflat[within_cutoff_pairs],
self.pair_first[within_cutoff_pairs],
self.pair_second[within_cutoff_pairs],
paircoord[within_cutoff_pairs],
self.offsets[within_cutoff_pairs],
self.offset_index[within_cutoff_pairs],
)
18 changes: 14 additions & 4 deletions hippynn/layers/pairs/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells):

inputs = (coordinates, nonblank, real_atoms, inv_real_atoms, cells)
outputs = self._pair_indexer(*inputs)
distflat2, pair_first, pair_second, paircoord, cell_offsets, offset_num, pair_mol = outputs
distflat, pair_first, pair_second, paircoord, cell_offsets, offset_num, pair_mol = outputs

for name, var in [
("cell_offsets", cell_offsets),
Expand All @@ -305,6 +305,16 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells):
pair_shifts = torch.matmul(self.cell_offsets.unsqueeze(1).to(cells.dtype), cells[self.pair_mol]).squeeze(1)
coordflat = coordinates.reshape(self.n_molecules * self.n_atoms, 3)[real_atoms]
paircoord = coordflat[self.pair_first] - coordflat[self.pair_second] + pair_shifts
distflat2 = paircoord.norm(dim=1)

return distflat2, self.pair_first, self.pair_second, paircoord, self.cell_offsets, self.offset_num
distflat = paircoord.norm(dim=1)

# We will trim the lists to only send forward relevant atoms, improving performance.
within_cutoff_pairs = distflat < self.hard_dist_cutoff

return (
distflat[within_cutoff_pairs],
self.pair_first[within_cutoff_pairs],
self.pair_second[within_cutoff_pairs],
paircoord[within_cutoff_pairs],
self.cell_offsets[within_cutoff_pairs],
self.offset_num[within_cutoff_pairs],
)

0 comments on commit 602b5d9

Please sign in to comment.