Skip to content

Commit

Permalink
Add fuzzy histogram, pair-finder memory, and KD Tree pair-finder node…
Browse files Browse the repository at this point in the history
…s; fix bugs (#50)

* add pre-interaction layers option to hippynn

* revert pre-interaction layers, add fuzzy histogram feature

* fix minor typos, add number pairs to warn_if_under function

* add pair-finder with memory, fix minor bugs

* fix bug in pair-finder with memory

* * New KD Tree pair-finder node
* Modularized pair-finder memory component
* Typos corrected

* revert my change to .gitignore, revise docstring

* Updae change log and docs. Revert unneeded changes.

---------

Co-authored-by: Emily Suzanne Shinkle <[email protected]>
  • Loading branch information
shinkle-lanl and Emily Suzanne Shinkle authored Dec 11, 2023
1 parent ec91cf8 commit 9b24b21
Show file tree
Hide file tree
Showing 16 changed files with 417 additions and 22 deletions.
20 changes: 20 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,31 @@ Improvements
0.0.2a2
=======

New Features:
-------------

- New FuzzyHistogrammer node for transforming scalar feature into a fuzzy/soft
histogram array

- New PeriodicPairIndexerMemory node which removes the need to recompute
pairs for each model evaluation in some instances, leading to speed improvements

- New KDTreePairs and KDTreePairsMemory nodes for computing pairs using linearly-
scaling KD Tree algorithm.

Improvements
------------

- ASE database loader added to read any ASE file or list of ASE files.

Bug Fixes:
----------
- Function 'gemerate_database_info' renamed to 'generate_database_info.'

- Fixed issue with class Predictor arising when multiple names for the same output node are provided.

- Fixed issue with MolPairSummer when the batch size and the feature size are both one.

0.0.2a1
=======

Expand Down
20 changes: 18 additions & 2 deletions docs/source/examples/periodic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ to within the unit cell. Because the nearest images (27 replicates of the cell a
search radius 1) are numerous, periodic pair finding is noticeably more costly in terms of
memory and time than open boundary conditions. The less skewed your cells are, as well as
are the larger cells are compared to the cutoff distance required,
the fewer images needed to be searched in finding pairs.
the fewer images needed to be searched in finding pairs.


Dynamic Pair Finder
Expand All @@ -42,7 +42,23 @@ the systems one by one. The upshot of this is that less memory is required.
However, the cost is that each system is evaluated independently in serial,
and as such the pair finding can be a rather slow operation. This algorithm is
more likely to show benefits when the number of atoms in a training system is highly
variable.
variable.

For systems with orthorhombic cells and an interaction radius not greater than any of the
cell side lengths, the :class:`~hippynn.graphs.nodes.pairs.KDTreePairs` can be used
alternatively. It should exhibit reduced computation times, especially for large systems.

Pair Finder Memory
------------------
When using a trained model to run MD or for any application where atom positions
change only slightly between subsquent model calls,
:class:`~hippynn.graphs.nodes.pairs.PeriodicPairIndexerMemory` and
:class:`~hippynn.graphs.nodes.pairs.KDTreePairsMemory` can be used to reduce run
time by reusing pair information. Current pair indices are stored in memory and
reused so long as no atom has moved more than `skin`/2, where `skin` is an additional
parameter set by the user. Increasing the value of `skin` will increase the number of
pair distances computed at each step, but decrease the number of times new pairs must
be computed. Skin should be set to zero while training for fastest results.

Caching Pre-computed Pairs
--------------------------
Expand Down
2 changes: 1 addition & 1 deletion examples/allegro_ag_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def fit_model(training_modules,database):
with hippynn.tools.log_terminal("model_results.txt",'wt'):
test_model(database, training_modules.evaluator, 128, "Final Training")

## Possible to export lammps MLIPInterface for model if lammmps with MLIP Installed!
## Possible to export lammps MLIPInterface for model if Lammps with MLIP Installed!
# print("Exporting lammps interface")
# first_frame = ase.io.read(dbname) # Reads in first frame only for saving box
# ase.io.write('ag_box.data', first_frame, format='lammps-data')
Expand Down
4 changes: 2 additions & 2 deletions hippynn/databases/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(
:param inputs: list of strings for input db_names
:param targets: list of strings for output db_namees
:param seed: int, for random splitting
:param test_size: fraction of data to use in test spli
:param valid_size: fraction oof data to use in train split
:param test_size: fraction of data to use in test split
:param valid_size: fraction of data to use in train split
:param num_workers: passed to pytorch dataloaders
:param pin_memory: passed to pytorch dataloaders
:param allow_unfound: If true, skip checking if the needed inputs and targets are found.
Expand Down
4 changes: 2 additions & 2 deletions hippynn/experiment/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""


def gemerate_database_info(inputs, targets, allow_unfound=False):
def generate_database_info(inputs, targets, allow_unfound=False):
"""
Construct db info from input nodes and target nodes.
:param inputs: list of input nodes
Expand Down Expand Up @@ -157,7 +157,7 @@ def assemble_for_training(train_loss, validation_losses, validation_names=None,
if plot_maker is not None:
plot_maker.assemble_module(outputs, targets)

db_info = gemerate_database_info(inputs, targets)
db_info = generate_database_info(inputs, targets)

evaluator = Evaluator(model, validation_lossfns, validation_names, plot_maker=plot_maker, db_info=db_info)

Expand Down
2 changes: 1 addition & 1 deletion hippynn/graphs/gops.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def search_by_name(nodes, name_or_dbname):
:return: node that matches criterion
Raises NodeAmbiguityError if more than one node found
Raises NotNotFoundError if no nodes found
Raises NodeNotFoundError if no nodes found
"""
try:
Expand Down
2 changes: 1 addition & 1 deletion hippynn/graphs/nodes/base/definition_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _assert_tupleform(input_tuple, type_tuple):
# If not, it must at least have the same length
if not len(input_tuple) == len(type_tuple):
raise TupleTypeMismatch(
"Wrong length.{}!={}".format(len(input_tuple), len(type_tuple))
"Wrong length. {}!={}".format(len(input_tuple), len(type_tuple))
+ " \nInput: {} \nExpected: {}".format(input_tuple, type_tuple)
)

Expand Down
21 changes: 21 additions & 0 deletions hippynn/graphs/nodes/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,24 @@ def acquire_encoding_padding(search_nodes, species_set, purpose=None):
pidxer = PaddingIndexer("PaddingIndexer", (encoder.encoding, encoder.nonblank))

return encoder, pidxer

class FuzzyHistogrammer(AutoKw, SingleNode):
"""
Node for transforming a scalar feature into a vectorized feature via
the fuzzy/soft histogram method.
:param length: length of vectorized feature
"""

_input_names = "values"
_auto_module_class = index_modules.FuzzyHistogram

def __init__(self, name, parents, length, vmin, vmax, module="auto", **kwargs):

if isinstance(parents, _BaseNode):
parents = (parents,)

self._output_index_state = parents[0]._index_state
self.module_kwargs = {"length": length, "vmin": vmin, "vmax": vmax}

super().__init__(name, parents, module=module, **kwargs)
61 changes: 56 additions & 5 deletions hippynn/graphs/nodes/pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,28 @@ def __init__(self, name, parents, dist_hard_max, module="auto", module_kwargs=No
parents = self.expand_parents(parents)
super().__init__(name, parents, module=module, **kwargs)

class PeriodicPairIndexerMemory(PeriodicPairIndexer):
'''
Implementation of PeriodicPairIndexer with additional memory component.
Stores current pair indices in memory and reuses them to compute the pair distances if no
particle has moved more than skin/2 since last pair calculation. Otherwise uses the
_pair_indexer_class to recompute the pairs.
Increasing the value of 'skin' will increase the number of pair distances computed at
each step, but decrease the number of times new pairs must be computed. Skin should be
set to zero while training for fastest results.
'''

_auto_module_class = pairs_modules.periodic.PeriodicPairIndexerMemory

def __init__(self, name, parents, dist_hard_max, skin, module="auto", module_kwargs=None, **kwargs):
if module_kwargs is None:
module_kwargs = {}
module_kwargs = {"skin": skin, **module_kwargs}

super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=module_kwargs, **kwargs)


class ExternalNeighborIndexer(ExpandParents, PairIndexer, AutoKw, MultiNode):
_input_names = "coordinates", "real_atoms", "shifts", "cell", "ext_pair_first", "ext_pair_second"
Expand Down Expand Up @@ -279,7 +301,7 @@ def __init__(self, name, parents, module="auto", bins=None, module_kwargs=None,
super().__init__(name, parents, module=module, **kwargs)


class _DispatchNeighbors(ExpandParents, PeriodicPairOutputs, PairIndexer, MultiNode):
class _DispatchNeighbors(ExpandParents, AutoKw, PeriodicPairOutputs, PairIndexer, MultiNode):
"""
Superclass for nodes that compute neighbors for systems one at a time.
These should be capable of searching all feasible neighbors (no limit on number of images)
Expand Down Expand Up @@ -326,13 +348,15 @@ def expand1(self, pos, encode, indexer, cell, **kwargs):
_parent_expander.get_main_outputs()
_parent_expander.require_idx_states(IdxType.MolAtom, None, None, None, None, None, None, None)

def __init__(self, name, parents, dist_hard_max, module="auto", **kwargs):
def __init__(self, name, parents, dist_hard_max, module="auto", module_kwargs=None, **kwargs):
self.dist_hard_max = dist_hard_max
parents = self.expand_parents(parents)
super().__init__(name, parents, module=module, **kwargs)

def auto_module(self):
return self._auto_module_class(self.dist_hard_max)
if module_kwargs is None:
module_kwargs = {}
self.module_kwargs = {"dist_hard_max": dist_hard_max, **module_kwargs}

super().__init__(name, parents, module=module, **kwargs)


class NumpyDynamicPairs(_DispatchNeighbors):
Expand All @@ -348,6 +372,33 @@ class DynamicPeriodicPairs(_DispatchNeighbors):

_auto_module_class = pairs_modules.TorchNeighbors

class KDTreePairs(_DispatchNeighbors):
'''
Node for finding pairs under periodic boundary conditions using Scipy's KD Tree algorithm.
Cell must be orthorhombic.
'''
_auto_module_class = pairs_modules.dispatch.KDTreeNeighbors

class KDTreePairsMemory(_DispatchNeighbors):
'''
Implementation of KDTreePairs with an added memory component.
Stores current pair indices in memory and reuses them to compute the pair distances if no
particle has moved more than skin/2 since last pair calculation. Otherwise uses the
_pair_indexer_class to recompute the pairs.
Increasing the value of 'skin' will increase the number of pair distances computed at
each step, but decrease the number of times new pairs must be computed. Skin should be
set to zero while training for fastest results.
'''
_auto_module_class = pairs_modules.dispatch.KDTreePairsMemory

def __init__(self, name, parents, dist_hard_max, skin, module="auto", module_kwargs=None, **kwargs):
if module_kwargs is None:
module_kwargs = {}
module_kwargs = {"skin": skin, **module_kwargs}

super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=module_kwargs, **kwargs)

class PaddedNeighborNode(ExpandParents, AutoNoKw, MultiNode):
_input_names = "pair_first", "pair_second", "pair_coord"
Expand Down
2 changes: 1 addition & 1 deletion hippynn/graphs/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, inputs, outputs, return_device=torch.device("cpu"), model_dev
"""

outputs = [search_by_name(inputs, o) if isinstance(o, str) else o for o in outputs]
outputs = list(set(outputs)) # Remove any redundancies -- they will screw up the output name map.

outputs = [o for o in outputs if o._index_state is not IdxType.Scalar]

Expand Down Expand Up @@ -77,7 +78,6 @@ def from_graph(cls, graph, additional_outputs=None, **kwargs):
outputs = graph.nodes_to_compute
if additional_outputs is not None:
outputs = outputs + list(additional_outputs)
outputs = list(set(outputs)) # Remove any redundancies -- they will screw up the output name map.

return cls(inputs, outputs, **kwargs)

Expand Down
4 changes: 3 additions & 1 deletion hippynn/layers/hiplayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ def warn_if_under(distance, threshold):
if dmin < threshold:
d_count = distance < threshold
d_frac = d_count.to(distance.dtype).mean()
d_sum = (d_count.sum()/2).to(torch.int)
warnings.warn(
"Provided distances are underneath sensitivity range!\n"
f"Minimum distance in current batch: {dmin}\n"
f"Threshold distance for warning: {threshold}.\n"
f"Fraction of pairs under the threshold: {d_frac}"
f"Fraction of pairs under the threshold: {d_frac}\n"
f"Number of pairs under the threshold: {d_sum}"
)


Expand Down
34 changes: 34 additions & 0 deletions hippynn/layers/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,37 @@ def forward(self, bonds, pair_first, pair_second):
# in seqm, only bonds with index first < second is used
cond = pair_first < pair_second
return bonds[cond]

class FuzzyHistogram(torch.nn.Module):
"""
Transforms a scalar feature into a vectorized feature via
the fuzzy/soft histogram method.
:param length: length of vectorized feature
:returns FuzzyHistogram
"""

def __init__(self, length, vmin, vmax):
super().__init__()

err_msg = "The value of 'length' must be a positive integer."
if not isinstance(length, int):
raise ValueError(err_msg)
if length <= 0:
raise ValueError(err_msg)

if not (isinstance(vmin, (int,float)) and isinstance(vmax, (int,float))):
raise ValueError("The values of 'vmin' and 'vmax' must be floating point numbers.")
if vmin >= vmax:
raise ValueError("The value of 'vmin' must be less than the value of 'vmax.'")

self.bins = torch.nn.Parameter(torch.linspace(vmin, vmax, length), requires_grad=False)
self.sigma = (vmax - vmin) / length

def forward(self, values):
if values.shape[-1] != 1:
values = values[...,None]
x = values - self.bins
histo = torch.exp(-((x / self.sigma) ** 2) / 4)
return torch.flatten(histo, end_dim=1)
Loading

0 comments on commit 9b24b21

Please sign in to comment.