diff --git a/examples/InPSNAPExample.py b/examples/InPSNAPExample.py index 6910bfbd..61c04588 100644 --- a/examples/InPSNAPExample.py +++ b/examples/InPSNAPExample.py @@ -1,3 +1,17 @@ +""" +Example training to the SNAP database for Indium Phosphide. + +This script was designed for an external dataset available at +https://github.com/FitSNAP/FitSNAP + +For info on the dataset, see the following publication: +Explicit Multielement Extension of the Spectral Neighbor Analysis Potential for Chemically Complex Systems +M. A. Cusentino, M. A. Wood, and A. P. Thompson +The Journal of Physical Chemistry A 2020 124 (26), 5456-5464 +DOI: 10.1021/acs.jpca.0c02450 + +""" + import numpy as np import torch torch.set_default_dtype(torch.float32) diff --git a/examples/ani1x_training.py b/examples/ani1x_training.py index d729ad06..c8fcba68 100644 --- a/examples/ani1x_training.py +++ b/examples/ani1x_training.py @@ -21,7 +21,7 @@ import ase.units import sys -sys.path.append("../../datasets/ani-al/readers/lib/") +sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py import pyanitools diff --git a/examples/ani_aluminum_example.py b/examples/ani_aluminum_example.py index eb47745e..1d0c7d44 100644 --- a/examples/ani_aluminum_example.py +++ b/examples/ani_aluminum_example.py @@ -3,7 +3,8 @@ Example training to the ANI-aluminum dataset. This script was designed for an external dataset available at -https://github.com/atomistic-ml/ani-al +https://github.com/atomistic-ml/ani-al. One should download the +entire repository. Note: It is necessary to untar the h5 data files in ani-al/data/ before running this script. @@ -23,7 +24,7 @@ import sys -sys.path.append("../../datasets/ani-al/readers/lib/") +sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py import pyanitools # Check if pyanitools is found early import torch diff --git a/examples/ani_aluminum_example_multilayer.py b/examples/ani_aluminum_example_multilayer.py index 580b8b1a..b1ca647f 100644 --- a/examples/ani_aluminum_example_multilayer.py +++ b/examples/ani_aluminum_example_multilayer.py @@ -3,7 +3,8 @@ Example training to the ANI-aluminum dataset. This script was designed for an external dataset available at -https://github.com/atomistic-ml/ani-al +https://github.com/atomistic-ml/ani-al. One should download the +entire repository. Note: It is necessary to untar the h5 data files in ani-al/data/ before running this script. @@ -23,7 +24,7 @@ import sys -sys.path.append("../../datasets/ani-al/readers/lib/") +sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py import pyanitools # Check if pyanitools is found early import torch diff --git a/examples/ase_example.py b/examples/ase_example.py index 66ccbb75..99b1bf33 100644 --- a/examples/ase_example.py +++ b/examples/ase_example.py @@ -1,7 +1,9 @@ """ Script running an aluminum model with ASE. -For training see `ani_aluminum_example.py`. -This will generate the files for a model. + +Before running this script, you must run +`ani_aluminum_example.py` to train the corresponding +model. Modified from ase MD example. @@ -12,7 +14,6 @@ # Imports import numpy as np import torch -import hippynn import ase import time @@ -39,21 +40,22 @@ calc = HippynnCalculator(energy_node, en_unit=units.eV) calc.to(torch.float64) if torch.cuda.is_available(): - nrep = 30 # 27,000 atoms -- should fit in a 12 GB GPU if using custom kernels. + nrep = 25 # 31,250 atoms -- should fit in a 16 GB GPU if using custom kernels. calc.to(torch.device("cuda")) else: nrep = 10 # 1,000 atoms. # Build the atoms object -atoms = ase.build.bulk("Al", crystalstructure="fcc", a=4.05) +atoms = ase.build.bulk("Al", crystalstructure="fcc", a=4.05, orthorhombic=True) reps = nrep * np.eye(3, dtype=int) atoms = ase.build.make_supercell(atoms, reps, wrap=True) atoms.calc = calc print("Number of atoms:", len(atoms)) -atoms.rattle(0.1) -MaxwellBoltzmannDistribution(atoms, temperature_K=500) +rng = np.random.default_rng(seed=0) +atoms.rattle(0.1, rng=rng) +MaxwellBoltzmannDistribution(atoms, temperature_K=500, rng=rng) dyn = VelocityVerlet(atoms, 1 * units.fs) diff --git a/examples/ase_example_multilayer.py b/examples/ase_example_multilayer.py index c2326243..8249388a 100644 --- a/examples/ase_example_multilayer.py +++ b/examples/ase_example_multilayer.py @@ -1,12 +1,15 @@ """ Script running an aluminum model with ASE. -For training see `ani_aluminum_example.py`. -This will generate the files for a model. -Modified from ase MD example. +This script is designed to match the +LAMMPS script located at +./lammps/in.mliap.unified.hippynn.Al + +Before running this script, you must run +`ani_aluminum_example_multilayer.py` to +train the corresponding model. -If a GPU is available, this script -will use it, and run a somewhat bigger system. +Modified from ase MD example. """ # Imports @@ -29,7 +32,7 @@ with active_directory("TEST_ALUMINUM_MODEL_MULTILAYER", create=False): bundle = load_checkpoint_from_cwd(map_location='cpu',restore_db=False) except FileNotFoundError: - raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!") + raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!") model = bundle["training_modules"].model @@ -38,15 +41,14 @@ energy_node = model.node_from_name("energy") calc = HippynnCalculator(energy_node, en_unit=units.eV) calc.to(torch.float64) + if torch.cuda.is_available(): - nrep = 4 calc.to(torch.device('cuda')) -else: - nrep = 10 # Build the atoms object atoms = FaceCenteredCubic(directions=np.eye(3, dtype=int), size=(1,1,1), symbol='Al', pbc=(True,True,True)) +nrep = 4 reps = nrep*np.eye(3, dtype=int) atoms = ase.build.make_supercell(atoms, reps, wrap=True) atoms.calc = calc diff --git a/examples/close_contact_finding.py b/examples/close_contact_finding.py index 3ff19266..36d3597f 100644 --- a/examples/close_contact_finding.py +++ b/examples/close_contact_finding.py @@ -15,7 +15,7 @@ """ import sys -sys.path.append("../../datasets/ani-al/readers/lib/") +sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py import pyanitools # Check if pyanitools is found early ### Loading the database diff --git a/examples/feature_extraction.py b/examples/feature_extraction.py index edd62dec..26c892d1 100644 --- a/examples/feature_extraction.py +++ b/examples/feature_extraction.py @@ -64,7 +64,7 @@ import sys -sys.path.append("../../datasets/ani-al/readers/lib/") +sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py import pyanitools # Check if pyanitools is found early from hippynn.databases.h5_pyanitools import PyAniDirectoryDB diff --git a/examples/molecular_dynamics.py b/examples/molecular_dynamics.py new file mode 100644 index 00000000..3ab4d6f1 --- /dev/null +++ b/examples/molecular_dynamics.py @@ -0,0 +1,168 @@ +""" +This script demonstrates how to use the custom MD module. +It is intended to mirror the `ase_example.py` example, +using the custom MD module rather than ASE. + +Before running this script, you must run +`ani_aluminum_example.py` to train a model. + +If a GPU is available, this script +will use it, and run a somewhat bigger system. +""" + +import numpy as np +import torch +import ase +import time +from tqdm import trange + +import ase.build +from ase import units +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution + +from hippynn.graphs import physics, replace_node +from hippynn.graphs.predictor import Predictor +from hippynn.graphs.nodes.pairs import KDTreePairsMemory +from hippynn.experiment.serialization import load_checkpoint_from_cwd +from hippynn.tools import active_directory +from hippynn.molecular_dynamics.md import ( + Variable, + NullUpdater, + VelocityVerlet, + MolecularDynamics, +) + +# Adjust size of system depending on device +if torch.cuda.is_available(): + nrep = 25 + device = torch.device("cuda") +else: + nrep = 10 + device = torch.device("cpu") + +# Load the pre-trained model +try: + with active_directory("TEST_ALUMINUM_MODEL", create=False): + bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False) +except FileNotFoundError: + raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!") + +# Adjust sign on force node (the HippynnCalculator does this automatically) +model = bundle["training_modules"].model +positions_node = model.node_from_name("coordinates") +energy_node = model.node_from_name("energy") +force_node = physics.GradientNode("force", (energy_node, positions_node), sign=-1) + +# Replace pair-finder with more efficient one so that system can fit on GPU +old_pairs_node = model.node_from_name("PairIndexer") +species_node = model.node_from_name("species") +cell_node = model.node_from_name("cell") +new_pairs_node = KDTreePairsMemory("PairIndexer", parents=(positions_node, species_node, cell_node), skin=1.0, dist_hard_max=7.5) +replace_node(old_pairs_node, new_pairs_node) + +model = Predictor(inputs=model.input_nodes, outputs=[force_node]) +model.to(device) +model.to(torch.float64) + +# Use ASE to generate initial positions and velocities +atoms = ase.build.bulk("Al", crystalstructure="fcc", a=4.05, orthorhombic=True) +reps = nrep * np.eye(3, dtype=int) +atoms = ase.build.make_supercell(atoms, reps, wrap=True) + +print("Number of atoms:", len(atoms)) + +rng = np.random.default_rng(seed=0) +atoms.rattle(0.1, rng=rng) +MaxwellBoltzmannDistribution(atoms, temperature_K=500, rng=rng) + +# Initialize MD variables +# NOTE: Setting the initial acceleration is only necessary to exactly match the results +# in `ase_example.py.` In general, it can be set to zero without impacting the statistics +# of the trajectory. +coordinates = torch.as_tensor(np.array(atoms.get_positions()), device=device).unsqueeze_(0) # add batch axis +init_velocity = torch.as_tensor(np.array(atoms.get_velocities())).unsqueeze_(0) +cell = torch.as_tensor(np.array(atoms.get_cell()), device=device).unsqueeze_(0) +species = torch.as_tensor(np.array(atoms.get_atomic_numbers()), device=device).unsqueeze_(0) +mass = torch.as_tensor(atoms.get_masses()).unsqueeze_(0).unsqueeze_(-1) # add a batch axis and a feature axis +init_force = model( + coordinates=coordinates, + cell=cell, + species=species, +)["force"] +init_force = torch.as_tensor(init_force) +init_acceleration = init_force / mass + +# Define a position "Variable" and set updater to "VelocityVerlet" +position_variable = Variable( + name="position", + data={ + "position": coordinates, + "velocity": init_velocity, + "acceleration": init_acceleration, + "mass": mass, + "cell": cell, # Optional. If added, coordinates will be wrapped in each step of the VelocityVerlet updater. Otherwise, they will be temporarily wrapped for model evaluation only and stored in their unwrapped form + }, + model_input_map={ + "coordinates": "position", + }, + device=device, + updater=VelocityVerlet(force_db_name="force"), +) + +# Define species and cell Variables +species_variable = Variable( + name="species", + data={"species": species}, + model_input_map={"species": "species"}, + device=device, + updater=NullUpdater(), +) + +cell_variable = Variable( + name="cell", + data={"cell": cell}, + model_input_map={"cell": "cell"}, + device=device, + updater=NullUpdater(), +) + +# Set up MD driver +emdee = MolecularDynamics( + variables=[position_variable, species_variable, cell_variable], + model=model, +) + +# This Tracker imitates the Tracker from ase_example.py and is optional to use +class Tracker: + def __init__(self): + self.last_call_time = time.time() + + def update(self, diff_steps, data): + now = time.time() + diff = now - self.last_call_time + self.n_atoms = data["position_position"].shape[-2] + time_per_atom_step = diff / (self.n_atoms * diff_steps) + self.last_call_time = now + return time_per_atom_step + + def print(self, diff_steps=None, data=None): + time_per_atom_step = self.update(diff_steps, data) + """Function to print the potential, kinetic and total energy""" + atoms.set_positions(data["position_position"][-1]) + atoms.set_velocities(data["position_velocity"][-1]) + print( + "Performance:", + round(1e6 * time_per_atom_step, 1), + " microseconds/(atom-step)", + ) + # epot = self.atoms.get_potential_energy() / len(self.atoms) + ekin = atoms.get_kinetic_energy() / self.n_atoms + # stress = self.atoms.get_stress() + print("Energy per atom: Ekin = %.7feV (T=%3.0fK)" % (ekin, ekin / (1.5 * units.kB))) + +# Run MD! +tracker = Tracker() +for i in trange(100): # Run 2 ps + n_steps = 20 + emdee.run(dt=1 * units.fs, n_steps=n_steps, record_every=n_steps) # Run 20 fs + tracker.print(n_steps, emdee.get_data()) diff --git a/examples/singlet_triplet_model.py b/examples/singlet_triplet_model.py index 5c7515e5..f8ac18ec 100644 --- a/examples/singlet_triplet_model.py +++ b/examples/singlet_triplet_model.py @@ -1,3 +1,5 @@ +# NOTE: This script needs revision before it will run + import torch # Setup pytorch things diff --git a/hippynn/databases/h5_pyanitools.py b/hippynn/databases/h5_pyanitools.py index e079f8d7..bbe71e01 100644 --- a/hippynn/databases/h5_pyanitools.py +++ b/hippynn/databases/h5_pyanitools.py @@ -157,7 +157,7 @@ def filter_arrays(self, arr_dict, allow_unfound=False, quiet=False): class PyAniFileDB(Database, PyAniMethods, Restartable): - def __init__(self, file, inputs, targets, *args, allow_unfound=False,species_key="species", quiet=False, **kwargs): + def __init__(self, file, inputs, targets, *args, allow_unfound=False, species_key="species", quiet=False, **kwargs): self.file = file self.inputs = inputs diff --git a/hippynn/layers/pairs/dispatch.py b/hippynn/layers/pairs/dispatch.py index d4cce4a5..b14a8b92 100644 --- a/hippynn/layers/pairs/dispatch.py +++ b/hippynn/layers/pairs/dispatch.py @@ -8,6 +8,7 @@ import torch from .open import PairMemory +from .periodic import filter_pairs def wrap_points_np(coords, cell, inv_cell): # cell is (basis,cartesian) @@ -291,9 +292,9 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cell, mol_i # print("Pairs found",pair_first.shape) coordflat = coordinates.reshape(n_molecules * n_atoms_max, 3)[real_atoms] paircoord = coordflat[pair_first] - coordflat[pair_second] + pair_offsets - distflat2 = paircoord.norm(dim=1) + distflat = paircoord.norm(dim=1) - return distflat2, pair_first, pair_second, paircoord, offsets, offset_index + return distflat, pair_first, pair_second, paircoord, offsets, offset_index class NPNeighbors(_DispatchNeighbors): @@ -343,7 +344,7 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells, mol_ inputs = (coordinates, nonblank, real_atoms, inv_real_atoms, cells, mol_index, n_molecules, n_atoms_max) outputs = self._pair_indexer(*inputs) - distflat2, pair_first, pair_second, paircoord, offsets, offset_index = outputs + distflat, pair_first, pair_second, paircoord, offsets, offset_index = outputs with torch.no_grad(): pair_mol = mol_index[pair_first] @@ -366,7 +367,7 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells, mol_ coordflat = coordinates.reshape(n_molecules * n_atoms_max, 3)[real_atoms] paircoord = coordflat[self.pair_first] - coordflat[self.pair_second] + self.pair_offsets - distflat2 = paircoord.norm(dim=1) - - return distflat2, self.pair_first, self.pair_second, paircoord, self.offsets, self.offset_index + distflat = paircoord.norm(dim=1) + # We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance. + return filter_pairs(self.hard_dist_cutoff, distflat, self.pair_first, self.pair_second, paircoord, self.offsets, self.offset_index) diff --git a/hippynn/layers/pairs/indexing.py b/hippynn/layers/pairs/indexing.py index 7239e11f..a0bbddeb 100644 --- a/hippynn/layers/pairs/indexing.py +++ b/hippynn/layers/pairs/indexing.py @@ -5,6 +5,7 @@ from ...custom_kernels.utils import get_id_and_starts from .open import _PairIndexer +from .periodic import filter_pairs class ExternalNeighbors(_PairIndexer): @@ -18,14 +19,8 @@ def forward(self, coordinates, real_atoms, shifts, cell, pair_first, pair_second paircoord = atom_coordinates[pair_second] - atom_coordinates[pair_first] + shifts.to(cell.dtype) @ cell distflat = paircoord.norm(dim=1) - # Trim the list to only include relevant atoms, improving performance. - within_cutoff_pairs = distflat < self.hard_dist_cutoff - distflat = distflat[within_cutoff_pairs] - pair_first = pair_first[within_cutoff_pairs] - pair_second = pair_second[within_cutoff_pairs] - paircoord = paircoord[within_cutoff_pairs] - - return distflat, pair_first, pair_second, paircoord + # We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance. + return filter_pairs(self.hard_dist_cutoff, distflat, pair_first, pair_second, paircoord) class PairReIndexer(torch.nn.Module): diff --git a/hippynn/layers/pairs/open.py b/hippynn/layers/pairs/open.py index e676c2cc..7cbb1324 100644 --- a/hippynn/layers/pairs/open.py +++ b/hippynn/layers/pairs/open.py @@ -70,6 +70,8 @@ class PairMemory(torch.nn.Module): set to zero while training for fastest results. ''' + # TODO: Adapt to work system-by-system + # ## Subclasses should update the following ## # _pair_indexer_class = NotImplemented diff --git a/hippynn/layers/pairs/periodic.py b/hippynn/layers/pairs/periodic.py index f3666fd0..c5fa92cf 100644 --- a/hippynn/layers/pairs/periodic.py +++ b/hippynn/layers/pairs/periodic.py @@ -147,6 +147,9 @@ def wrap_systems_torch(coords, cell, cutoff: float): return inv_cell, wrapped_coords, wrapped_offset.to(torch.int64), n_bounds +def filter_pairs(cutoff, distflat, *addn_features): + filter = distflat < cutoff + return tuple((array[filter] for array in [distflat, *addn_features])) class PeriodicPairIndexer(_PairIndexer): """ @@ -287,7 +290,7 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells): inputs = (coordinates, nonblank, real_atoms, inv_real_atoms, cells) outputs = self._pair_indexer(*inputs) - distflat2, pair_first, pair_second, paircoord, cell_offsets, offset_num, pair_mol = outputs + distflat, pair_first, pair_second, paircoord, cell_offsets, offset_num, pair_mol = outputs for name, var in [ ("cell_offsets", cell_offsets), @@ -305,6 +308,7 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells): pair_shifts = torch.matmul(self.cell_offsets.unsqueeze(1).to(cells.dtype), cells[self.pair_mol]).squeeze(1) coordflat = coordinates.reshape(self.n_molecules * self.n_atoms, 3)[real_atoms] paircoord = coordflat[self.pair_first] - coordflat[self.pair_second] + pair_shifts - distflat2 = paircoord.norm(dim=1) + distflat = paircoord.norm(dim=1) - return distflat2, self.pair_first, self.pair_second, paircoord, self.cell_offsets, self.offset_num \ No newline at end of file + # We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance. + return filter_pairs(self.hard_dist_cutoff, distflat, self.pair_first, self.pair_second, paircoord, self.cell_offsets, self.offset_num) \ No newline at end of file diff --git a/hippynn/molecular_dynamics/md.py b/hippynn/molecular_dynamics/md.py new file mode 100644 index 00000000..fe7a06b1 --- /dev/null +++ b/hippynn/molecular_dynamics/md.py @@ -0,0 +1,540 @@ +from __future__ import annotations +from functools import singledispatchmethod + +import numpy as np +import torch + +from tqdm.autonotebook import trange +import ase + +from ..graphs import Predictor +from ..layers.pairs.periodic import wrap_systems_torch + + +class Variable: + """ + Tracks the state of a quantity (eg. position, cell, species, + volume) on each particle or each system in an MD simulation. Can + also hold additional data associated to that quantity (such as its + velocity, acceleration, etc...) + """ + + def __init__( + self, + name: str, + data: dict[str, torch.Tensor], + model_input_map: dict[str, str] = dict(), + updater: _VariableUpdater = None, + device: torch.device = None, + dtype: torch.dtype = None, + ) -> None: + """ + Parameters + ---------- + name : str + name for variable + data : dict[str, torch.Tensor] + dictionary of tracked data in the form `value_name: value` + model_input_map : dict[str, str], optional + dictionary of correspondences between data tracked by Variable + and inputs to the HIP-NN model in the form + `hipnn-db_name: variable-data-key`, by default dict() + updater : _VariableUpdater, optional + object which will update the data of the Variable + over the course of the MD simulation, by default None + device : torch.device, optional + device on which to keep data, by default None + dtype : torch.dtype, optional + dtype for float type data, by default None + """ + self.name = name + self.data = data + self.model_input_map = model_input_map + self.updater = updater + self.device = device + self.dtype = dtype + + @property + def data(self): + return self._data + + @data.setter + def data(self, data): + for key, value in data.items(): + if isinstance(value, np.ndarray): + data[key] = torch.as_tensor(value) + + batch_sizes = set([value.shape[0] for value in data.values()]) + if len(batch_sizes) > 1: + raise ValueError( + f"Inconsistent batch sizes found: {batch_sizes}. The first axis of each array in 'data' must be a batch axis of the same size." + ) + + self._data = data + + @property + def model_input_map(self): + return self._model_input_map + + @model_input_map.setter + def model_input_map(self, model_input_map): + for key, value in model_input_map.items(): + if not value in self.data.keys(): + raise ValueError( + f"Each value in the 'model_input_map' dictionary should correspond to a key in the 'data' dictionary. " + + f"Each key of the 'model_input_map' should correspond to hippynn db_name. Value {value} found in 'model_input_map', but no corresponding key in the 'data' dictionary found." + ) + self._model_input_map = model_input_map + + @property + def updater(self): + return self._updater + + @updater.setter + def updater(self, updater): + if updater is None: + self._updater = None + return + updater.variable = self + self._updater = updater + + @property + def device(self): + return self._device + + @device.setter + def device(self, device): + if device is None: + self._device = None + return + self._device = device + for key, value in self.data.items(): + self.data[key] = value.to(device) + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, dtype): + if dtype is None: + self._dtype = None + return + self._dtype = dtype + float_dtypes = [torch.float, torch.float16, torch.float32, torch.float64] + for key, value in self.data.items(): + if value.dtype in float_dtypes: + self.data[key] = value.to(dtype) + + @singledispatchmethod + def to(self, arg): + raise ValueError(f"Argument must be of type torch.device or torch.dtype, but provided argument is of type {type(arg)}.") + + @to.register + def _(self, arg: torch.device): + self.device = arg + + @to.register + def _(self, arg: torch.dtype): + self.dtype = arg + + +class _VariableUpdater: + """ + Parent class for algorithms that make updates to the data of a Variable during + each step of an MD simulation. + + Subclasses should redefine __init__, pre_step, post_step, and + required_variable_data as needed. The inputs to pre_step and post_step + should not be changed. + """ + + # A list of keys which must appear in Variable.data for any Variable that will be updated by objects of this class. + # Checked for by variable.setter. + required_variable_data = [] + + def __init__(self): + pass + + @property + def variable(self): + return self._variable + + @variable.setter + def variable(self, variable): + for key in self.required_variable_data: + if key not in variable.data.keys(): + raise ValueError( + f"Cannot attach to Variable with no assigned values for {key}. Update Variable.data to include values for {key}." + ) + self._variable = variable + + def pre_step(self, dt): + """Updates to variables performed during each step of MD simulation + before HIPNN model evaluation + + Parameters + ---------- + dt : float + timestep + """ + pass + + def post_step(self, dt, model_outputs): + """Updates to variables performed during each step of MD simulation + after HIPNN model evaluation + + Parameters + ---------- + dt : float + timestep + model_outputs : dict + dictionary of HIPNN model outputs + """ + pass + + +class NullUpdater(_VariableUpdater): + """ + Makes no change to the variable data at each step of MD. + """ + + def pre_step(self, dt): + pass + + def post_step(self, dt, model_outputs): + pass + +class VelocityVerlet(_VariableUpdater): + """ + Implements the Velocity Verlet algorithm + """ + + required_variable_data = ["position", "velocity", "acceleration", "mass"] + + def __init__( + self, + force_db_name: str, + units_force: float = ase.units.eV, + units_acc: float = ase.units.Ang / (1.0**2), + ): + """ + Parameters + ---------- + force_db_name : str + key which will correspond to the force on the corresponding Variable + in the HIPNN model output dictionary + units_force : float, optional + amount of eV equal to one in the units used for force output + of HIPNN model (eg. if force output in kcal, units_force = + ase.units.kcal = 2.6114e22 since 2.6114e22 kcal = 1 eV), + by default ase.units.eV = 1 + units_acc : float, optional + amount of Ang/fs^2 equal to one in the units used for acceleration + in the corresponding Variable, by default units.Ang/(1.0 ** 2) = 1 + """ + self.force_key = force_db_name + self.force_factor = units_force / units_acc + + def pre_step(self, dt): + """Updates to variables performed during each step of MD simulation + before HIPNN model evaluation + + Parameters + ---------- + dt : float + timestep + """ + self.variable.data["velocity"] = self.variable.data["velocity"] + 0.5 * dt * self.variable.data["acceleration"] + self.variable.data["position"] = self.variable.data["position"] + self.variable.data["velocity"] * dt + try: + _, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only used for discarded outputs; can be set arbitrarily + except KeyError: + pass + + def post_step(self, dt, model_outputs): + """Updates to variables performed during each step of MD simulation + after HIPNN model evaluation + + Parameters + ---------- + dt : float + timestep + model_outputs : dict + dictionary of HIPNN model outputs + """ + self.variable.data["force"] = model_outputs[self.force_key].to(self.variable.device) + if len(self.variable.data["force"].shape) == len(self.variable.data["mass"].shape): + self.variable.data["acceleration"] = self.variable.data["force"].detach() / self.variable.data["mass"] * self.force_factor + else: + self.variable.data["acceleration"] = ( + self.variable.data["force"].detach() / self.variable.data["mass"][..., None] * self.force_factor + ) + self.variable.data["velocity"] = self.variable.data["velocity"] + 0.5 * dt * self.variable.data["acceleration"] + + +class LangevinDynamics(_VariableUpdater): + """ + Implements the Langevin algorithm + """ + + required_variable_data = ["position", "velocity", "mass"] + + def __init__( + self, + force_db_name: str, + temperature: float, + frix: float, + units_force=ase.units.eV, + units_acc=ase.units.Ang / (1.0**2), + seed: int = None, + ): + """ + Parameters + ---------- + force_db_name : str + key which will correspond to the force on the corresponding Variable + in the HIPNN model output dictionary + temperature : float + temperature for Langevin algorithm + frix : float + friction coefficient for Langevin algorithm + units_force : float, optional + amount of eV equal to one in the units used for force output + of HIPNN model (eg. if force output in kcal, units_force = + ase.units.kcal = 2.6114e22 since 2.6114e22 kcal = 1 eV), + by default ase.units.eV = 1 + units_acc : float, optional + amount of Ang/fs^2 equal to one in the units used for acceleration + in the corresponding Variable, by default units.Ang/(1.0 ** 2) = 1 + seed : int, optional + used to set seed for reproducibility, by default None + """ + + self.force_key = force_db_name + self.force_factor = units_force / units_acc + self.temperature = temperature + self.frix = frix + self.kB = 0.001987204 * self.force_factor + + if seed is not None: + torch.manual_seed(seed) + + def pre_step(self, dt): + """Updates to variables performed during each step of MD simulation + before HIPNN model evaluation + + Parameters + ---------- + dt : float + timestep + """ + + self.variable.data["position"] = self.variable.data["position"] + self.variable.data["velocity"] * dt + + try: + _, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only used for discarded outputs; can be set arbitrarily + except KeyError: + pass + + def post_step(self, dt, model_outputs): + """Updates to variables performed during each step of MD simulation + after HIPNN model evaluation + + Parameters + ---------- + dt : float + timestep + model_outputs : dict + dictionary of HIPNN model outputs + """ + self.variable.data["force"] = model_outputs[self.force_key].to(self.variable.device) + + if len(self.variable.data["force"].shape) != len(self.variable.data["mass"].shape): + self.variable.data["mass"] = self.variable.data["mass"][..., None] + + self.variable.data["acceleration"] = self.variable.data["force"].detach() / self.variable.data["mass"] * self.force_factor + + self.variable.data["velocity"] = ( + self.variable.data["velocity"] + + dt * self.variable.data["acceleration"] + - self.frix * self.variable.data["velocity"] * dt + + torch.sqrt(2 * self.kB * self.frix * self.temperature / self.variable.data["mass"] * dt) + * torch.randn_like(self.variable.data["velocity"], memory_format=torch.contiguous_format) + ) + + +class MolecularDynamics: + """ + Driver for MD run + """ + + def __init__( + self, + variables: list[Variable], + model: Predictor, + device: torch.device = None, + dtype: torch.dtype = None, + ): + """ + Parameters + ---------- + variables : list[Variable] + list of Variable objects which will be tracked during simulation + model : Predictor + HIPNN Predictor + device : torch.device, optional + device to move variables and model to, by default None + dtype : torch.dtype, optional + dtype to convert all float type variable data and model parameters to, by default None + """ + + self.variables = variables + self.model = model + self.device = device + self.dtype = dtype + + self._data = dict() + + @property + def variables(self): + return self._variables + + @variables.setter + def variables(self, variables): + if not isinstance(variables, list): + variables = [variables] + for variable in variables: + if variable.updater is None: + raise ValueError(f"Variable with name {variable.name} does not have a _VariableUpdater set.") + + variable_names = [variable.name for variable in variables] + if len(variable_names) != len(set(variable_names)): + raise ValueError(f"Duplicate name found for Variables. Each Variable must have a distinct name. Names found: {variable_names}") + + batch_sizes = set([value.shape[0] for variable in variables for value in variable.data.values()]) + if len(batch_sizes) > 1: + raise ValueError( + f"Inconsistent batch sizes found: {batch_sizes}. The first axis of each array in 'data' represents a batch axis." + ) + + self._variables = variables + + @property + def model(self): + return self._model + + @model.setter + def model(self, model): + input_db_names = [node.db_name for node in model.inputs] + variable_data_db_names = [key for variable in self.variables for key in variable.model_input_map.keys()] + for db_name in input_db_names: + if db_name not in variable_data_db_names: + raise ValueError( + f"Model requires input for '{db_name}', but no Variable found which contains an entry for '{db_name}' in its 'model_input_map'." + + f" Entries in the 'model_input_map' should have the form 'hipnn-db_name: variable-data-key' where 'hipnn-db_name'" + + f" refers to the db_name of an input for the hippynn Predictor model," + + f" and 'variable-data-key' corresponds to a key in the 'data' dictionary of one of the Variables." + ) + self._model = model + + @property + def device(self): + return self._device + + @device.setter + def device(self, device): + if device is None: + self._device = None + return + self._device = device + self.model.to(device) + for variable in self.variables: + variable.to(device) + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, dtype): + if dtype is None: + self._dtype = None + return + self._dtype = dtype + self.model.to(dtype) + for variable in self.variables: + variable.to(dtype) + + @singledispatchmethod + def to(self, arg): + raise ValueError(f"Argument must be of type torch.device or torch.dtype, but provided argument is of type {type(arg)}.") + + @to.register + def _(self, arg: torch.device): + self.device = arg + + @to.register + def _(self, arg: torch.dtype): + self.dtype = arg + + def _step( + self, + dt: float, + ): + for variable in self.variables: + variable.updater.pre_step(dt) + + model_inputs = { + hipnn_db_name: variable.data[variable_key] + for variable in self.variables + for hipnn_db_name, variable_key in variable.model_input_map.items() + } + + model_outputs = self.model(**model_inputs) + + for variable in self.variables: + variable.updater.post_step(dt, model_outputs) + + return model_outputs + + def _update_data(self, model_outputs: dict): + for variable in self.variables: + for key, value in variable.data.items(): + try: + self._data[f"{variable.name}_{key}"].append(value.cpu().detach()[0]) + except KeyError: + self._data[f"{variable.name}_{key}"] = [value.cpu().detach()[0]] + for key, value in model_outputs.items(): + try: + self._data[f"output_{key}"].append(value.cpu().detach()[0]) + except KeyError: + self._data[f"output_{key}"] = [value.cpu().detach()[0]] + + + def run(self, dt: float, n_steps: int, record_every: int = None): + """ + Run `n_steps` of MD algorithm. + + Parameters + ---------- + dt : float + timestep + n_steps : int + number of steps to execute + record_every : int, optional + frequency at which to store the data at a step in memory, + record_every = 1 means every step will be stored, by default None + """ + for i in trange(n_steps): + model_outputs = self._step(dt) + if record_every is not None and (i + 1) % record_every == 0: + self._update_data(model_outputs) + + def get_data(self): + """Returns a dictionary of the recorded data""" + return {key: torch.stack(value) for key, value in self._data.items()} + + def reset_data(self): + """Clear all recorded data""" + self._data = {key: [] for key in self._data.keys()}