Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiGradient node and implement minor improvements and bug fixes for KDTree and Memory nodes #53

Merged
merged 18 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,28 @@ New Features:
- Add nodes for non-adiabatic coupling vectors (NACR) and phase-less loss.
See /examples/excited_states_azomethane.py.

Improvements
------------
- New MultiGradient node for computing multiple partial derivatives of
the same node simultaneously.

Improvements:
-------------

- Multi-target dipole node now has a shape of (n_molecules, n_targets, 3).

- Add out-of-range warning to FuzzyHistogrammer.

- Create Memory parent class to remove redundancy.

Bug Fixes:
----------

- Fix KDTreePairs issue caused by floating point precision limitations.

- Fix KDTreePairs issue with not moving tensors off GPU.

- Enable PairMemory nodes to handle batch size > 1.


0.0.2a2
=======

Expand All @@ -27,8 +44,8 @@ New Features:
- New KDTreePairs and KDTreePairsMemory nodes for computing pairs using linearly-
scaling KD Tree algorithm.

Improvements
------------
Improvements:
-------------

- ASE database loader added to read any ASE file or list of ASE files.

Expand Down
24 changes: 20 additions & 4 deletions hippynn/graphs/nodes/pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,23 @@ def __init__(self, name, parents, dist_hard_max, module="auto", module_kwargs=No
parents = self.expand_parents(parents)
super().__init__(name, parents, module=module, **kwargs)

class PeriodicPairIndexerMemory(PeriodicPairIndexer):
class Memory:
@property
def skin(self):
return self.torch_module.skin

@skin.setter
def skin(self, skin):
self.torch_module.skin = skin

@property
def reuse_percentage(self):
return self.torch_module.reuse_percentage

def reset_reuse_percentage(self):
self.torch_module.reset_reuse_percentage()

class PeriodicPairIndexerMemory(PeriodicPairIndexer, Memory):
'''
Implementation of PeriodicPairIndexer with additional memory component.

Expand All @@ -86,9 +102,9 @@ class PeriodicPairIndexerMemory(PeriodicPairIndexer):
def __init__(self, name, parents, dist_hard_max, skin, module="auto", module_kwargs=None, **kwargs):
if module_kwargs is None:
module_kwargs = {}
module_kwargs = {"skin": skin, **module_kwargs}
self.module_kwargs = {"skin": skin, **module_kwargs}

super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=module_kwargs, **kwargs)
super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=self.module_kwargs, **kwargs)


class ExternalNeighborIndexer(ExpandParents, PairIndexer, AutoKw, MultiNode):
Expand Down Expand Up @@ -379,7 +395,7 @@ class KDTreePairs(_DispatchNeighbors):
'''
_auto_module_class = pairs_modules.dispatch.KDTreeNeighbors

class KDTreePairsMemory(_DispatchNeighbors):
class KDTreePairsMemory(_DispatchNeighbors, Memory):
'''
Implementation of KDTreePairs with an added memory component.

Expand Down
25 changes: 25 additions & 0 deletions hippynn/graphs/nodes/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,31 @@ def __init__(self, name, parents, sign, module="auto", **kwargs):
self.sign = sign
self._index_state = position._index_state
super().__init__(name, parents, module=module, **kwargs)

class MultiGradientNode(AutoKw, MultiNode):
"""
Compute the gradient of a quantity.
"""

_auto_module_class = physics_layers.MultiGradient

def __init__(self, name: str, molecular_energies_parent: _BaseNode, generalized_coordinates_parents: tuple[_BaseNode], signs: tuple[int], module="auto", **kwargs):
if isinstance(signs, int):
signs = (signs,)

self.signs = signs
self.module_kwargs = {"signs": signs}

parents = molecular_energies_parent, *generalized_coordinates_parents

for parent in generalized_coordinates_parents:
parent.requires_grad = True

self._input_names = tuple((parent.name for parent in parents))
self._output_names = tuple((parent.name + "_grad" for parent in generalized_coordinates_parents))
self._output_index_states = tuple(parent._index_state for parent in generalized_coordinates_parents)

super().__init__(name, parents, module=module, **kwargs)


class StressForceNode(AutoNoKw, MultiNode):
Expand Down
17 changes: 17 additions & 0 deletions hippynn/layers/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Layers for encoding, decoding, index states, besides pairs
"""

import warnings
import torch


Expand Down Expand Up @@ -249,7 +250,23 @@ def __init__(self, length, vmin, vmax):
self.bins = torch.nn.Parameter(torch.linspace(vmin, vmax, length), requires_grad=False)
self.sigma = (vmax - vmin) / length

self.vmin = vmin
self.vmax = vmax

def forward(self, values):
# Warn user if provided values lie outside the range of the histogram bins
values_out_of_range = (values < self.vmin) + (values > self.vmax)

if values_out_of_range.sum() > 0:
perc_out_of_range = values_out_of_range.float().mean()
warnings.warn(
"Values out of range for FuzzyHistogrammer\n"
f"Number of values out of range: {values_out_of_range.sum()}\n"
f"Percentage of values out of range: {perc_out_of_range * 100:.2f}%\n"
f"Set range for FuzzyHistogrammer: ({self.vmin:.2f}, {self.vmax:.2f})\n"
f"Range of values: ({values.min().item():.2f}, {values.max().item():.2f})"
)

if values.shape[-1] != 1:
values = values[...,None]
x = values - self.bins
Expand Down
19 changes: 16 additions & 3 deletions hippynn/layers/pairs/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
from scipy.spatial import KDTree
import torch
import os
from datetime import datetime

from .open import PairMemory

Expand Down Expand Up @@ -159,10 +161,21 @@ def neighbor_list_kdtree(cutoff, coords, cell):
new_cell = cell.clone()
new_coords = coords.clone()

# Find pair indices
new_coords = new_coords % torch.diag(new_cell)

# The following three lines are included to prevent an extremely rare but not unseen edge
# case where the modulo operation returns a particle coordinate that is exactly equal to
# the corresponding cell length, causing KDTree to throw an error
n_particles = new_coords.shape[0]
tiled_cell = torch.tile(torch.diag(new_cell), (n_particles,)).reshape(n_particles, -1)
new_coords = torch.where(new_coords == tiled_cell, 0, new_coords)

new_coords = new_coords.detach().cpu().numpy()
new_cell = torch.diag(new_cell).detach().cpu().numpy()

tree = KDTree(
data=new_coords.detach().cpu().numpy(),
boxsize=torch.diag(new_cell).detach().cpu().numpy()
data=new_coords,
boxsize=new_cell,
)

pairs = tree.query_pairs(r=cutoff, output_type='ndarray')
Expand Down
6 changes: 5 additions & 1 deletion hippynn/layers/pairs/open.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def initialize_buffers(self):
self.register_buffer(name=name, tensor=None, persistent=False)

def recalculation_needed(self, coordinates, cells):
if self.positions is None: # ie. forward function has not been called
if coordinates.shape[0] != 1: # does not support batch size larger than 1
return True
if self.positions is None: # ie. forward function has not been called
return True
if self.skin == 0:
return True
if (self.cells != cells).any() or (((self.positions - coordinates)**2).sum(1).max() > (self._skin/2)**2):
return True
Expand Down
17 changes: 16 additions & 1 deletion hippynn/layers/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,22 @@ def __init__(self, sign):

def forward(self, molecular_energies, positions):
return self.sign * torch.autograd.grad(molecular_energies.sum(), positions, create_graph=True)[0]


class MultiGradient(torch.nn.Module):
def __init__(self, signs):
super().__init__()
if isinstance(signs, int):
signs = (signs,)
for sign in signs:
assert sign in (-1,1), "Sign of gradient must be -1 or +1"
self.signs = signs

def forward(self, molecular_energies: Tensor, *generalized_coordinates: Tensor):
if isinstance(generalized_coordinates, Tensor):
generalized_coordinates = (generalized_coordinates,)
assert len(generalized_coordinates) == len(self.signs), f"Number of items to take derivative w.r.t ({len(generalized_coordinates)}) must match number of provided signs ({len(self.signs)})."
grads = torch.autograd.grad(molecular_energies.sum(), generalized_coordinates, create_graph=True)
return tuple((sign * grad for sign, grad in zip(self.signs, grads)))

class StressForce(torch.nn.Module):
def __init__(self, *args, **kwargs):
Expand Down
Loading