From 9b24b2130b21194061a72d351a641d89af23559b Mon Sep 17 00:00:00 2001 From: Emily Shinkle Date: Mon, 11 Dec 2023 15:18:16 -0700 Subject: [PATCH] Add fuzzy histogram, pair-finder memory, and KD Tree pair-finder nodes; fix bugs (#50) * add pre-interaction layers option to hippynn * revert pre-interaction layers, add fuzzy histogram feature * fix minor typos, add number pairs to warn_if_under function * add pair-finder with memory, fix minor bugs * fix bug in pair-finder with memory * * New KD Tree pair-finder node * Modularized pair-finder memory component * Typos corrected * revert my change to .gitignore, revise docstring * Updae change log and docs. Revert unneeded changes. --------- Co-authored-by: Emily Suzanne Shinkle --- CHANGELOG.rst | 20 +++ docs/source/examples/periodic.rst | 20 ++- examples/allegro_ag_example.py | 2 +- hippynn/databases/database.py | 4 +- hippynn/experiment/assembly.py | 4 +- hippynn/graphs/gops.py | 2 +- .../graphs/nodes/base/definition_helpers.py | 2 +- hippynn/graphs/nodes/indexers.py | 21 +++ hippynn/graphs/nodes/pairs.py | 61 ++++++++- hippynn/graphs/predictor.py | 2 +- hippynn/layers/hiplayers.py | 4 +- hippynn/layers/indexers.py | 34 +++++ hippynn/layers/pairs/dispatch.py | 128 +++++++++++++++++- hippynn/layers/pairs/indexing.py | 6 +- hippynn/layers/pairs/open.py | 75 ++++++++++ hippynn/layers/pairs/periodic.py | 54 +++++++- 16 files changed, 417 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c5c874bb..97d6e938 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,11 +15,31 @@ Improvements 0.0.2a2 ======= +New Features: +------------- + +- New FuzzyHistogrammer node for transforming scalar feature into a fuzzy/soft + histogram array + +- New PeriodicPairIndexerMemory node which removes the need to recompute + pairs for each model evaluation in some instances, leading to speed improvements + +- New KDTreePairs and KDTreePairsMemory nodes for computing pairs using linearly- + scaling KD Tree algorithm. + Improvements ------------ - ASE database loader added to read any ASE file or list of ASE files. +Bug Fixes: +---------- +- Function 'gemerate_database_info' renamed to 'generate_database_info.' + +- Fixed issue with class Predictor arising when multiple names for the same output node are provided. + +- Fixed issue with MolPairSummer when the batch size and the feature size are both one. + 0.0.2a1 ======= diff --git a/docs/source/examples/periodic.rst b/docs/source/examples/periodic.rst index 58fcf818..3995a0e5 100644 --- a/docs/source/examples/periodic.rst +++ b/docs/source/examples/periodic.rst @@ -20,7 +20,7 @@ to within the unit cell. Because the nearest images (27 replicates of the cell a search radius 1) are numerous, periodic pair finding is noticeably more costly in terms of memory and time than open boundary conditions. The less skewed your cells are, as well as are the larger cells are compared to the cutoff distance required, -the fewer images needed to be searched in finding pairs. +the fewer images needed to be searched in finding pairs. Dynamic Pair Finder @@ -42,7 +42,23 @@ the systems one by one. The upshot of this is that less memory is required. However, the cost is that each system is evaluated independently in serial, and as such the pair finding can be a rather slow operation. This algorithm is more likely to show benefits when the number of atoms in a training system is highly -variable. +variable. + +For systems with orthorhombic cells and an interaction radius not greater than any of the +cell side lengths, the :class:`~hippynn.graphs.nodes.pairs.KDTreePairs` can be used +alternatively. It should exhibit reduced computation times, especially for large systems. + +Pair Finder Memory +------------------ +When using a trained model to run MD or for any application where atom positions +change only slightly between subsquent model calls, +:class:`~hippynn.graphs.nodes.pairs.PeriodicPairIndexerMemory` and +:class:`~hippynn.graphs.nodes.pairs.KDTreePairsMemory` can be used to reduce run +time by reusing pair information. Current pair indices are stored in memory and +reused so long as no atom has moved more than `skin`/2, where `skin` is an additional +parameter set by the user. Increasing the value of `skin` will increase the number of +pair distances computed at each step, but decrease the number of times new pairs must +be computed. Skin should be set to zero while training for fastest results. Caching Pre-computed Pairs -------------------------- diff --git a/examples/allegro_ag_example.py b/examples/allegro_ag_example.py index 1e796a93..64f3ef7a 100644 --- a/examples/allegro_ag_example.py +++ b/examples/allegro_ag_example.py @@ -190,7 +190,7 @@ def fit_model(training_modules,database): with hippynn.tools.log_terminal("model_results.txt",'wt'): test_model(database, training_modules.evaluator, 128, "Final Training") - ## Possible to export lammps MLIPInterface for model if lammmps with MLIP Installed! + ## Possible to export lammps MLIPInterface for model if Lammps with MLIP Installed! # print("Exporting lammps interface") # first_frame = ase.io.read(dbname) # Reads in first frame only for saving box # ase.io.write('ag_box.data', first_frame, format='lammps-data') diff --git a/hippynn/databases/database.py b/hippynn/databases/database.py index e9375b49..4eada5b9 100644 --- a/hippynn/databases/database.py +++ b/hippynn/databases/database.py @@ -34,8 +34,8 @@ def __init__( :param inputs: list of strings for input db_names :param targets: list of strings for output db_namees :param seed: int, for random splitting - :param test_size: fraction of data to use in test spli - :param valid_size: fraction oof data to use in train split + :param test_size: fraction of data to use in test split + :param valid_size: fraction of data to use in train split :param num_workers: passed to pytorch dataloaders :param pin_memory: passed to pytorch dataloaders :param allow_unfound: If true, skip checking if the needed inputs and targets are found. diff --git a/hippynn/experiment/assembly.py b/hippynn/experiment/assembly.py index 4aab4308..8338ff70 100644 --- a/hippynn/experiment/assembly.py +++ b/hippynn/experiment/assembly.py @@ -18,7 +18,7 @@ """ -def gemerate_database_info(inputs, targets, allow_unfound=False): +def generate_database_info(inputs, targets, allow_unfound=False): """ Construct db info from input nodes and target nodes. :param inputs: list of input nodes @@ -157,7 +157,7 @@ def assemble_for_training(train_loss, validation_losses, validation_names=None, if plot_maker is not None: plot_maker.assemble_module(outputs, targets) - db_info = gemerate_database_info(inputs, targets) + db_info = generate_database_info(inputs, targets) evaluator = Evaluator(model, validation_lossfns, validation_names, plot_maker=plot_maker, db_info=db_info) diff --git a/hippynn/graphs/gops.py b/hippynn/graphs/gops.py index fc07126f..df4c15b8 100644 --- a/hippynn/graphs/gops.py +++ b/hippynn/graphs/gops.py @@ -298,7 +298,7 @@ def search_by_name(nodes, name_or_dbname): :return: node that matches criterion Raises NodeAmbiguityError if more than one node found - Raises NotNotFoundError if no nodes found + Raises NodeNotFoundError if no nodes found """ try: diff --git a/hippynn/graphs/nodes/base/definition_helpers.py b/hippynn/graphs/nodes/base/definition_helpers.py index d0612b85..fb39adf0 100644 --- a/hippynn/graphs/nodes/base/definition_helpers.py +++ b/hippynn/graphs/nodes/base/definition_helpers.py @@ -99,7 +99,7 @@ def _assert_tupleform(input_tuple, type_tuple): # If not, it must at least have the same length if not len(input_tuple) == len(type_tuple): raise TupleTypeMismatch( - "Wrong length.{}!={}".format(len(input_tuple), len(type_tuple)) + "Wrong length. {}!={}".format(len(input_tuple), len(type_tuple)) + " \nInput: {} \nExpected: {}".format(input_tuple, type_tuple) ) diff --git a/hippynn/graphs/nodes/indexers.py b/hippynn/graphs/nodes/indexers.py index 9e3a25a5..ffb20fed 100644 --- a/hippynn/graphs/nodes/indexers.py +++ b/hippynn/graphs/nodes/indexers.py @@ -188,3 +188,24 @@ def acquire_encoding_padding(search_nodes, species_set, purpose=None): pidxer = PaddingIndexer("PaddingIndexer", (encoder.encoding, encoder.nonblank)) return encoder, pidxer + +class FuzzyHistogrammer(AutoKw, SingleNode): + """ + Node for transforming a scalar feature into a vectorized feature via + the fuzzy/soft histogram method. + + :param length: length of vectorized feature + """ + + _input_names = "values" + _auto_module_class = index_modules.FuzzyHistogram + + def __init__(self, name, parents, length, vmin, vmax, module="auto", **kwargs): + + if isinstance(parents, _BaseNode): + parents = (parents,) + + self._output_index_state = parents[0]._index_state + self.module_kwargs = {"length": length, "vmin": vmin, "vmax": vmax} + + super().__init__(name, parents, module=module, **kwargs) \ No newline at end of file diff --git a/hippynn/graphs/nodes/pairs.py b/hippynn/graphs/nodes/pairs.py index 222953f0..ae2d6fb7 100644 --- a/hippynn/graphs/nodes/pairs.py +++ b/hippynn/graphs/nodes/pairs.py @@ -68,6 +68,28 @@ def __init__(self, name, parents, dist_hard_max, module="auto", module_kwargs=No parents = self.expand_parents(parents) super().__init__(name, parents, module=module, **kwargs) +class PeriodicPairIndexerMemory(PeriodicPairIndexer): + ''' + Implementation of PeriodicPairIndexer with additional memory component. + + Stores current pair indices in memory and reuses them to compute the pair distances if no + particle has moved more than skin/2 since last pair calculation. Otherwise uses the + _pair_indexer_class to recompute the pairs. + + Increasing the value of 'skin' will increase the number of pair distances computed at + each step, but decrease the number of times new pairs must be computed. Skin should be + set to zero while training for fastest results. + ''' + + _auto_module_class = pairs_modules.periodic.PeriodicPairIndexerMemory + + def __init__(self, name, parents, dist_hard_max, skin, module="auto", module_kwargs=None, **kwargs): + if module_kwargs is None: + module_kwargs = {} + module_kwargs = {"skin": skin, **module_kwargs} + + super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=module_kwargs, **kwargs) + class ExternalNeighborIndexer(ExpandParents, PairIndexer, AutoKw, MultiNode): _input_names = "coordinates", "real_atoms", "shifts", "cell", "ext_pair_first", "ext_pair_second" @@ -279,7 +301,7 @@ def __init__(self, name, parents, module="auto", bins=None, module_kwargs=None, super().__init__(name, parents, module=module, **kwargs) -class _DispatchNeighbors(ExpandParents, PeriodicPairOutputs, PairIndexer, MultiNode): +class _DispatchNeighbors(ExpandParents, AutoKw, PeriodicPairOutputs, PairIndexer, MultiNode): """ Superclass for nodes that compute neighbors for systems one at a time. These should be capable of searching all feasible neighbors (no limit on number of images) @@ -326,13 +348,15 @@ def expand1(self, pos, encode, indexer, cell, **kwargs): _parent_expander.get_main_outputs() _parent_expander.require_idx_states(IdxType.MolAtom, None, None, None, None, None, None, None) - def __init__(self, name, parents, dist_hard_max, module="auto", **kwargs): + def __init__(self, name, parents, dist_hard_max, module="auto", module_kwargs=None, **kwargs): self.dist_hard_max = dist_hard_max parents = self.expand_parents(parents) - super().__init__(name, parents, module=module, **kwargs) - def auto_module(self): - return self._auto_module_class(self.dist_hard_max) + if module_kwargs is None: + module_kwargs = {} + self.module_kwargs = {"dist_hard_max": dist_hard_max, **module_kwargs} + + super().__init__(name, parents, module=module, **kwargs) class NumpyDynamicPairs(_DispatchNeighbors): @@ -348,6 +372,33 @@ class DynamicPeriodicPairs(_DispatchNeighbors): _auto_module_class = pairs_modules.TorchNeighbors +class KDTreePairs(_DispatchNeighbors): + ''' + Node for finding pairs under periodic boundary conditions using Scipy's KD Tree algorithm. + Cell must be orthorhombic. + ''' + _auto_module_class = pairs_modules.dispatch.KDTreeNeighbors + +class KDTreePairsMemory(_DispatchNeighbors): + ''' + Implementation of KDTreePairs with an added memory component. + + Stores current pair indices in memory and reuses them to compute the pair distances if no + particle has moved more than skin/2 since last pair calculation. Otherwise uses the + _pair_indexer_class to recompute the pairs. + + Increasing the value of 'skin' will increase the number of pair distances computed at + each step, but decrease the number of times new pairs must be computed. Skin should be + set to zero while training for fastest results. + ''' + _auto_module_class = pairs_modules.dispatch.KDTreePairsMemory + + def __init__(self, name, parents, dist_hard_max, skin, module="auto", module_kwargs=None, **kwargs): + if module_kwargs is None: + module_kwargs = {} + module_kwargs = {"skin": skin, **module_kwargs} + + super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=module_kwargs, **kwargs) class PaddedNeighborNode(ExpandParents, AutoNoKw, MultiNode): _input_names = "pair_first", "pair_second", "pair_coord" diff --git a/hippynn/graphs/predictor.py b/hippynn/graphs/predictor.py index 437e5b4e..a65b67e6 100644 --- a/hippynn/graphs/predictor.py +++ b/hippynn/graphs/predictor.py @@ -41,6 +41,7 @@ def __init__(self, inputs, outputs, return_device=torch.device("cpu"), model_dev """ outputs = [search_by_name(inputs, o) if isinstance(o, str) else o for o in outputs] + outputs = list(set(outputs)) # Remove any redundancies -- they will screw up the output name map. outputs = [o for o in outputs if o._index_state is not IdxType.Scalar] @@ -77,7 +78,6 @@ def from_graph(cls, graph, additional_outputs=None, **kwargs): outputs = graph.nodes_to_compute if additional_outputs is not None: outputs = outputs + list(additional_outputs) - outputs = list(set(outputs)) # Remove any redundancies -- they will screw up the output name map. return cls(inputs, outputs, **kwargs) diff --git a/hippynn/layers/hiplayers.py b/hippynn/layers/hiplayers.py index eacb8013..be70943e 100644 --- a/hippynn/layers/hiplayers.py +++ b/hippynn/layers/hiplayers.py @@ -16,11 +16,13 @@ def warn_if_under(distance, threshold): if dmin < threshold: d_count = distance < threshold d_frac = d_count.to(distance.dtype).mean() + d_sum = (d_count.sum()/2).to(torch.int) warnings.warn( "Provided distances are underneath sensitivity range!\n" f"Minimum distance in current batch: {dmin}\n" f"Threshold distance for warning: {threshold}.\n" - f"Fraction of pairs under the threshold: {d_frac}" + f"Fraction of pairs under the threshold: {d_frac}\n" + f"Number of pairs under the threshold: {d_sum}" ) diff --git a/hippynn/layers/indexers.py b/hippynn/layers/indexers.py index 91e7f477..c3d6d1f9 100644 --- a/hippynn/layers/indexers.py +++ b/hippynn/layers/indexers.py @@ -221,3 +221,37 @@ def forward(self, bonds, pair_first, pair_second): # in seqm, only bonds with index first < second is used cond = pair_first < pair_second return bonds[cond] + +class FuzzyHistogram(torch.nn.Module): + """ + Transforms a scalar feature into a vectorized feature via + the fuzzy/soft histogram method. + + :param length: length of vectorized feature + + :returns FuzzyHistogram + """ + + def __init__(self, length, vmin, vmax): + super().__init__() + + err_msg = "The value of 'length' must be a positive integer." + if not isinstance(length, int): + raise ValueError(err_msg) + if length <= 0: + raise ValueError(err_msg) + + if not (isinstance(vmin, (int,float)) and isinstance(vmax, (int,float))): + raise ValueError("The values of 'vmin' and 'vmax' must be floating point numbers.") + if vmin >= vmax: + raise ValueError("The value of 'vmin' must be less than the value of 'vmax.'") + + self.bins = torch.nn.Parameter(torch.linspace(vmin, vmax, length), requires_grad=False) + self.sigma = (vmax - vmin) / length + + def forward(self, values): + if values.shape[-1] != 1: + values = values[...,None] + x = values - self.bins + histo = torch.exp(-((x / self.sigma) ** 2) / 4) + return torch.flatten(histo, end_dim=1) \ No newline at end of file diff --git a/hippynn/layers/pairs/dispatch.py b/hippynn/layers/pairs/dispatch.py index 8ed27d94..f034a516 100644 --- a/hippynn/layers/pairs/dispatch.py +++ b/hippynn/layers/pairs/dispatch.py @@ -2,9 +2,12 @@ System-by-system pair finders """ +from itertools import product import numpy as np +from scipy.spatial import KDTree import torch +from .open import PairMemory def wrap_points_np(coords, cell, inv_cell): # cell is (basis,cartesian) @@ -132,6 +135,68 @@ def neighbor_list_torch(cutoff: float, coords, cell): pi = pi - wrap_offset_ij[pf, ps] return pf, ps, pi +def neighbor_list_kdtree(cutoff, coords, cell): + ''' + Use KD Tree implementation from scipy.spatial to find pairs under periodic boundary conditions + with an orthonormal cell. + ''' + + # Verify that cell is orthorhombic + cell_prod = cell @ cell.T + if torch.count_nonzero(cell_prod - torch.diag(torch.diag(cell_prod))): + raise ValueError("KD Tree search only works for orthorhombic cells.") + + # Verify that the cutoff is less than the side lengths of the cell + cell_side_lengths = torch.sqrt(torch.diag(cell_prod)) + if (cutoff >= cell_side_lengths).any(): + raise ValueError(f"Cutoff value ({cutoff}) must be less than the cell slide lengths ({cell_side_lengths}).") + + if torch.count_nonzero(cell - torch.diag(torch.diag(cell))): + # Transform via isometry to a basis where cell is a diagonal matrix if it currently is not + new_cell = torch.sqrt(cell_prod) + new_coords = coords @ torch.linalg.inv(cell) @ new_cell + else: + new_cell = cell.clone() + new_coords = coords.clone() + + # Find pair indices + tree = KDTree( + data=new_coords.detach().cpu().numpy(), + boxsize=torch.diag(new_cell).detach().cpu().numpy() + ) + + pairs = tree.query_pairs(r=cutoff, output_type='ndarray') + pairs = torch.as_tensor(pairs, device=coords.device) + pair_first, pair_second = torch.unbind(pairs, dim=1) + + # Find difference vector between pairs without considering the MIC + pair_diff = torch.sub(coords[pair_first], coords[pair_second]) + + # Possible adjacent offset directions for images of the difference vector + offset_range = torch.tensor(list(product([-1, 0, 1], repeat=3)), device=coords.device) + + # All adjacent offsets + perm_offsets = offset_range.to(cell.dtype) @ cell + + # All adjacent offset images of the difference vector + pair_diff = pair_diff.unsqueeze(1) + perm_offsets.unsqueeze(0) + + # L2 norm of offset images + pair_diff = torch.linalg.norm(pair_diff, dim=2) + + # Index of shortest offset image + pair_diff = torch.argmin(pair_diff, dim=1) + + # Offset direction corresponding to shortest offset image + pair_image = offset_range[pair_diff] + + # KDTree only returns each pair once (eg. (1,2) but not (2,1)) + doubled_pair_first = torch.concat((pair_first, pair_second)) + doubled_pair_second = torch.concat((pair_second, pair_first)) + doubled_pair_image = torch.concat((pair_image, -pair_image)) + + return doubled_pair_first, doubled_pair_second, doubled_pair_image + class _DispatchNeighbors(torch.nn.Module): def __init__(self, dist_hard_max): @@ -204,7 +269,7 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cell, mol_i pair_cell = cell[pair_mol] pair_offsets = torch.bmm(offsets.unsqueeze(1).to(pair_cell.dtype), pair_cell).squeeze(1) - # now calculate pair_dist, paircoord differentiably + # now calculate pair_dist, paircoord differentiably # print("Pairs found",pair_first.shape) coordflat = coordinates.reshape(n_molecules * n_atoms_max, 3)[real_atoms] paircoord = coordflat[pair_first] - coordflat[pair_second] + pair_offsets @@ -226,3 +291,64 @@ def compute_one(self, positions, cell): with torch.no_grad(): outputs = neighbor_list_torch(self.dist_hard_max, positions, cell) return outputs + +class KDTreeNeighbors(_DispatchNeighbors): + ''' + Node for finding pairs under periodic boundary conditions using Scipy's KD Tree algorithm. + Cell must be orthorhombic. + ''' + + def compute_one(self, positions, cell): + with torch.no_grad(): + outputs = neighbor_list_kdtree(self.dist_hard_max, positions, cell) + return outputs + + +class KDTreePairsMemory(PairMemory): + ''' + Implementation of KDTreePairs with an added memory component. + + Stores current pair indices in memory and reuses them to compute the pair distances if no + particle has moved more than skin/2 since last pair calculation. Otherwise uses the + _pair_indexer_class to recompute the pairs. + + Increasing the value of 'skin' will increase the number of pair distances computed at + each step, but decrease the number of times new pairs must be computed. Skin should be + set to zero while training for fastest results. + ''' + + _pair_indexer_class = KDTreeNeighbors + + def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells, mol_index, n_molecules, n_atoms_max): + if self.recalculation_needed(coordinates, cells): + self.recalculations += 1 + + inputs = (coordinates, nonblank, real_atoms, inv_real_atoms, cells, mol_index, n_molecules, n_atoms_max) + outputs = self._pair_indexer(*inputs) + distflat2, pair_first, pair_second, paircoord, offsets, offset_index = outputs + + with torch.no_grad(): + pair_mol = mol_index[pair_first] + pair_cell = cells[pair_mol] + pair_offsets = torch.bmm(offsets.unsqueeze(1).to(pair_cell.dtype), pair_cell).squeeze(1) + + for name, var in [ + ("pair_first", pair_first), + ("pair_second", pair_second), + ("offsets", offsets), + ("offset_index", offset_index), + ("pair_offsets", pair_offsets), + ("positions", coordinates), + ("cells", cells), + ]: + self.__setattr__(name, var) + + else: + self.reuses += 1 + + coordflat = coordinates.reshape(n_molecules * n_atoms_max, 3)[real_atoms] + paircoord = coordflat[self.pair_first] - coordflat[self.pair_second] + self.pair_offsets + distflat2 = paircoord.norm(dim=1) + + return distflat2, self.pair_first, self.pair_second, paircoord, self.offsets, self.offset_index + diff --git a/hippynn/layers/pairs/indexing.py b/hippynn/layers/pairs/indexing.py index 6f26cedb..7239e11f 100644 --- a/hippynn/layers/pairs/indexing.py +++ b/hippynn/layers/pairs/indexing.py @@ -59,7 +59,11 @@ def forward(self, features, molecule_index, atom_index, n_molecules, n_atoms_max class MolPairSummer(torch.nn.Module): def forward(self, pairfeatures, mol_index, n_molecules, pair_first): pair_mol = mol_index[pair_first] - feat_shape = (1,) if pairfeatures.ndimension() == 1 else pairfeatures.shape[1:] + if pairfeatures.shape[0] == 1: + feat_shape = (1,) + pairfeatures.unsqueeze(-1) + else: + feat_shape = pairfeatures.shape[1:] out_shape = (n_molecules, *feat_shape) result = torch.zeros(out_shape, device=pairfeatures.device, dtype=pairfeatures.dtype) result.index_add_(0, pair_mol, pairfeatures) diff --git a/hippynn/layers/pairs/open.py b/hippynn/layers/pairs/open.py index cbee674e..68acdf84 100644 --- a/hippynn/layers/pairs/open.py +++ b/hippynn/layers/pairs/open.py @@ -58,3 +58,78 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms): distflat2 = paircoord.norm(dim=1) return distflat2, pair_first, pair_second, paircoord + +class PairMemory(torch.nn.Module): + ''' + Stores current pair indices and reuses them to compute the pair distances if no + particle has moved more than skin/2 since last pair calculation. Otherwise uses the + _pair_indexer_class to recompute the pairs. + + Increasing the value of 'skin' will increase the number of pair distances computed at + each step, but decrease the number of times new pairs must be computed. Skin should be + set to zero while training for fastest results. + ''' + + # ## Subclasses should update the following ## # + _pair_indexer_class = NotImplemented + + def forward(*args, **kwargs): + return NotImplementedError + # ## End ## # + + def __init__(self, skin, dist_hard_max=None, hard_dist_cutoff=None): + super().__init__() + + if dist_hard_max is None and hard_dist_cutoff is None: + raise ValueError("One of 'dist_hard_max' and 'hard_dist_cutoff' must be specified.") + if dist_hard_max is not None and hard_dist_cutoff is not None and dist_hard_max != hard_dist_cutoff: + raise ValueError("Must only specify one of 'dist_hard_max' and 'hard_dist_cutoff.'") + + self.hard_dist_cutoff = (dist_hard_max or hard_dist_cutoff) + self.dist_hard_max = (dist_hard_max or hard_dist_cutoff) + self.set_skin(skin) + + @property + def skin(self): + return self._skin + + def set_skin(self, skin): + self._skin = skin + + try: + self._pair_indexer = self._pair_indexer_class(hard_dist_cutoff = self._skin + self.hard_dist_cutoff) + except TypeError: + self._pair_indexer = self._pair_indexer_class(dist_hard_max = self._skin + self.hard_dist_cutoff) + + self.reset_reuse_percentage() + self.initialize_buffers() + + @skin.setter + def skin(self, skin): + self.set_skin(skin) + + @property + def reuse_percentage(self): + ''' + Returns None if there are no model calls on record. + ''' + try: + return self.reuses / (self.reuses + self.recalculations) * 100 + except ZeroDivisionError: + print("No model calls on record.") + return + + def reset_reuse_percentage(self): + self.reuses = 0 + self.recalculations = 0 + + def initialize_buffers(self): + for name in ["pair_mol", "cell_offsets", "pair_first", "pair_second", "offset_num", "positions", "cells"]: + self.register_buffer(name=name, tensor=None, persistent=False) + + def recalculation_needed(self, coordinates, cells): + if self.positions is None: # ie. forward function has not been called + return True + if (self.cells != cells).any() or (((self.positions - coordinates)**2).sum(1).max() > (self._skin/2)**2): + return True + return False \ No newline at end of file diff --git a/hippynn/layers/pairs/periodic.py b/hippynn/layers/pairs/periodic.py index 4a36db2c..f3666fd0 100644 --- a/hippynn/layers/pairs/periodic.py +++ b/hippynn/layers/pairs/periodic.py @@ -1,6 +1,9 @@ import torch -from .open import _PairIndexer +from scipy.spatial import KDTree + +from .open import _PairIndexer, PairMemory +from torch.profiler import profile, record_function, ProfilerActivity # Deprecated? class StaticImagePeriodicPairIndexer(_PairIndexer): @@ -87,7 +90,7 @@ def tracebatch(mat): return torch.diagonal(mat, dim1=-2, dim2=-1).sum(dim=-1) # For some reason torch.linalg on GPU tends to -# spend a lot of time allocating memory, especialyl +# spend a lot of time allocating memory, especially # when it is given a large batch of matrices. # So we use the Cayley-Hamilton version of a matrix inverse # Without calling any linalg functions. @@ -150,7 +153,7 @@ class PeriodicPairIndexer(_PairIndexer): Finds pairs in general periodic conditions. """ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells): - + original_coordinates = coordinates with torch.no_grad(): @@ -261,4 +264,47 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells): paircoord = coordflat[pair_first] - coordflat[pair_second] + pair_shifts distflat2 = paircoord.norm(dim=1) - return distflat2, pair_first, pair_second, paircoord, cell_offsets, offset_num + return distflat2, pair_first, pair_second, paircoord, cell_offsets, offset_num, pair_mol + +class PeriodicPairIndexerMemory(PairMemory): + ''' + Implementation of PeriodicPairIndexer with additional memory component. + + Stores current pair indices in memory and reuses them to compute the pair distances if no + particle has moved more than skin/2 since last pair calculation. Otherwise uses the + _pair_indexer_class to recompute the pairs. + + Increasing the value of 'skin' will increase the number of pair distances computed at + each step, but decrease the number of times new pairs must be computed. Skin should be + set to zero while training for fastest results. + ''' + _pair_indexer_class = PeriodicPairIndexer + + def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells): + if self.recalculation_needed(coordinates, cells): + self.n_molecules, self.n_atoms, _ = coordinates.shape + self.recalculations += 1 + + inputs = (coordinates, nonblank, real_atoms, inv_real_atoms, cells) + outputs = self._pair_indexer(*inputs) + distflat2, pair_first, pair_second, paircoord, cell_offsets, offset_num, pair_mol = outputs + + for name, var in [ + ("cell_offsets", cell_offsets), + ("pair_first", pair_first), + ("pair_second", pair_second), + ("offset_num", offset_num), + ("positions", coordinates), + ("cells", cells), + ("pair_mol", pair_mol) + ]: + self.__setattr__(name, var) + + else: + self.reuses += 1 + pair_shifts = torch.matmul(self.cell_offsets.unsqueeze(1).to(cells.dtype), cells[self.pair_mol]).squeeze(1) + coordflat = coordinates.reshape(self.n_molecules * self.n_atoms, 3)[real_atoms] + paircoord = coordflat[self.pair_first] - coordflat[self.pair_second] + pair_shifts + distflat2 = paircoord.norm(dim=1) + + return distflat2, self.pair_first, self.pair_second, paircoord, self.cell_offsets, self.offset_num \ No newline at end of file