diff --git a/.gitignore b/.gitignore index 2818d631..299d9aea 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ __pycache__/ *.pyc build/ -hippynn.egg-info/* \ No newline at end of file +hippynn.egg-info/* diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0637f9ba..c5c874bb 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,17 @@ +0.0.2a3 +======= + +New Features: +------------- + +- Add nodes for non-adiabatic coupling vectors (NACR) and phase-less loss. + See /examples/excited_states_azomethane.py. + +Improvements +------------ + +- Multi-target dipole node now has a shape of (n_molecules, n_targets, 3). + 0.0.2a2 ======= diff --git a/docs/source/examples/excited_states.rst b/docs/source/examples/excited_states.rst new file mode 100644 index 00000000..7a1e7a58 --- /dev/null +++ b/docs/source/examples/excited_states.rst @@ -0,0 +1,72 @@ +Non-Adiabiatic Excited States +============================= + +`hippynn` has features for training to excited-state energies, transition dipoles, and +the non-adiabatic coupling vectors (NACR). These features can be found in +:mod:`~hippynn.graphs.nodes.excited`. + +For a more detailed description, please see the paper [Li2023]_ + +Multi-targets nodes are recommended over the usage of one node per target. + +For energies, the node can be constructed just like the ground-state +counterpart:: + + energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1}) + mol_energy = energy.mol_energy + mol_energy.db_name = "E" + +Note that a `multi-target node` is used here, defined by the keyword +``module_kwargs={"n_target": n_states + 1}``. Here, `n_states` is the number of +*excited* states in consideration. The extra state is for the ground state, which is often +useful. The database name is simply `E` with a shape of ``(n_molecules, +n_states+1)``. + +Predicting the transition dipoles is also similar to the ground-state permanent +dipole:: + + charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states}) + dipole = physics.DipoleNode("D", (charge, positions), db_name="D") + +The database name is `D` with a shape of ``(n_molecules, n_states, 3)``. + +For NACR, to avoid singularity problems, we enforcing the training of NACR*ΔE +instead:: + + nacr = excited.NACRMultiStateNode( + "ScaledNACR", + (charge, positions, energy), + db_name="ScaledNACR", + module_kwargs={"n_target": n_states}, + ) + +For NACR between state `i` and `j`, :math:`\boldsymbol{d}_{ij}`, it is expressed +in the following way + +.. math:: + \boldsymbol{d}_{ij}\Delta E_{ij} = \Delta E_{ij}\boldsymbol{q}_i \frac{\partial\boldsymbol{q}_j}{\partial\boldsymbol{R}} + +:math:`E_{ij}` is energy difference between state `i` and `j`, which is +calculated internally in the NACR node based on the input of the ``energy`` +node. :math:`\boldsymbol{R}` corresponding the ``positions`` node in the code. +:math:`\boldsymbol{q}_{i}` and :math:`\boldsymbol{q}_{j}` are the transition +atomic charges for state `i` and `j` contained in the ``charge`` node. This +charge node can be constructed from scratch or reused from the dipole +predictions. The database name is `ScaledNACR` with a shape of ``(n_molecules, +n_states*(n_states-1)/2, 3*n_atoms)``. + +Due to the phase problem, when the loss function is constructed, the +`phase-less` version of MAE or RMSE should be used:: + + energy_mae = loss.MAELoss.of_node(energy) + dipole_mae = excited.MAEPhaseLoss.of_node(dipole) + nacr_mae = excited.MAEPhaseLoss.of_node(nacr) + +:class:`~hippynn.graphs.nodes.excited.MAEPhaseLoss` and +:class:`~hippynn.graphs.nodes.excited.MSEPhaseLoss` are the `phase-less` version MAE +and MSE, which take the minimum error over the possible signs of the output. + +For a complete script, please take a look at ``examples/excited_states_azomethane.py``. + +.. [Li2023] | Machine Learning Framework for Modeling Exciton-Polaritons in Molecular Materials. + | Li et. al, 2023. https://arxiv.org/abs/2306.02523 diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index f7b4ba50..548b884b 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -18,4 +18,5 @@ the examples are just snippets. For fully-fledged examples see the restarting ase_calculator mliap_unified + excited_states diff --git a/examples/excited_states_azomethane.py b/examples/excited_states_azomethane.py new file mode 100644 index 00000000..44b3daa7 --- /dev/null +++ b/examples/excited_states_azomethane.py @@ -0,0 +1,188 @@ +""" + +Example training script to predicted excited-states energies, transition dipoles, and +non-adiabatic coupling vectors (NACR) + +The dataset used in this example can be found at https://doi.org/10.5281/zenodo.7076420. + +This script is set up to assume the "release" folder from the zenodo record + is placed in ../../datasets/azomethane/ relative to this script. + +For more information on the modeling techniques, please see the paper: +Machine Learning Framework for Modeling Exciton-Polaritons in Molecular Materials +Li, et al. (2023) +https://arxiv.org/abs/2306.02523 + +""" +import json + +import matplotlib +import numpy as np +import torch + +import hippynn +from hippynn import plotting +from hippynn.experiment import setup_training, train_model +from hippynn.experiment.controllers import PatienceController, RaiseBatchSizeOnPlateau +from hippynn.graphs import inputs, loss, networks, physics, targets, excited + +matplotlib.use("Agg") +# default types for torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_default_dtype(torch.float32) + +hippynn.settings.WARN_LOW_DISTANCES = False +hippynn.settings.TRANSPARENT_PLOT = True + +n_atoms = 10 +n_states = 3 +plot_frequency = 100 +dipole_weight = 4 +nacr_weight = 2 +l2_weight = 2e-5 + +# Hyperparameters for the network +# Note: These hyperparameters were generated via +# a tuning algorithm, hence their somewhat arbitrary nature. +network_params = { + "possible_species": [0, 1, 6, 7], + "n_features": 30, + "n_sensitivities": 28, + "dist_soft_min": 0.7665723566179274, + "dist_soft_max": 3.4134447177301515, + "dist_hard_max": 4.6860240434651805, + "n_interaction_layers": 3, + "n_atom_layers": 3, +} +# dump parameters to the log file +print("Network parameters\n\n", json.dumps(network_params, indent=4)) + +with hippynn.tools.active_directory("TEST_AZOMETHANE_MODEL"): + with hippynn.tools.log_terminal("training_log.txt", "wt"): + # build network + species = inputs.SpeciesNode(db_name="Z") + positions = inputs.PositionsNode(db_name="R") + network = networks.Hipnn("hipnn_model", (species, positions), module_kwargs=network_params) + # add energy + energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1}) + mol_energy = energy.mol_energy + mol_energy.db_name = "E" + # add dipole + charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states}) + dipole = physics.DipoleNode("D", (charge, positions), db_name="D") + # add NACR + nacr = excited.NACRMultiStateNode( + "ScaledNACR", + (charge, positions, energy), + db_name="ScaledNACR", + module_kwargs={"n_target": n_states}, + ) + # set up plotter + plotter = [] + for node in [mol_energy, dipole, nacr]: + plotter.append(plotting.Hist2D.compare(node, saved=True, shown=False)) + for i in range(network_params["n_interaction_layers"]): + plotter.append( + plotting.SensitivityPlot( + network.torch_module.sensitivity_layers[i], + saved=f"Sensitivity_{i}.pdf", + shown=False, + ) + ) + plotter = plotting.PlotMaker(*plotter, plot_every=plot_frequency) + # build the loss function + validation_losses = {} + # energy + energy_rmse = loss.MSELoss.of_node(energy) ** 0.5 + validation_losses["E-RMSE"] = energy_rmse + energy_mae = loss.MAELoss.of_node(energy) + validation_losses["E-MAE"] = energy_mae + energy_loss = energy_rmse + energy_mae + validation_losses["E-Loss"] = energy_loss + total_loss = energy_loss + # dipole + dipole_rmse = excited.MSEPhaseLoss.of_node(dipole) ** 0.5 + validation_losses["D-RMSE"] = dipole_rmse + dipole_mae = excited.MAEPhaseLoss.of_node(dipole) + validation_losses["D-MAE"] = dipole_mae + dipole_loss = dipole_rmse / np.sqrt(3) + dipole_mae + validation_losses["D-Loss"] = dipole_loss + total_loss += dipole_weight * dipole_loss + # nacr + nacr_rmse = excited.MSEPhaseLoss.of_node(nacr) ** 0.5 + validation_losses["NACR-RMSE"] = nacr_rmse + nacr_mae = excited.MAEPhaseLoss.of_node(nacr) + validation_losses["NACR-MAE"] = nacr_mae + nacr_loss = nacr_rmse / np.sqrt(3 * n_atoms) + nacr_mae + validation_losses["NACR-Loss"] = nacr_loss + total_loss += nacr_weight * nacr_loss + # l2 regularization + l2_reg = loss.l2reg(network) + validation_losses["L2"] = l2_reg + loss_regularization = l2_weight * l2_reg + # add total loss to the dictionary + validation_losses["Loss_wo_L2"] = total_loss + validation_losses["Loss"] = total_loss + loss_regularization + + # set up experiment + training_modules, db_info = hippynn.experiment.assemble_for_training( + validation_losses["Loss"], + validation_losses, + plot_maker=plotter, + ) + # set up the optimizer + optimizer = torch.optim.AdamW(training_modules.model.parameters(), lr=1e-3) + # use higher patience for production runs + scheduler = RaiseBatchSizeOnPlateau(optimizer=optimizer, max_batch_size=2048, patience=10, factor=0.5) + controller = PatienceController( + optimizer=optimizer, + scheduler=scheduler, + batch_size=32, + eval_batch_size=2048, + # use higher max_epochs for production runs + max_epochs=100, + stopping_key="Loss", + fraction_train_eval=0.1, + # use higher termination_patience for production runs + termination_patience=10, + ) + experiment_params = hippynn.experiment.SetupParams(controller=controller) + + # load database + database = hippynn.databases.DirectoryDatabase( + name="azo_", # Prefix for arrays in the directory + directory="../../../datasets/azomethane/release/training/", + seed=114514, # Random seed for splitting data + **db_info, # Adds the inputs and targets db_names from the model as things to load + ) + # use 10% of the dataset just for quick testing purpose + database.make_random_split("train", 0.07) + database.make_random_split("valid", 0.02) + database.make_random_split("test", 0.01) + database.splitting_completed = True + # split the whole dataset into train, valid, test in the ratio of 7:2:1 + # database.make_trainvalidtest_split(0.1, 0.2) + + # set up training + training_modules, controller, metric_tracker = setup_training( + training_modules=training_modules, + setup_params=experiment_params, + ) + # train model + metric_tracker = train_model( + training_modules, + database, + controller, + metric_tracker, + callbacks=None, + batch_callbacks=None, + ) + + del network_params["possible_species"] + network_params["metric"] = metric_tracker.best_metric_values + network_params["avg_epoch_time"] = np.average(metric_tracker.epoch_times) + network_params["Loss"] = metric_tracker.best_metric_values["valid"]["Loss"] + + with open("training_summary.json", "w") as out: + json.dump(network_params, out, indent=4) diff --git a/hippynn/graphs/__init__.py b/hippynn/graphs/__init__.py index eed24245..4fc3bd66 100644 --- a/hippynn/graphs/__init__.py +++ b/hippynn/graphs/__init__.py @@ -11,11 +11,13 @@ from . import indextypes from .indextypes import clear_index_cache, IdxType -from .nodes import base, inputs, networks, targets, physics, loss +from .nodes import base, inputs from .nodes.base import find_unique_relative, find_relatives, get_connected_nodes from .gops import get_subgraph, copy_subgraph, replace_node, compute_evaluation_order +from .nodes import networks, targets, physics, loss, excited + # Needed to populate the registry of index transformers. # This has to happen before the indextypes package can work, # however, we don't want the indextypes package to depend on actual diff --git a/hippynn/graphs/nodes/excited.py b/hippynn/graphs/nodes/excited.py new file mode 100644 index 00000000..142ba6fc --- /dev/null +++ b/hippynn/graphs/nodes/excited.py @@ -0,0 +1,169 @@ +from typing import Tuple +import torch + +from ...layers import excited as excited_layers +from .. import IdxType, find_unique_relative +from .base import AutoKw, SingleNode, ExpandParents, MultiNode +from .loss import _BaseCompareLoss +from .tags import Energies, HAtomRegressor, Network, AtomIndexer +from ...layers import physics as physics_layers + + +class NACRNode(AutoKw, SingleNode): + """ + Compute the non-adiabatic coupling vector multiplied by the energy difference + between two states. + """ + + _input_names = "charges i", "charges j", "coordinates", "energy i", "energy j" + _auto_module_class = excited_layers.NACR + + def __init__(self, name: str, parents: Tuple, module="auto", module_kwargs=None, **kwargs): + """Automatically build the node for calculating NACR * ΔE between two states i + and j. + + :param name: name of the node + :type name: str + :param parents: parents of the NACR node in the sequence of (charges i, \ + charges j, positions, energy i, energy j) + :type parents: Tuple + :param module: _description_, defaults to "auto" + :type module: str, optional + :param module_kwargs: keyword arguments passed to the corresponding layer, + defaults to None + :type module_kwargs: dict, optional + """ + + self.module_kwargs = {} + if module_kwargs is not None: + self.module_kwargs.update(module_kwargs) + charges1, charges2, positions, energy1, energy2 = parents + positions.requires_grad = True + self._index_state = IdxType.Molecules + # self._index_state = positions._index_state + parents = ( + charges1.main_output, + charges2.main_output, + positions, + energy1.main_output, + energy2.main_output, + ) + super().__init__(name, parents, module=module, **kwargs) + + +class NACRMultiStateNode(AutoKw, SingleNode): + """ + Compute the non-adiabatic coupling vector multiplied by the energy difference + between all pairs of states. + """ + + _input_names = "charges", "coordinates", "energies" + _auto_module_class = excited_layers.NACRMultiState + + def __init__(self, name, parents, module="auto", module_kwargs=None, **kwargs): + """Automatically build the node for calculating NACR * ΔE between all pairs of + states. + + :param name: name of the node + :type name: str + :param parents: parents of the NACR node in the sequence of (charges, \ + positions, energies) + :type parents: Tuple + :param module: _description_, defaults to "auto" + :type module: str, optional + :param module_kwargs: keyword arguments passed to the corresponding layer, + defaults to None + :type module_kwargs: dict, optional + """ + + self.module_kwargs = {} + if module_kwargs is not None: + self.module_kwargs.update(module_kwargs) + charges, positions, energies = parents + positions.requires_grad = True + self._index_state = IdxType.Molecules + # self._index_state = positions._index_state + parents = ( + charges.main_output, + positions, + energies.main_output, + ) + super().__init__(name, parents, module=module, **kwargs) + + +class LocalEnergyNode(Energies, ExpandParents, HAtomRegressor, MultiNode): + """ + Predict a localized energy, with contributions from implicitly computed atoms. + """ + + _input_names = "hier_features", "mol_index", "atom index", "n_molecules", "n_atoms_max" + _output_names = "mol_energy", "atom_energy", "atom_preenergy", "atom_probabilities", "atom_propensities" + _main_output = "mol_energy" + _output_index_states = IdxType.Molecules, IdxType.Atoms, IdxType.Atoms, IdxType.Atoms, IdxType.Atoms + _auto_module_class = excited_layers.LocalEnergy + + @_parent_expander.match(Network) + def expansion0(self, net, *, purpose, **kwargs): + pdindexer = find_unique_relative(net, AtomIndexer, why_desc=purpose) + return net, pdindexer + + @_parent_expander.match(Network, AtomIndexer) + def expansion1(self, net, pdindexer, **kwargs): + return net, pdindexer.mol_index, pdindexer.atom_index, pdindexer.n_molecules, pdindexer.n_atoms_max + + _parent_expander.assertlen(5) + + def __init__(self, name, parents, first_is_interacting=False, module="auto", **kwargs): + parents = self.expand_parents(parents) + self.module_kwargs = {"first_is_interacting": first_is_interacting} + super().__init__(name, parents, module=module, **kwargs) + + def auto_module(self): + network = find_unique_relative(self, Network).torch_module + return self._auto_module_class(network.feature_sizes, **self.module_kwargs) + + +def _mae_with_phases(predict: torch.Tensor, true: torch.Tensor): + """MAE with phases + + :param predict: predicted values + :type predict: torch.Tensor + :param true: true values + :type true: torch.Tensor + :return: MAE with phases + :rtype: torch.Tensor + """ + + errors = torch.minimum( + torch.linalg.norm(true - predict, ord=1, dim=-1), + torch.linalg.norm(true + predict, ord=1, dim=-1), + ) + # errors = absolute_errors(predict, true) + return torch.sum(errors) / predict.numel() + + +def _mse_with_phases(predict: torch.Tensor, true: torch.Tensor): + """MSE with phases + + :param predict: predicted values + :type predict: torch.Tensor + :param true: true values + :type true: torch.Tensor + :return: MSE with phases + :rtype: torch.Tensor + """ + + errors = torch.minimum( + torch.linalg.norm(true - predict, dim=-1), + torch.linalg.norm(true + predict, dim=-1), + ) + # errors = absolute_errors(predict, true) ** 2 + return torch.sum(errors**2) / predict.numel() + + +class MAEPhaseLoss(_BaseCompareLoss, op=_mae_with_phases): + pass + + +class MSEPhaseLoss(_BaseCompareLoss, op=_mse_with_phases): + pass diff --git a/hippynn/graphs/nodes/loss.py b/hippynn/graphs/nodes/loss.py index 2320df1d..e11ade2b 100644 --- a/hippynn/graphs/nodes/loss.py +++ b/hippynn/graphs/nodes/loss.py @@ -134,3 +134,21 @@ def l2reg(network): def l1reg(network): return lpreg(network, p=1) + +# For loss functions with phases +def absolute_errors(predict: torch.Tensor, true: torch.Tensor): + """Compute the absolute errors with phases between predicted and true values. In + other words, prediction should be close to the absolute value of true, and the sign + does not matter. + + :param predict: predicted values + :type predict: torch.Tensor + :param true: true values + :type true: torch.Tensor + :return: errors + :rtype: torch.Tensor + """ + + return torch.minimum(torch.abs(true - predict), torch.abs(true + predict)) + + diff --git a/hippynn/graphs/nodes/physics.py b/hippynn/graphs/nodes/physics.py index 28cfae6e..55ba9bba 100644 --- a/hippynn/graphs/nodes/physics.py +++ b/hippynn/graphs/nodes/physics.py @@ -3,17 +3,24 @@ """ import warnings -from .base import SingleNode, MultiNode, AutoNoKw, AutoKw, ExpandParents, find_unique_relative, _BaseNode +from ...layers import indexers as index_layers +from ...layers import pairs as pair_layers +from ...layers import physics as physics_layers +from ..indextypes import IdxType, elementwise_compare_reduce, index_type_coercion +from .base import ( + AutoKw, + AutoNoKw, + ExpandParents, + MultiNode, + SingleNode, + _BaseNode, + find_unique_relative, +) from .base.node_functions import NodeNotFound from .indexers import AtomIndexer, PaddingIndexer, acquire_encoding_padding -from .pairs import OpenPairIndexer -from .tags import Encoder, PairIndexer, Charges, Energies from .inputs import PositionsNode, SpeciesNode - -from ..indextypes import IdxType, index_type_coercion, elementwise_compare_reduce -from ...layers import indexers as index_layers -from ...layers import physics as physics_layers -from ...layers import pairs as pair_layers +from .pairs import OpenPairIndexer +from .tags import Charges, Encoder, Energies, PairIndexer class GradientNode(AutoKw, SingleNode): @@ -264,7 +271,6 @@ def __init__(self, name, parents, module="auto", **kwargs): # TODO: This seems broken for parent expanders, check the signature of the layer. class BondToMolSummmer(ExpandParents, AutoNoKw, SingleNode): - _input_names = "pairfeatures", "mol_index", "n_molecules", "pair_first" _auto_module_class = pair_layers.MolPairSummer _index_state = IdxType.Molecules @@ -310,17 +316,20 @@ def __init__(self, name, parents, module="auto", **kwargs): super().__init__(name, parents, module=module, **kwargs) - class CombineEnergyNode(Energies, AutoKw, ExpandParents, MultiNode): """ - Combines Local atom energies from different Energy Nodes. + Combines Local atom energies from different Energy Nodes. """ + _input_names = "input_atom_energy_1", "input_atom_energy_2", "mol_index", "n_molecules" _output_names = "mol_energy", "atom_energies" _main_output = "mol_energy" - _output_index_states = IdxType.Molecules, IdxType.Atoms, + _output_index_states = ( + IdxType.Molecules, + IdxType.Atoms, + ) _auto_module_class = physics_layers.CombineEnergy - + @_parent_expander.match(_BaseNode, Energies) def expansion0(self, energy_1, energy_2, **kwargs): return energy_1, energy_2.atom_energies @@ -345,4 +354,3 @@ def __init__(self, name, parents, module="auto", module_kwargs=None, **kwargs): self.module_kwargs = {} if module_kwargs is None else module_kwargs parents = self.expand_parents(parents, **kwargs) super().__init__(name, parents=parents, module=module, **kwargs) - diff --git a/hippynn/graphs/nodes/targets.py b/hippynn/graphs/nodes/targets.py index b666e97d..adb86a50 100644 --- a/hippynn/graphs/nodes/targets.py +++ b/hippynn/graphs/nodes/targets.py @@ -1,8 +1,8 @@ """ Nodes for prediction of variables from network features. """ + from .base import MultiNode, AutoKw, ExpandParents, find_unique_relative, _BaseNode -from .indexers import PaddingIndexer from .tags import AtomIndexer, Network, PairIndexer, HAtomRegressor, Charges, Energies from .indexers import PaddingIndexer from ..indextypes import IdxType, index_type_coercion @@ -102,33 +102,3 @@ def __init__(self, name, parents, module="auto", module_kwargs=None, **kwargs): super().__init__(name, parents, module=module, **kwargs) -class LocalEnergyNode(Energies, ExpandParents, HAtomRegressor, MultiNode): - """ - Predict a localized energy, with contributions from implicitly computed atoms. - """ - - _input_names = "hier_features", "mol_index", "atom index", "n_molecules", "n_atoms_max" - _output_names = "mol_energy", "atom_energy", "atom_preenergy", "atom_probabilities", "atom_propensities" - _main_output = "mol_energy" - _output_index_states = IdxType.Molecules, IdxType.Atoms, IdxType.Atoms, IdxType.Atoms, IdxType.Atoms - _auto_module_class = target_modules.LocalEnergy - - @_parent_expander.match(Network) - def expansion0(self, net, *, purpose, **kwargs): - pdindexer = find_unique_relative(net, AtomIndexer, why_desc=purpose) - return net, pdindexer - - @_parent_expander.match(Network, AtomIndexer) - def expansion1(self, net, pdindexer, **kwargs): - return net, pdindexer.mol_index, pdindexer.atom_index, pdindexer.n_molecules, pdindexer.n_atoms_max - - _parent_expander.assertlen(5) - - def __init__(self, name, parents, first_is_interacting=False, module="auto", **kwargs): - parents = self.expand_parents(parents) - self.module_kwargs = {"first_is_interacting": first_is_interacting} - super().__init__(name, parents, module=module, **kwargs) - - def auto_module(self): - network = find_unique_relative(self, Network).torch_module - return self._auto_module_class(network.feature_sizes, **self.module_kwargs) diff --git a/hippynn/layers/__init__.py b/hippynn/layers/__init__.py index 58d549e0..c2180c26 100644 --- a/hippynn/layers/__init__.py +++ b/hippynn/layers/__init__.py @@ -6,3 +6,4 @@ from . import targets from . import transform from . import physics +from . import excited \ No newline at end of file diff --git a/hippynn/layers/excited.py b/hippynn/layers/excited.py new file mode 100644 index 00000000..8096e743 --- /dev/null +++ b/hippynn/layers/excited.py @@ -0,0 +1,128 @@ +import torch +from . import indexers +from torch import Tensor + + +class NACR(torch.nn.Module): + """ + Compute NAC vector * ΔE. Originally in hippynn.layers.physics. + """ + + def __init__(self): + super().__init__() + + def forward( + self, + charges1: Tensor, + charges2: Tensor, + positions: Tensor, + energy1: Tensor, + energy2: Tensor, + ): + dE = energy2 - energy1 + nacr = torch.autograd.grad( + charges2, [positions], grad_outputs=[charges1], create_graph=True + )[0].reshape(len(dE), -1) + return nacr * dE + + +class NACRMultiState(torch.nn.Module): + """ + Compute NAC vector * ΔE for all paris of states. Originally in hippynn.layers.physics. + """ + + def __init__(self, n_target=1): + self.n_target = n_target + super().__init__() + + def forward(self, charges: Tensor, positions: Tensor, energies: Tensor): + # charges shape: n_molecules, n_atoms, n_targets + # positions shape: n_molecules, n_atoms, 3 + # energies shape: n_molecules, n_targets + # dE shape: n_molecules, n_targets, n_targets + dE = energies.unsqueeze(1) - energies.unsqueeze(2) + # take the upper triangle excluding the diagonal + indices = torch.triu_indices( + self.n_target, self.n_target, offset=1, device=dE.device + ) + # dE shape: n_molecules, n_pairs + # n_pairs = n_targets * (n_targets - 1) / 2 + dE = dE[..., indices[0], indices[1]] + # compute q1 * dq2/dR + nacr_ij = [] + for i, j in zip(*indices): + nacr = torch.autograd.grad( + charges[..., j], + positions, + grad_outputs=charges[..., i], + create_graph=True, + )[0] + nacr_ij.append(nacr) + # nacr shape: n_molecules, n_atoms, 3, n_pairs + nacr = torch.stack(nacr_ij, dim=1) + n_molecule, n_pairs, n_atoms, n_dims = nacr.shape + nacr = nacr.reshape(n_molecule, n_pairs, n_atoms * n_dims) + # multiply dE + return nacr * dE.unsqueeze(2) + + +class LocalEnergy(torch.nn.Module): + def __init__(self, feature_sizes, first_is_interacting=False): + + super().__init__() + self.first_is_interacting = first_is_interacting + if first_is_interacting: + feature_sizes = feature_sizes[1:] + + self.feature_sizes = feature_sizes + + self.summer = indexers.MolSummer() + self.n_terms = len(feature_sizes) + biases = (first_is_interacting, *(True for _ in range(self.n_terms - 1))) + + self.layers = torch.nn.ModuleList(torch.nn.Linear(nf, 1, bias=bias) for nf, bias in zip(feature_sizes, biases)) + self.players = torch.nn.ModuleList(torch.nn.Linear(nf, 1, bias=False) for nf in feature_sizes) + self.ninf = float("-inf") + + def forward(self, all_features, mol_index, atom_index, n_molecules, n_atoms_max): + """ + :param all_features: list of feature tensors + :param mol_index: which molecule is the atom + :param atom_index: which atom in the molecule is that atom + :param n_molecules: total number of molecules in the batch + :param n_atoms_max: maximum number of atoms in the batch + :return: contributed_energy, atom_energy, atom_preenergy, prob, propensity + """ + + if self.first_is_interacting: + all_features = all_features[1:] + + partial_preenergy = [lay(x) for x, lay in zip(all_features, self.layers)] + atom_preenergy = sum(partial_preenergy) + partial_potentials = [lay(x) for x, lay in zip(all_features, self.players)] + propensity = sum(partial_potentials) # Keep in mind that this has shape (natoms,1) + + # This segment does not need gradients, we are constructing the subtraction parameters for softmax + # which results in a calculation that does not under or overflow; the result is most accurate this way + # But actually invariant to the subtraction used, so it does not require a grad. + # It's a standard SoftMax technique, however, the implementation is not built into pytorch for + # the molecule/atom framework. + with torch.autograd.no_grad(): + propensity_molatom = all_features[0].new_full((n_molecules, n_atoms_max, 1), self.ninf) + propensity_molatom[mol_index, atom_index] = propensity + propensity_norms = propensity_molatom.max(dim=1)[0] # first element is max vals, 2nd is max position + propensity_norm_atoms = propensity_norms[mol_index] + + propensity_normed = propensity - propensity_norm_atoms + + # Calculate probabilities with molecule version of softmax + relative_prob = torch.exp(propensity_normed) + z_factor_permol = self.summer(relative_prob, mol_index, n_molecules) + atom_zfactor = z_factor_permol[mol_index] + prob = relative_prob / atom_zfactor + + # Find molecular sum + atom_energy = prob * atom_preenergy + contributed_energy = self.summer(atom_energy, mol_index, n_molecules) + + return contributed_energy, atom_energy, atom_preenergy, prob, propensity diff --git a/hippynn/layers/physics.py b/hippynn/layers/physics.py index fc5ef09e..3debf31d 100644 --- a/hippynn/layers/physics.py +++ b/hippynn/layers/physics.py @@ -6,8 +6,7 @@ import torch from torch import Tensor -from . import pairs -from . import indexers +from . import indexers, pairs class Gradient(torch.nn.Module): @@ -43,8 +42,8 @@ def __init__(self): def forward(self, charges: Tensor, positions: Tensor, mol_index: Tensor, n_molecules: int): if charges.shape[1] > 1: # charges contain multiple targets, so set up broadcasting - charges = charges.unsqueeze(1) - positions = positions.unsqueeze(2) + charges = charges.unsqueeze(2) + positions = positions.unsqueeze(1) # shape is (n_atoms, 3, n_targets) in multi-target mode # shape is (n_atoms, 3) in single target mode @@ -248,9 +247,7 @@ def forward(self, features, species): class VecMag(torch.nn.Module): def forward(self, vector_feature): - return torch.norm(vector_feature, dim=1) - - + return torch.norm(vector_feature, dim=1).unsqueeze(1) class CombineEnergy(torch.nn.Module): diff --git a/hippynn/layers/targets.py b/hippynn/layers/targets.py index 98077c74..f8d262da 100644 --- a/hippynn/layers/targets.py +++ b/hippynn/layers/targets.py @@ -235,63 +235,3 @@ def forward(self, all_features, pair_first, pair_second, pair_dist): return total_bonds, bond_hier -class LocalEnergy(torch.nn.Module): - def __init__(self, feature_sizes, first_is_interacting=False): - - super().__init__() - self.first_is_interacting = first_is_interacting - if first_is_interacting: - feature_sizes = feature_sizes[1:] - - self.feature_sizes = feature_sizes - - self.summer = indexers.MolSummer() - self.n_terms = len(feature_sizes) - biases = (first_is_interacting, *(True for _ in range(self.n_terms - 1))) - - self.layers = torch.nn.ModuleList(torch.nn.Linear(nf, 1, bias=bias) for nf, bias in zip(feature_sizes, biases)) - self.players = torch.nn.ModuleList(torch.nn.Linear(nf, 1, bias=False) for nf in feature_sizes) - self.ninf = float("-inf") - - def forward(self, all_features, mol_index, atom_index, n_molecules, n_atoms_max): - """ - :param all_features: list of feature tensors - :param mol_index: which molecule is the atom - :param atom_index: which atom in the molecule is that atom - :param n_molecules: total number of molecules in the batch - :param n_atoms_max: maximum number of atoms in the batch - :return: contributed_energy, atom_energy, atom_preenergy, prob, propensity - """ - - if self.first_is_interacting: - all_features = all_features[1:] - - partial_preenergy = [lay(x) for x, lay in zip(all_features, self.layers)] - atom_preenergy = sum(partial_preenergy) - partial_potentials = [lay(x) for x, lay in zip(all_features, self.players)] - propensity = sum(partial_potentials) # Keep in mind that this has shape (natoms,1) - - # This segment does not need gradients, we are constructing the subtraction parameters for softmax - # which results in a calculation that does not under or overflow; the result is most accurate this way - # But actually invariant to the subtraction used, so it does not require a grad. - # It's a standard SoftMax technique, however, the implementation is not built into pytorch for - # the molecule/atom framework. - with torch.autograd.no_grad(): - propensity_molatom = all_features[0].new_full((n_molecules, n_atoms_max, 1), self.ninf) - propensity_molatom[mol_index, atom_index] = propensity - propensity_norms = propensity_molatom.max(dim=1)[0] # first element is max vals, 2nd is max position - propensity_norm_atoms = propensity_norms[mol_index] - - propensity_normed = propensity - propensity_norm_atoms - - # Calculate probabilities with molecule version of softmax - relative_prob = torch.exp(propensity_normed) - z_factor_permol = self.summer(relative_prob, mol_index, n_molecules) - atom_zfactor = z_factor_permol[mol_index] - prob = relative_prob / atom_zfactor - - # Find molecular sum - atom_energy = prob * atom_preenergy - contributed_energy = self.summer(atom_energy, mol_index, n_molecules) - - return contributed_energy, atom_energy, atom_preenergy, prob, propensity