-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Excited state nodes, docs, and an example (#42)
* Add nodes for NACR and phase-less loss * NACR node now does NACR_ij * ΔE_ij * dE should be E_j - E_i * remove parent expansion * Update NACR implementation 1. Fix a bug where the charges tensor is incorrectly sliced. 2. Add corresponding tests to NACR layers. * Fix a bug on calculating the phase-less loss When the predicted vector is not in the same quadrant as the true value, the calculated loss would be smaller than the correct one. This might cause training process slower or stuck. * Fix a bug in new MSE implementation * Fix a bug in setting up custom kernels Now "auto" will use pytorch when numba or cupy is not installed * Add nodes and example for excited states * Update changelog and rename example file * update example for excited states * move excited states into subdirectories * fix import order and make excited states import by default * update documentation * documentation update --------- Co-authored-by: Nicholas Lubbers <[email protected]>
- Loading branch information
Showing
14 changed files
with
622 additions
and
114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,4 @@ | |
__pycache__/ | ||
*.pyc | ||
build/ | ||
hippynn.egg-info/* | ||
hippynn.egg-info/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
Non-Adiabiatic Excited States | ||
============================= | ||
|
||
`hippynn` has features for training to excited-state energies, transition dipoles, and | ||
the non-adiabatic coupling vectors (NACR). These features can be found in | ||
:mod:`~hippynn.graphs.nodes.excited`. | ||
|
||
For a more detailed description, please see the paper [Li2023]_ | ||
|
||
Multi-targets nodes are recommended over the usage of one node per target. | ||
|
||
For energies, the node can be constructed just like the ground-state | ||
counterpart:: | ||
|
||
energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1}) | ||
mol_energy = energy.mol_energy | ||
mol_energy.db_name = "E" | ||
|
||
Note that a `multi-target node` is used here, defined by the keyword | ||
``module_kwargs={"n_target": n_states + 1}``. Here, `n_states` is the number of | ||
*excited* states in consideration. The extra state is for the ground state, which is often | ||
useful. The database name is simply `E` with a shape of ``(n_molecules, | ||
n_states+1)``. | ||
|
||
Predicting the transition dipoles is also similar to the ground-state permanent | ||
dipole:: | ||
|
||
charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states}) | ||
dipole = physics.DipoleNode("D", (charge, positions), db_name="D") | ||
|
||
The database name is `D` with a shape of ``(n_molecules, n_states, 3)``. | ||
|
||
For NACR, to avoid singularity problems, we enforcing the training of NACR*ΔE | ||
instead:: | ||
|
||
nacr = excited.NACRMultiStateNode( | ||
"ScaledNACR", | ||
(charge, positions, energy), | ||
db_name="ScaledNACR", | ||
module_kwargs={"n_target": n_states}, | ||
) | ||
|
||
For NACR between state `i` and `j`, :math:`\boldsymbol{d}_{ij}`, it is expressed | ||
in the following way | ||
|
||
.. math:: | ||
\boldsymbol{d}_{ij}\Delta E_{ij} = \Delta E_{ij}\boldsymbol{q}_i \frac{\partial\boldsymbol{q}_j}{\partial\boldsymbol{R}} | ||
:math:`E_{ij}` is energy difference between state `i` and `j`, which is | ||
calculated internally in the NACR node based on the input of the ``energy`` | ||
node. :math:`\boldsymbol{R}` corresponding the ``positions`` node in the code. | ||
:math:`\boldsymbol{q}_{i}` and :math:`\boldsymbol{q}_{j}` are the transition | ||
atomic charges for state `i` and `j` contained in the ``charge`` node. This | ||
charge node can be constructed from scratch or reused from the dipole | ||
predictions. The database name is `ScaledNACR` with a shape of ``(n_molecules, | ||
n_states*(n_states-1)/2, 3*n_atoms)``. | ||
|
||
Due to the phase problem, when the loss function is constructed, the | ||
`phase-less` version of MAE or RMSE should be used:: | ||
|
||
energy_mae = loss.MAELoss.of_node(energy) | ||
dipole_mae = excited.MAEPhaseLoss.of_node(dipole) | ||
nacr_mae = excited.MAEPhaseLoss.of_node(nacr) | ||
|
||
:class:`~hippynn.graphs.nodes.excited.MAEPhaseLoss` and | ||
:class:`~hippynn.graphs.nodes.excited.MSEPhaseLoss` are the `phase-less` version MAE | ||
and MSE, which take the minimum error over the possible signs of the output. | ||
|
||
For a complete script, please take a look at ``examples/excited_states_azomethane.py``. | ||
|
||
.. [Li2023] | Machine Learning Framework for Modeling Exciton-Polaritons in Molecular Materials. | ||
| Li et. al, 2023. https://arxiv.org/abs/2306.02523 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
""" | ||
Example training script to predicted excited-states energies, transition dipoles, and | ||
non-adiabatic coupling vectors (NACR) | ||
The dataset used in this example can be found at https://doi.org/10.5281/zenodo.7076420. | ||
This script is set up to assume the "release" folder from the zenodo record | ||
is placed in ../../datasets/azomethane/ relative to this script. | ||
For more information on the modeling techniques, please see the paper: | ||
Machine Learning Framework for Modeling Exciton-Polaritons in Molecular Materials | ||
Li, et al. (2023) | ||
https://arxiv.org/abs/2306.02523 | ||
""" | ||
import json | ||
|
||
import matplotlib | ||
import numpy as np | ||
import torch | ||
|
||
import hippynn | ||
from hippynn import plotting | ||
from hippynn.experiment import setup_training, train_model | ||
from hippynn.experiment.controllers import PatienceController, RaiseBatchSizeOnPlateau | ||
from hippynn.graphs import inputs, loss, networks, physics, targets, excited | ||
|
||
matplotlib.use("Agg") | ||
# default types for torch | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
torch.set_default_dtype(torch.float32) | ||
|
||
hippynn.settings.WARN_LOW_DISTANCES = False | ||
hippynn.settings.TRANSPARENT_PLOT = True | ||
|
||
n_atoms = 10 | ||
n_states = 3 | ||
plot_frequency = 100 | ||
dipole_weight = 4 | ||
nacr_weight = 2 | ||
l2_weight = 2e-5 | ||
|
||
# Hyperparameters for the network | ||
# Note: These hyperparameters were generated via | ||
# a tuning algorithm, hence their somewhat arbitrary nature. | ||
network_params = { | ||
"possible_species": [0, 1, 6, 7], | ||
"n_features": 30, | ||
"n_sensitivities": 28, | ||
"dist_soft_min": 0.7665723566179274, | ||
"dist_soft_max": 3.4134447177301515, | ||
"dist_hard_max": 4.6860240434651805, | ||
"n_interaction_layers": 3, | ||
"n_atom_layers": 3, | ||
} | ||
# dump parameters to the log file | ||
print("Network parameters\n\n", json.dumps(network_params, indent=4)) | ||
|
||
with hippynn.tools.active_directory("TEST_AZOMETHANE_MODEL"): | ||
with hippynn.tools.log_terminal("training_log.txt", "wt"): | ||
# build network | ||
species = inputs.SpeciesNode(db_name="Z") | ||
positions = inputs.PositionsNode(db_name="R") | ||
network = networks.Hipnn("hipnn_model", (species, positions), module_kwargs=network_params) | ||
# add energy | ||
energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1}) | ||
mol_energy = energy.mol_energy | ||
mol_energy.db_name = "E" | ||
# add dipole | ||
charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states}) | ||
dipole = physics.DipoleNode("D", (charge, positions), db_name="D") | ||
# add NACR | ||
nacr = excited.NACRMultiStateNode( | ||
"ScaledNACR", | ||
(charge, positions, energy), | ||
db_name="ScaledNACR", | ||
module_kwargs={"n_target": n_states}, | ||
) | ||
# set up plotter | ||
plotter = [] | ||
for node in [mol_energy, dipole, nacr]: | ||
plotter.append(plotting.Hist2D.compare(node, saved=True, shown=False)) | ||
for i in range(network_params["n_interaction_layers"]): | ||
plotter.append( | ||
plotting.SensitivityPlot( | ||
network.torch_module.sensitivity_layers[i], | ||
saved=f"Sensitivity_{i}.pdf", | ||
shown=False, | ||
) | ||
) | ||
plotter = plotting.PlotMaker(*plotter, plot_every=plot_frequency) | ||
# build the loss function | ||
validation_losses = {} | ||
# energy | ||
energy_rmse = loss.MSELoss.of_node(energy) ** 0.5 | ||
validation_losses["E-RMSE"] = energy_rmse | ||
energy_mae = loss.MAELoss.of_node(energy) | ||
validation_losses["E-MAE"] = energy_mae | ||
energy_loss = energy_rmse + energy_mae | ||
validation_losses["E-Loss"] = energy_loss | ||
total_loss = energy_loss | ||
# dipole | ||
dipole_rmse = excited.MSEPhaseLoss.of_node(dipole) ** 0.5 | ||
validation_losses["D-RMSE"] = dipole_rmse | ||
dipole_mae = excited.MAEPhaseLoss.of_node(dipole) | ||
validation_losses["D-MAE"] = dipole_mae | ||
dipole_loss = dipole_rmse / np.sqrt(3) + dipole_mae | ||
validation_losses["D-Loss"] = dipole_loss | ||
total_loss += dipole_weight * dipole_loss | ||
# nacr | ||
nacr_rmse = excited.MSEPhaseLoss.of_node(nacr) ** 0.5 | ||
validation_losses["NACR-RMSE"] = nacr_rmse | ||
nacr_mae = excited.MAEPhaseLoss.of_node(nacr) | ||
validation_losses["NACR-MAE"] = nacr_mae | ||
nacr_loss = nacr_rmse / np.sqrt(3 * n_atoms) + nacr_mae | ||
validation_losses["NACR-Loss"] = nacr_loss | ||
total_loss += nacr_weight * nacr_loss | ||
# l2 regularization | ||
l2_reg = loss.l2reg(network) | ||
validation_losses["L2"] = l2_reg | ||
loss_regularization = l2_weight * l2_reg | ||
# add total loss to the dictionary | ||
validation_losses["Loss_wo_L2"] = total_loss | ||
validation_losses["Loss"] = total_loss + loss_regularization | ||
|
||
# set up experiment | ||
training_modules, db_info = hippynn.experiment.assemble_for_training( | ||
validation_losses["Loss"], | ||
validation_losses, | ||
plot_maker=plotter, | ||
) | ||
# set up the optimizer | ||
optimizer = torch.optim.AdamW(training_modules.model.parameters(), lr=1e-3) | ||
# use higher patience for production runs | ||
scheduler = RaiseBatchSizeOnPlateau(optimizer=optimizer, max_batch_size=2048, patience=10, factor=0.5) | ||
controller = PatienceController( | ||
optimizer=optimizer, | ||
scheduler=scheduler, | ||
batch_size=32, | ||
eval_batch_size=2048, | ||
# use higher max_epochs for production runs | ||
max_epochs=100, | ||
stopping_key="Loss", | ||
fraction_train_eval=0.1, | ||
# use higher termination_patience for production runs | ||
termination_patience=10, | ||
) | ||
experiment_params = hippynn.experiment.SetupParams(controller=controller) | ||
|
||
# load database | ||
database = hippynn.databases.DirectoryDatabase( | ||
name="azo_", # Prefix for arrays in the directory | ||
directory="../../../datasets/azomethane/release/training/", | ||
seed=114514, # Random seed for splitting data | ||
**db_info, # Adds the inputs and targets db_names from the model as things to load | ||
) | ||
# use 10% of the dataset just for quick testing purpose | ||
database.make_random_split("train", 0.07) | ||
database.make_random_split("valid", 0.02) | ||
database.make_random_split("test", 0.01) | ||
database.splitting_completed = True | ||
# split the whole dataset into train, valid, test in the ratio of 7:2:1 | ||
# database.make_trainvalidtest_split(0.1, 0.2) | ||
|
||
# set up training | ||
training_modules, controller, metric_tracker = setup_training( | ||
training_modules=training_modules, | ||
setup_params=experiment_params, | ||
) | ||
# train model | ||
metric_tracker = train_model( | ||
training_modules, | ||
database, | ||
controller, | ||
metric_tracker, | ||
callbacks=None, | ||
batch_callbacks=None, | ||
) | ||
|
||
del network_params["possible_species"] | ||
network_params["metric"] = metric_tracker.best_metric_values | ||
network_params["avg_epoch_time"] = np.average(metric_tracker.epoch_times) | ||
network_params["Loss"] = metric_tracker.best_metric_values["valid"]["Loss"] | ||
|
||
with open("training_summary.json", "w") as out: | ||
json.dump(network_params, out, indent=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.