From b426f28a538f55b85ad101c7d7c9244daacff8cf Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 19 Dec 2023 15:34:58 +0100 Subject: [PATCH 01/55] multistate --- chiron/multistate.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 chiron/multistate.py diff --git a/chiron/multistate.py b/chiron/multistate.py new file mode 100644 index 0000000..e69de29 From 6af6ef6812e59e05ce7016cacb6d8f935fe9d512 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 19 Dec 2023 16:05:43 +0100 Subject: [PATCH 02/55] adopt example from openmmtools --- chiron/tests/test_minization.py | 3 +-- chiron/tests/test_multistate.py | 39 +++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 chiron/tests/test_multistate.py diff --git a/chiron/tests/test_minization.py b/chiron/tests/test_minization.py index 157df1e..b5cc290 100644 --- a/chiron/tests/test_minization.py +++ b/chiron/tests/test_minization.py @@ -1,6 +1,5 @@ def test_minimization(): from chiron.minimze import minimize_energy - import jax import jax.numpy as jnp from chiron.states import SamplerState @@ -32,4 +31,4 @@ def test_minimization(): min_x = minimize_energy(sampler_state.x0, lj_potential.compute_energy, nbr_list) e = lj_potential.compute_energy(min_x, nbr_list) - assert jnp.isclose(e, -12506.332) + assert jnp.isclose(e, -12506.332) \ No newline at end of file diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py new file mode 100644 index 0000000..62c1b59 --- /dev/null +++ b/chiron/tests/test_multistate.py @@ -0,0 +1,39 @@ +def test_HO(): + import math + from openmm import unit + from openmmtools import testsystems + from chiron.multistate import MultiStateSampler + from chiron.mcmc import LangevinDynamicsMove + from chiron.states import ThermodynamicState + + testsystem = testsystems.AlanineDipeptideImplicit() + + n_replicas = 3 + T_min = 298.0 * unit.kelvin # Minimum temperature. + T_max = 600.0 * unit.kelvin # Maximum temperature. + temperatures = [ + T_min + + (T_max - T_min) + * (math.exp(float(i) / float(n_replicas - 1)) - 1.0) + / (math.e - 1.0) + for i in range(n_replicas) + ] + temperatures = [ + T_min + + (T_max - T_min) + * (math.exp(float(i) / float(n_replicas - 1)) - 1.0) + / (math.e - 1.0) + for i in range(n_replicas) + ] + thermodynamic_states = [ + ThermodynamicState(system=testsystem.system, temperature=T) + for T in temperatures + ] + + # Initialize simulation object with options. Run with a langevin integrator. + + move = LangevinDynamicsMove(timestep=2.0 * unit.femtoseconds, n_steps=50) + simulation = MultiStateSampler(mcmc_moves=move, number_of_iterations=2) + + # Run the simulation + simulation.run() From a8722ec9d4679606fe711b72e115da6fe0ed74df Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 28 Dec 2023 15:56:05 +0100 Subject: [PATCH 03/55] Add tests for MultiStateSampler class and its methods --- chiron/multistate.py | 505 ++++++++++++++++++++++++++++++++ chiron/tests/test_multistate.py | 60 ++-- 2 files changed, 546 insertions(+), 19 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index e69de29..6bdd125 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -0,0 +1,505 @@ +import os +import copy +import time +from typing import List +import loguru as logger +from chiron.states import SamplerState, ThermodynamicState +import datetime +from loguru import logger as log +import numpy as np +from openmmtools.utils import time_it + +import openmm +from openmm import unit + + +class MultiStateSampler(object): + """ + Base class for samplers that sample multiple thermodynamic states using + one or more replicas. + + This base class provides a general simulation facility for multistate from multiple + thermodynamic states, allowing any set of thermodynamic states to be specified. + If instantiated on its own, the thermodynamic state indices associated with each + state are specified and replica mixing does not change any thermodynamic states, + meaning that each replica remains in its original thermodynamic state. + + Stored configurations, energies, swaps, and restart information are all written + to a single output file using the platform portable, robust, and efficient + NetCDF4 library. + + Parameters + ---------- + mcmc_moves : MCMCMove or list of MCMCMove, optional + The MCMCMove used to propagate the thermodynamic states. If a list of MCMCMoves, + they will be assigned to the correspondent thermodynamic state on + creation. If None is provided, Langevin dynamics with 2fm timestep, 5.0/ps collision rate, + and 500 steps per iteration will be used. + number_of_iterations : int or infinity, optional, default: 1 + The number of iterations to perform. Both ``float('inf')`` and + ``numpy.inf`` are accepted for infinity. If you set this to infinity, + be sure to set also ``online_analysis_interval``. + online_analysis_interval : None or Int >= 1, optional, default: 200 + Choose the interval at which to perform online analysis of the free energy. + + After every interval, the simulation will be stopped and the free energy estimated. + + If the error in the free energy estimate is at or below ``online_analysis_target_error``, then the simulation + will be considered completed. + + If set to ``None``, then no online analysis is performed + + online_analysis_target_error : float >= 0, optional, default 0.0 + The target error for the online analysis measured in kT per phase. + + Once the free energy is at or below this value, the phase will be considered complete. + + If ``online_analysis_interval`` is None, this option does nothing. + + Default is set to 0.0 since online analysis runs by default, but a finite ``number_of_iterations`` should also + be set to ensure there is some stop condition. If target error is 0 and an infinite number of iterations is set, + then the sampler will run until the user stop it manually. + + online_analysis_minimum_iterations : int >= 0, optional, default 200 + Set the minimum number of iterations which must pass before online analysis is carried out. + + Since the initial samples likely not to yield a good estimate of free energy, save time and just skip them + If ``online_analysis_interval`` is None, this does nothing + + locality : int > 0, optional, default None + If None, the energies at all states will be computed for every replica each iteration. + If int > 0, energies will only be computed for states ``range(max(0, state-locality), min(n_states, state+locality))``. + + Attributes + ---------- + n_replicas + n_states + iteration + mcmc_moves + sampler_states + metadata + is_completed + """ + + def __init__( + self, + mcmc_moves=None, + number_of_iterations=1, + ): + # These will be set on initialization. See function + # create() for explanation of single variables. + self._thermodynamic_states = None + self._unsampled_states = None + self._sampler_states = None + self._replica_thermodynamic_states = None + self._iteration = None + self._energy_thermodynamic_states = None + self._neighborhoods = None + self._energy_unsampled_states = None + self._n_accepted_matrix = None + self._n_proposed_matrix = None + self._reporter = None + self._metadata = None + self._timing_data = dict() + + # Handling default propagator. + if mcmc_moves is None: + from .mcmc import LangevinDynamicsMove + + # This will be converted to a list in create(). + self._mcmc_moves = LangevinDynamicsMove( + timestep=2.0 * unit.femtosecond, + collision_rate=5.0 / unit.picosecond, + n_steps=500, + ) + else: + self._mcmc_moves = copy.deepcopy(mcmc_moves) + + # Store constructor parameters. Everything is marked for internal + # usage because any change to these attribute implies a change + # in the storage file as well. Use properties for checks. + self.number_of_iterations = number_of_iterations + + self._last_mbar_f_k = None + self._last_err_free_energy = None + + @property + def n_states(self): + """The integer number of thermodynamic states (read-only).""" + if self._thermodynamic_states is None: + return 0 + else: + return len(self._thermodynamic_states) + + @property + def n_replicas(self): + """The integer number of replicas (read-only).""" + if self._sampler_states is None: + return 0 + else: + return len(self._sampler_states) + + @property + def iteration(self): + """The integer current iteration of the simulation (read-only). + + If the simulation has not been created yet, this is None. + + """ + return self._iteration + + @property + def mcmc_moves(self): + """A copy of the MCMCMoves list used to propagate the simulation. + + This can be set only before creation. + + """ + return copy.deepcopy(self._mcmc_moves) + + @property + def sampler_states(self): + """A copy of the sampler states list at the current iteration. + + This can be set only before running. + """ + return copy.deepcopy(self._sampler_states) + + @property + def is_periodic(self): + """Return True if system is periodic, False if not, and None if not initialized""" + if self._sampler_states is None: + return None + return self._thermodynamic_states[0].is_periodic + + @property + def metadata(self): + """A copy of the metadata dictionary passed on creation (read-only).""" + return copy.deepcopy(self._metadata) + + @property + def is_completed(self): + """Check if we have reached any of the stop target criteria (read-only)""" + return self._is_completed() + + def create( + self, + thermodynamic_states: List[ThermodynamicState], + sampler_states: List[SamplerState], + metadata=None, + ): + """Create new multistate sampler simulation. + + Parameters + ---------- + thermodynamic_states : list of ThermodynamicState + Thermodynamic states to simulate, where one replica is allocated per state. + Each state must have a system with the same number of atoms. + sampler_states : list of SamplerState + One or more sets of initial sampler states. + The number of replicas is taken to be the number of sampler states provided. + If the sampler states do not have box_vectors attached and the system is periodic, + an exception will be thrown. + metadata : dict, optional, default=None + Simulation metadata to be stored in the file. + """ + # TODO: initialize reporter here + # TODO: consider unsampled thermodynamic states for reweighting schemes + self._allocate_variables(thermodynamic_states, sampler_states) + + @classmethod + def _default_initial_thermodynamic_states( + cls, + thermodynamic_states: List[ThermodynamicState], + sampler_states: List[SamplerState], + ): + """ + Create the initial_thermodynamic_states obeying the following rules: + + * ``len(thermodynamic_states) == len(sampler_states)``: 1-to-1 distribution + """ + n_thermo = len(thermodynamic_states) + n_sampler = len(sampler_states) + assert n_thermo == n_sampler, "Must have 1-to-1 distribution of states" + initial_thermo_states = np.arange(n_thermo, dtype=int) + return initial_thermo_states + + def _allocate_variables(self, thermodynamic_states, sampler_states): + # Save thermodynamic states. This sets n_replicas. + self._thermodynamic_states = [ + copy.deepcopy(thermodynamic_state) + for thermodynamic_state in thermodynamic_states + ] + + # Deep copy sampler states. + self._sampler_states = [ + copy.deepcopy(sampler_state) for sampler_state in sampler_states + ] + + # Set initial thermodynamic state indices + initial_thermodynamic_states = self._default_initial_thermodynamic_states( + thermodynamic_states, sampler_states + ) + self._replica_thermodynamic_states = np.array( + initial_thermodynamic_states, np.int64 + ) + + # Reset statistics. + # _n_accepted_matrix[i][j] is the number of swaps proposed between thermodynamic states i and j. + # _n_proposed_matrix[i][j] is the number of swaps proposed between thermodynamic states i and j. + self._n_accepted_matrix = np.zeros([self.n_states, self.n_states], np.int64) + self._n_proposed_matrix = np.zeros([self.n_states, self.n_states], np.int64) + + # Allocate memory for energy matrix. energy_thermodynamic_states[k][l] + # is the reduced potential computed at the positions of SamplerState sampler_states[k] + # and ThermodynamicState thermodynamic_states[l]. + self._energy_thermodynamic_states = np.zeros( + [self.n_replicas, self.n_states], np.float64 + ) + self._neighborhoods = np.zeros([self.n_replicas, self.n_states], "i1") + + def _minimize_replica( + self, replica_id: int, tolerance: unit.Quantity, max_iterations: int + ): + from chiron.minimze import minimize_energy + + # Retrieve thermodynamic and sampler states. + thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] + thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] + sampler_state = self._sampler_states[replica_id] + + # Compute the initial energy of the system for logging. + initial_energy = thermodynamic_state.get_reduced_potential(sampler_state) + print(initial_energy) + log.debug( + f"Replica {replica_id + 1}/{self.n_replicas}: initial energy {initial_energy:8.3f}kT" + ) + + results = minimize_energy( + sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=0 + ) + + def minimize( + self, + tolerance=1.0 * unit.kilojoules_per_mole / unit.nanometers, + max_iterations=0, + ): + """Minimize all replicas. + + Minimized positions are stored at the end. + + Parameters + ---------- + tolerance : openmm.unit.Quantity, optional + Minimization tolerance (units of energy/mole/length, default is + ``1.0 * unit.kilojoules_per_mole / unit.nanometers``). + max_iterations : int, optional + Maximum number of iterations for minimization. If 0, minimization + continues until converged. + + """ + # Check that simulation has been created. + if self.n_replicas == 0: + raise RuntimeError( + "Cannot minimize replicas. The simulation must be created first." + ) + + log.debug("Minimizing all replicas...") + + # minimization + minimized_positions, sampler_state_ids = [], [] + for replica_id in range(self.n_replicas): + minimized_position, sampler_state_id = self._minimize_replica( + replica_id, tolerance, max_iterations + ) + minimized_positions.append(minimized_position) + sampler_state_ids.append(sampler_state_id) + + # Update all sampler states. + for sampler_state_id, minimized_pos in zip( + sampler_state_ids, minimized_positions + ): + self._sampler_states[sampler_state_id].positions = minimized_pos + + def equilibrate(self, n_iterations, mcmc_moves=None): + """Equilibrate all replicas. + + This does not increase the iteration counter. The equilibrated + positions are stored at the end. + + Parameters + ---------- + n_iterations : int + Number of equilibration iterations. + mcmc_moves : MCMCMove or list of MCMCMove, optional + Optionally, the MCMCMoves to use for equilibration can be + different from the ones used in production. + + """ + # Check that simulation has been created. + if self.n_replicas == 0: + raise RuntimeError( + "Cannot equilibrate replicas. The simulation must be created first." + ) + + # If no MCMCMove is specified, use the ones for production. + if mcmc_moves is None: + mcmc_moves = self._mcmc_moves + + # Make sure there is one MCMCMove per thermodynamic state. + if isinstance(mcmc_moves, mcmc.MCMCMove): + mcmc_moves = [copy.deepcopy(mcmc_moves) for _ in range(self.n_states)] + elif len(mcmc_moves) != self.n_states: + raise RuntimeError( + "The number of MCMCMoves ({}) and ThermodynamicStates ({}) for equilibration" + " must be the same.".format(len(self._mcmc_moves), self.n_states) + ) + from openmmtools.utils import Timer + + timer = Timer() + timer.start("Run Equilibration") + + # Temporarily set the equilibration MCMCMoves. + production_mcmc_moves = self._mcmc_moves + self._mcmc_moves = mcmc_moves + for iteration in range(1, 1 + n_iterations): + logger.info("Equilibration iteration {}/{}".format(iteration, n_iterations)) + timer.start("Equilibration Iteration") + + # NOTE: Unlike run(), do NOT increment iteration counter. + # self._iteration += 1 + + # Propagate replicas. + self._propagate_replicas() + + # Compute energies of all replicas at all states + self._compute_energies() + + # Update thermodynamic states + self._replica_thermodynamic_states = self._mix_replicas() + + # Computing timing information + iteration_time = timer.stop("Equilibration Iteration") + partial_total_time = timer.partial("Run Equilibration") + time_per_iteration = partial_total_time / iteration + estimated_time_remaining = time_per_iteration * (n_iterations - iteration) + estimated_total_time = time_per_iteration * n_iterations + estimated_finish_time = time.time() + estimated_time_remaining + # TODO: Transmit timing information + + log.info(f"Iteration took {iteration_time:.3f}s.") + if estimated_time_remaining != float("inf"): + log.info( + "Estimated completion (of equilibration only) in {}, at {} (consuming total wall clock time {}).".format( + str(datetime.timedelta(seconds=estimated_time_remaining)), + time.ctime(estimated_finish_time), + str(datetime.timedelta(seconds=estimated_total_time)), + ) + ) + timer.report_timing() + + # Restore production MCMCMoves. + self._mcmc_moves = production_mcmc_moves + + # TODO: Update stored positions. + + def run(self, n_iterations=None): + """Run the replica-exchange simulation. + + This runs at most ``number_of_iterations`` iterations. + + Parameters + ---------- + n_iterations : int, optional + If specified, only at most the specified number of iterations + will be run (default is None). + """ + # If this is the first iteration, compute and store the + # starting energies of the minimized/equilibrated structures. + if self._iteration == 0: + try: + self._compute_energies() + # We're intercepting a possible initial NaN position here thrown by OpenMM, which is a simple exception + # So we have to under-specify this trap. + except Exception as e: + if "coordinate is nan" in str(e).lower(): + err_message = "Initial coordinates were NaN! Check your inputs!" + logger.critical(err_message) + raise SimulationNaNError(err_message) + else: + # If not the special case, raise the error normally + raise e + mpiplus.run_single_node( + 0, + self._reporter.write_energies, + self._energy_thermodynamic_states, + self._neighborhoods, + self._energy_unsampled_states, + self._iteration, + ) + self._check_nan_energy() + + from openmmtools.utils import Timer + + timer = Timer() + timer.start("Run ReplicaExchange") + run_initial_iteration = self._iteration + + # Handle default argument and determine number of iterations to run. + if n_iterations is None: + iteration_limit = self.number_of_iterations + else: + iteration_limit = min( + self._iteration + n_iterations, self.number_of_iterations + ) + + # Main loop. + while not self._is_completed(iteration_limit): + # Increment iteration counter. + self._iteration += 1 + + logger.info("*" * 80) + logger.info("Iteration {}/{}".format(self._iteration, iteration_limit)) + logger.info("*" * 80) + timer.start("Iteration") + + # Update thermodynamic states + self._replica_thermodynamic_states = self._mix_replicas() + + # Propagate replicas. + self._propagate_replicas() + + # Compute energies of all replicas at all states + self._compute_energies() + + # Write iteration to storage file + self._report_iteration() + + # Update analysis + self._update_analysis() + + # Computing and transmitting timing information + iteration_time = timer.stop("Iteration") + partial_total_time = timer.partial("Run ReplicaExchange") + self._update_timing( + iteration_time, + partial_total_time, + run_initial_iteration, + iteration_limit, + ) + + # Log timing data as info level -- useful for users by default + logger.info( + "Iteration took {:.3f}s.".format(self._timing_data["iteration_seconds"]) + ) + if self._timing_data["estimated_time_remaining"] != float("inf"): + logger.info( + "Estimated completion in {}, at {} (consuming total wall clock time {}).".format( + self._timing_data["estimated_time_remaining"], + self._timing_data["estimated_localtime_finish_date"], + self._timing_data["estimated_total_time"], + ) + ) + + # Perform sanity checks to see if we should terminate here. + self._check_nan_energy() diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 62c1b59..a6db757 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -1,13 +1,23 @@ -def test_HO(): +from chiron.multistate import MultiStateSampler +import pytest + + +@pytest.fixture +def ho_multistate_sampler() -> MultiStateSampler: + """ + Create a MultiStateSampler object for performing multistate simulations for a harmonic oscillator. + + Returns: + MultiStateSampler: The created MultiStateSampler object. + """ import math from openmm import unit - from openmmtools import testsystems - from chiron.multistate import MultiStateSampler from chiron.mcmc import LangevinDynamicsMove - from chiron.states import ThermodynamicState - - testsystem = testsystems.AlanineDipeptideImplicit() + from chiron.states import ThermodynamicState, SamplerState + from openmmtools.testsystems import HarmonicOscillator + from chiron.potential import HarmonicOscillatorPotential + ho = HarmonicOscillator() n_replicas = 3 T_min = 298.0 * unit.kelvin # Minimum temperature. T_max = 600.0 * unit.kelvin # Maximum temperature. @@ -18,22 +28,34 @@ def test_HO(): / (math.e - 1.0) for i in range(n_replicas) ] - temperatures = [ - T_min - + (T_max - T_min) - * (math.exp(float(i) / float(n_replicas - 1)) - 1.0) - / (math.e - 1.0) - for i in range(n_replicas) - ] + + ho_potential = HarmonicOscillatorPotential(ho.topology) thermodynamic_states = [ - ThermodynamicState(system=testsystem.system, temperature=T) - for T in temperatures + ThermodynamicState(ho_potential, temperature=T) for T in temperatures ] + sampler_state = [SamplerState(ho.positions) for _ in temperatures] # Initialize simulation object with options. Run with a langevin integrator. - move = LangevinDynamicsMove(timestep=2.0 * unit.femtoseconds, n_steps=50) - simulation = MultiStateSampler(mcmc_moves=move, number_of_iterations=2) + move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=50) + multistate_sampler = MultiStateSampler(mcmc_moves=move, number_of_iterations=2) + multistate_sampler.create( + thermodynamic_states=thermodynamic_states, sampler_states=sampler_state + ) + + return multistate_sampler + + +def test_multistate_class(ho_multistate_sampler): + # test the multistate_sampler object + assert ho_multistate_sampler.number_of_iterations == 2 + assert ho_multistate_sampler.n_replicas == 3 + assert ho_multistate_sampler.n_states == 3 + assert ho_multistate_sampler._energy_thermodynamic_states.shape == (3, 3) + assert ho_multistate_sampler._n_proposed_matrix.shape == (3, 3) + + +def test_multistate_minimize(ho_multistate_sampler): + ho_multistate_sampler.minimize() + - # Run the simulation - simulation.run() From 795a8e39991f8bbad71247b3458fe7c8fb5f4a4b Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 28 Dec 2023 20:50:01 +0100 Subject: [PATCH 04/55] Update MCMC sampler log messages --- chiron/mcmc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/chiron/mcmc.py b/chiron/mcmc.py index b649afa..e78bf50 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -306,16 +306,15 @@ def run(self, n_iterations: int = 1): n_iterations : int, optional Number of iterations of the sampler to run. """ - log.info("Running Gibbs sampler") + log.info("Running MCMC sampler") log.info(f"move_schedule = {self.move.move_schedule}") - log.info("Running Gibbs sampler") for iteration in range(n_iterations): log.info(f"Iteration {iteration + 1}/{n_iterations}") for move_name, move in self.move.move_schedule: log.debug(f"Performing: {move_name}") move.run(self.sampler_state, self.thermodynamic_state) - log.info("Finished running Gibbs sampler") + log.info("Finished running MCMC sampler") log.debug("Closing reporter") for _, move in self.move.move_schedule: if move.simulation_reporter is not None: From b88b49b069315366609d1bdf619e3e63bba86744 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 28 Dec 2023 21:08:56 +0100 Subject: [PATCH 05/55] Add NeighborListNsqrd import and initialize it in MultiStateSampler constructor --- chiron/multistate.py | 31 ++++++++++++++----------------- chiron/tests/test_multistate.py | 16 +++++++++++++--- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index 6bdd125..5d2ddbe 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -8,7 +8,7 @@ from loguru import logger as log import numpy as np from openmmtools.utils import time_it - +from chiron.neighbors import NeighborListNsqrd import openmm from openmm import unit @@ -119,7 +119,6 @@ def __init__( # usage because any change to these attribute implies a change # in the storage file as well. Use properties for checks. self.number_of_iterations = number_of_iterations - self._last_mbar_f_k = None self._last_err_free_energy = None @@ -186,7 +185,7 @@ def create( self, thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], - metadata=None, + nbr_list: NeighborListNsqrd, ): """Create new multistate sampler simulation. @@ -206,6 +205,7 @@ def create( # TODO: initialize reporter here # TODO: consider unsampled thermodynamic states for reweighting schemes self._allocate_variables(thermodynamic_states, sampler_states) + self.nbr_list = nbr_list @classmethod def _default_initial_thermodynamic_states( @@ -270,13 +270,20 @@ def _minimize_replica( # Compute the initial energy of the system for logging. initial_energy = thermodynamic_state.get_reduced_potential(sampler_state) - print(initial_energy) log.debug( f"Replica {replica_id + 1}/{self.n_replicas}: initial energy {initial_energy:8.3f}kT" ) results = minimize_energy( - sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=0 + sampler_state.x0, + thermodynamic_state.potential.compute_energy, + self.nbr_list, + maxiter=0, + ) + sampler_state.positions = results.params + final_energy = thermodynamic_state.get_reduced_potential(sampler_state) + log.debug( + f"Replica {replica_id + 1}/{self.n_replicas}: final energy {final_energy:8.3f}kT" ) def minimize( @@ -306,20 +313,10 @@ def minimize( log.debug("Minimizing all replicas...") - # minimization + # minimization and update sampler states minimized_positions, sampler_state_ids = [], [] for replica_id in range(self.n_replicas): - minimized_position, sampler_state_id = self._minimize_replica( - replica_id, tolerance, max_iterations - ) - minimized_positions.append(minimized_position) - sampler_state_ids.append(sampler_state_id) - - # Update all sampler states. - for sampler_state_id, minimized_pos in zip( - sampler_state_ids, minimized_positions - ): - self._sampler_states[sampler_state_id].positions = minimized_pos + self._minimize_replica(replica_id, tolerance, max_iterations) def equilibrate(self, n_iterations, mcmc_moves=None): """Equilibrate all replicas. diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index a6db757..23155db 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -16,6 +16,7 @@ def ho_multistate_sampler() -> MultiStateSampler: from chiron.states import ThermodynamicState, SamplerState from openmmtools.testsystems import HarmonicOscillator from chiron.potential import HarmonicOscillatorPotential + from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace ho = HarmonicOscillator() n_replicas = 3 @@ -36,11 +37,22 @@ def ho_multistate_sampler() -> MultiStateSampler: sampler_state = [SamplerState(ho.positions) for _ in temperatures] # Initialize simulation object with options. Run with a langevin integrator. + # initialize the LennardJones potential in chiron + # + sigma = 0.34 * unit.nanometer + cutoff = 3.0 * sigma + skin = 0.5 * unit.nanometer + + nbr_list = NeighborListNsqrd( + OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 + ) move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=50) multistate_sampler = MultiStateSampler(mcmc_moves=move, number_of_iterations=2) multistate_sampler.create( - thermodynamic_states=thermodynamic_states, sampler_states=sampler_state + thermodynamic_states=thermodynamic_states, + sampler_states=sampler_state, + nbr_list=nbr_list, ) return multistate_sampler @@ -57,5 +69,3 @@ def test_multistate_class(ho_multistate_sampler): def test_multistate_minimize(ho_multistate_sampler): ho_multistate_sampler.minimize() - - From d03879c6aaa8869e95612343c70122a31e75811a Mon Sep 17 00:00:00 2001 From: wiederm Date: Fri, 29 Dec 2023 13:31:30 +0100 Subject: [PATCH 06/55] Fix minimization and potential initialization --- chiron/multistate.py | 22 ++++++++++++++-------- chiron/potential.py | 18 +++++++++++++++++- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index 5d2ddbe..9d53b9a 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -259,8 +259,18 @@ def _allocate_variables(self, thermodynamic_states, sampler_states): self._neighborhoods = np.zeros([self.n_replicas, self.n_states], "i1") def _minimize_replica( - self, replica_id: int, tolerance: unit.Quantity, max_iterations: int + self, replica_id: int, tolerance: unit.Quantity, max_iterations: int = 1_000 ): + """ + Minimizes the energy of a replica using the provided parameters. + + Parameters + ---------- + replica_id (int): The ID of the replica. + tolerance (unit.Quantity): The tolerance for convergence. + max_iterations (int, optional): The maximum number of iterations for the minimization. Defaults to 1_000. + """ + from chiron.minimze import minimize_energy # Retrieve thermodynamic and sampler states. @@ -278,9 +288,9 @@ def _minimize_replica( sampler_state.x0, thermodynamic_state.potential.compute_energy, self.nbr_list, - maxiter=0, + maxiter=max_iterations, ) - sampler_state.positions = results.params + sampler_state.x0 = results.params final_energy = thermodynamic_state.get_reduced_potential(sampler_state) log.debug( f"Replica {replica_id + 1}/{self.n_replicas}: final energy {final_energy:8.3f}kT" @@ -289,7 +299,7 @@ def _minimize_replica( def minimize( self, tolerance=1.0 * unit.kilojoules_per_mole / unit.nanometers, - max_iterations=0, + max_iterations: int = 1_000, ): """Minimize all replicas. @@ -314,16 +324,12 @@ def minimize( log.debug("Minimizing all replicas...") # minimization and update sampler states - minimized_positions, sampler_state_ids = [], [] for replica_id in range(self.n_replicas): self._minimize_replica(replica_id, tolerance, max_iterations) def equilibrate(self, n_iterations, mcmc_moves=None): """Equilibrate all replicas. - This does not increase the iteration counter. The equilibrated - positions are stored at the end. - Parameters ---------- n_iterations : int diff --git a/chiron/potential.py b/chiron/potential.py index b5982f0..91c6458 100644 --- a/chiron/potential.py +++ b/chiron/potential.py @@ -272,9 +272,24 @@ def __init__( self, topology: Topology, k: unit.Quantity = 1.0 * unit.kilocalories_per_mole / unit.angstrom**2, - x0: unit.Quantity = 0.0 * unit.angstrom, + x0: unit.Quantity = [[0.0, 0.0, 0.0]] * unit.angstrom, U0: unit.Quantity = 0.0 * unit.kilocalories_per_mole, ): + """ + Initialize a HarmonicOscillatorPotential object. + + Parameters: + ---------- + topology : Topology + The topology object representing the molecular system. + k : unit.Quantity, optional + The spring constant of the harmonic potential. Default is 1.0 kcal/mol/Å^2. + x0 : unit.Quantity, optional + The equilibrium position of the harmonic potential. Default is [0.0,0.0,0.0] Å. + U0 : unit.Quantity, optional + The offset potential energy of the harmonic potential. Default is 0.0 kcal/mol. + """ + if not isinstance(topology, Topology): if not isinstance( topology, property @@ -298,6 +313,7 @@ def __init__( raise ValueError( f"x0 must be a unit.Quantity with units of distance, x0.unit = {x0.unit}" ) + assert x0.shape[1] == 3, f"x0 must be a NX3 vector, x0.shape = {x0.shape}" if not U0.unit.is_compatible(unit.kilocalories_per_mole): raise ValueError( f"U0 must be a unit.Quantity with units of energy, U0.unit = {U0.unit}" From 551ce2b560998362cac60cb85736f7bf0e1b9de5 Mon Sep 17 00:00:00 2001 From: wiederm Date: Fri, 29 Dec 2023 13:31:37 +0100 Subject: [PATCH 07/55] Refactor multi-state sampler and add test for minimize method --- chiron/tests/test_multistate.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 23155db..4125eeb 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -5,10 +5,10 @@ @pytest.fixture def ho_multistate_sampler() -> MultiStateSampler: """ - Create a MultiStateSampler object for performing multistate simulations for a harmonic oscillator. + Create a multi-state sampler for a harmonic oscillator system. Returns: - MultiStateSampler: The created MultiStateSampler object. + MultiStateSampler: The multi-state sampler object. """ import math from openmm import unit @@ -29,10 +29,17 @@ def ho_multistate_sampler() -> MultiStateSampler: / (math.e - 1.0) for i in range(n_replicas) ] + import jax.numpy as jnp - ho_potential = HarmonicOscillatorPotential(ho.topology) + x0s = [ + unit.Quantity(jnp.array([[x0, 0.0, 0.0]]), unit.angstrom) + for x0 in jnp.linspace(0.0, 1.0, n_replicas) + ] thermodynamic_states = [ - ThermodynamicState(ho_potential, temperature=T) for T in temperatures + ThermodynamicState( + HarmonicOscillatorPotential(ho.topology, x0=x0), temperature=T + ) + for T, x0 in zip(temperatures, x0s) ] sampler_state = [SamplerState(ho.positions) for _ in temperatures] @@ -68,4 +75,21 @@ def test_multistate_class(ho_multistate_sampler): def test_multistate_minimize(ho_multistate_sampler): + """ + Test function for the `minimize` method of the `ho_multistate_sampler` object. + It checks if the sampler states are correctly minimized. + + Parameters + ---------- + ho_multistate_sampler: The `ho_multistate_sampler` object to be tested. + """ + + import numpy as np + ho_multistate_sampler.minimize() + + assert np.allclose( + ho_multistate_sampler.sampler_states[0].x0, np.array([[0.0, 0.0, 0.0]]) + ) + assert np.allclose(ho_multistate_sampler.sampler_states[1].x0, np.array([[0.05, 0.0, 0.0]]), atol=1e-2) + assert np.allclose(ho_multistate_sampler.sampler_states[2].x0, np.array([[0.1, 0.0, 0.0]]), atol=1e-2) From 409f5c72f1b336229a68d8c91436fa6ad3ba0397 Mon Sep 17 00:00:00 2001 From: wiederm Date: Fri, 29 Dec 2023 13:43:01 +0100 Subject: [PATCH 08/55] Update x0 initialization in HarmonicOscillatorPotential constructor --- chiron/potential.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chiron/potential.py b/chiron/potential.py index 91c6458..14676dd 100644 --- a/chiron/potential.py +++ b/chiron/potential.py @@ -272,7 +272,7 @@ def __init__( self, topology: Topology, k: unit.Quantity = 1.0 * unit.kilocalories_per_mole / unit.angstrom**2, - x0: unit.Quantity = [[0.0, 0.0, 0.0]] * unit.angstrom, + x0: unit.Quantity = jnp.array([[0.0, 0.0, 0.0]]) * unit.angstrom, U0: unit.Quantity = 0.0 * unit.kilocalories_per_mole, ): """ From 73776f5b28f8f40927a836278ed9381b48b46aa3 Mon Sep 17 00:00:00 2001 From: wiederm Date: Fri, 29 Dec 2023 16:25:13 +0100 Subject: [PATCH 09/55] Refactor MCMC moves and update move set class --- chiron/mcmc.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/chiron/mcmc.py b/chiron/mcmc.py index a767f6c..c67730e 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -38,15 +38,14 @@ """ from chiron.states import SamplerState, ThermodynamicState -from chiron.potential import NeuralNetworkPotential from openmm import unit from loguru import logger as log -from typing import Dict, Union, Tuple, List, Optional +from typing import Tuple, List, Optional import jax.numpy as jnp from chiron.reporters import SimulationReporter -class StateUpdateMove: +class MCMCMove: def __init__(self, nr_of_moves: int, seed: int): """ Initialize a move within the molecular system. @@ -64,7 +63,7 @@ def __init__(self, nr_of_moves: int, seed: int): self.key = jrandom.PRNGKey(seed) # 'seed' is an integer seed value -class LangevinDynamicsMove(StateUpdateMove): +class LangevinDynamicsMove(MCMCMove): def __init__( self, stepsize=1.0 * unit.femtoseconds, @@ -110,6 +109,13 @@ def run( state_variables (StateVariablesCollection): State variables of the system. """ + assert isinstance( + sampler_state, SamplerState + ), f"Sampler state must be SamplerState, not {type(sampler_state)}" + assert isinstance( + thermodynamic_state, ThermodynamicState + ), f"Thermodynamic state must be ThermodynamicState, not {type(thermodynamic_state)}" + self.integrator.run( thermodynamic_state=thermodynamic_state, sampler_state=sampler_state, @@ -118,7 +124,7 @@ def run( ) -class MCMove(StateUpdateMove): +class MCMove(MCMCMove): def __init__(self, nr_of_moves: int, seed: int) -> None: super().__init__(nr_of_moves, seed) @@ -247,7 +253,7 @@ class MoveSet: def __init__( self, - move_schedule: List[Tuple[str, StateUpdateMove]], + move_schedule: List[Tuple[str, MCMCMove]], ) -> None: _AVAILABLE_MOVES = ["LangevinDynamicsMove"] self.move_schedule = move_schedule @@ -264,7 +270,7 @@ def _validate_sequence(self): If a move in the sequence is not present in available_moves. """ for move_name, move_class in self.move_schedule: - if not isinstance(move_class, StateUpdateMove): + if not isinstance(move_class, MCMCMove): raise ValueError(f"Move {move_name} in the sequence is not available.") From 7c4f3ae04fca462d22bcc0f8634fc39a846b608b Mon Sep 17 00:00:00 2001 From: wiederm Date: Fri, 29 Dec 2023 16:30:04 +0100 Subject: [PATCH 10/55] Fix LangevinIntegrator save state bug --- chiron/integrators.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/chiron/integrators.py b/chiron/integrators.py index e94aed7..b25d348 100644 --- a/chiron/integrators.py +++ b/chiron/integrators.py @@ -58,6 +58,7 @@ def __init__( self.save_frequency = save_frequency self.velocities = None + def set_velocities(self, vel: unit.Quantity) -> None: """ Set the initial velocities for the Langevin Integrator. @@ -167,7 +168,8 @@ def run( if step % self.save_frequency == 0: # log.debug(f"Saving at step {step}") - if self.reporter is not None: + # check if reporter is attribute of the class + if hasattr(self, "reporter") and self.reporter is not None: d = { "traj": x, "energy": potential.compute_energy(x, nbr_list), @@ -180,4 +182,6 @@ def run( self.reporter.report(d) log.debug("Finished running Langevin dynamics") + # save the final state of the simulation in the sampler_state object + sampler_state.x0 = x # self.reporter.close() From 0c22db9bba386946ead88e0ea88f14a56518f668 Mon Sep 17 00:00:00 2001 From: wiederm Date: Fri, 29 Dec 2023 16:44:36 +0100 Subject: [PATCH 11/55] Refactor MultiStateSampler class --- chiron/multistate.py | 99 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 87 insertions(+), 12 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index 9d53b9a..a50f83e 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -1,4 +1,3 @@ -import os import copy import time from typing import List @@ -7,10 +6,10 @@ import datetime from loguru import logger as log import numpy as np -from openmmtools.utils import time_it +from openmmtools.utils import time_it, with_timer from chiron.neighbors import NeighborListNsqrd -import openmm from openmm import unit +from chiron.mcmc import MCMCMove class MultiStateSampler(object): @@ -81,11 +80,7 @@ class MultiStateSampler(object): is_completed """ - def __init__( - self, - mcmc_moves=None, - number_of_iterations=1, - ): + def __init__(self, mcmc_moves=None, number_of_iterations=1, locality=None): # These will be set on initialization. See function # create() for explanation of single variables. self._thermodynamic_states = None @@ -122,6 +117,9 @@ def __init__( self._last_mbar_f_k = None self._last_err_free_energy = None + # Store locality + self.locality = locality + @property def n_states(self): """The integer number of thermodynamic states (read-only).""" @@ -350,12 +348,11 @@ def equilibrate(self, n_iterations, mcmc_moves=None): mcmc_moves = self._mcmc_moves # Make sure there is one MCMCMove per thermodynamic state. - if isinstance(mcmc_moves, mcmc.MCMCMove): + if isinstance(mcmc_moves, MCMCMove): mcmc_moves = [copy.deepcopy(mcmc_moves) for _ in range(self.n_states)] elif len(mcmc_moves) != self.n_states: raise RuntimeError( - "The number of MCMCMoves ({}) and ThermodynamicStates ({}) for equilibration" - " must be the same.".format(len(self._mcmc_moves), self.n_states) + f"The number of MCMCMoves ({len(self._mcmc_moves)}) and ThermodynamicStates ({self.n_states}) for equilibration must be the same." ) from openmmtools.utils import Timer @@ -366,7 +363,7 @@ def equilibrate(self, n_iterations, mcmc_moves=None): production_mcmc_moves = self._mcmc_moves self._mcmc_moves = mcmc_moves for iteration in range(1, 1 + n_iterations): - logger.info("Equilibration iteration {}/{}".format(iteration, n_iterations)) + log.info("Equilibration iteration {iteration}/{n_iterations}") timer.start("Equilibration Iteration") # NOTE: Unlike run(), do NOT increment iteration counter. @@ -406,6 +403,84 @@ def equilibrate(self, n_iterations, mcmc_moves=None): # TODO: Update stored positions. + def _propagate_replica(self, replica_id): + """Propagate thermodynamic state associated to the given replica.""" + # Retrieve thermodynamic, sampler states, and MCMC move of this replica. + thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] + thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] + mcmc_move = self._mcmc_moves[thermodynamic_state_id] + sampler_state = self._sampler_states[replica_id] + + # Apply MCMC move. + try: + mcmc_move.run(sampler_state, thermodynamic_state) + except Exception as e: + log.warning(e) + raise e + + @with_timer("Propagating all replicas") + def _propagate_replicas(self): + """Propagate all replicas.""" + + log.debug("Propagating all replicas...") + + for i in range(self.n_replicas): + self._propagate_replica(i) + + def _neighborhood(self, state_index): + """Compute the states in the local neighborhood determined by self.locality + + Parameters + ---------- + state_index : int + The current state + + Returns + ------- + neighborhood : list of int + The states in the local neighborhood + """ + if self.locality is None: + # Global neighborhood + return list(range(0, self.n_states)) + else: + # Local neighborhood specified by 'locality' + return list( + range( + max(0, state_index - self.locality), + min(self.n_states, state_index + self.locality + 1), + ) + ) + + @with_timer("Computing energy matrix") + def _compute_energies(self): + """Compute energies of all replicas at all states.""" + + # Determine neighborhoods (all nodes) + self._neighborhoods[:, :] = False + for replica_index, state_index in enumerate(self._replica_thermodynamic_states): + neighborhood = self._neighborhood(state_index) + self._neighborhoods[replica_index, neighborhood] = True + + # Distribute energy computation across nodes. Only node 0 receives + # all the energies since it needs to store them and mix states. + new_energies, replica_ids = [], [] + for i in range(self.n_replicas): + new_energy, replica_id = self._compute_replica_energies(i) + new_energies.append(new_energy) + replica_ids.append(replica_id) + + # Update energy matrices. Non-0 nodes update only the energies computed by this replica. + for replica_id, energies in zip(replica_ids, new_energies): + energy_thermodynamic_states, energy_unsampled_states = energies # Unpack. + neighborhood = self._neighborhood( + self._replica_thermodynamic_states[replica_id] + ) + self._energy_thermodynamic_states[ + replica_id, neighborhood + ] = energy_thermodynamic_states + self._energy_unsampled_states[replica_id] = energy_unsampled_states + def run(self, n_iterations=None): """Run the replica-exchange simulation. From 741ee0d97dec8a73fe4b2db0b330a5cd977387af Mon Sep 17 00:00:00 2001 From: wiederm Date: Fri, 29 Dec 2023 17:54:33 +0100 Subject: [PATCH 12/55] Refactor ThermodynamicState class to check compatibility between states --- chiron/states.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/chiron/states.py b/chiron/states.py index edfae43..b9dd8ad 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -214,26 +214,30 @@ def _check_completness(self): if self.temperature and self.pressure and self.nr_of_particles: log.info("NpT ensemble is simulated.") - @classmethod - def are_states_compatible(cls, state1, state2): - """ - Check if two simulation states are compatible. - - This method should define the criteria for compatibility, - such as matching number of particles, etc. + def is_state_compatible(self, thermodynamic_state): + """Check compatibility between ThermodynamicStates. Parameters ---------- - state1 : SimulationState - The first simulation state to compare. - state2 : SimulationState - The second simulation state to compare. + thermodynamic_state : ThermodynamicState + The thermodynamic state to test. Returns ------- - bool - True if states are compatible, False otherwise. + is_compatible : bool + True if the states are compatible, False otherwise. + + Examples + -------- + States in the same ensemble (NVT or NPT) are compatible. + States in different ensembles are not compatible. + States that store different systems (that differ by more than + barostat and thermostat pressure and temperature) are also not + compatible. """ + + # Check that the states are in the same ensemble. + # TODO: implement this pass def get_reduced_potential( From 61af8ddca58320ffcc4ff6ceb4d859f3dfd58d43 Mon Sep 17 00:00:00 2001 From: wiederm Date: Fri, 29 Dec 2023 17:54:53 +0100 Subject: [PATCH 13/55] Refactor MultiStateSampler class to remove unused code and improve code organization --- chiron/multistate.py | 85 +++++++++++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 33 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index a50f83e..63f47e2 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -23,10 +23,6 @@ class MultiStateSampler(object): state are specified and replica mixing does not change any thermodynamic states, meaning that each replica remains in its original thermodynamic state. - Stored configurations, energies, swaps, and restart information are all written - to a single output file using the platform portable, robust, and efficient - NetCDF4 library. - Parameters ---------- mcmc_moves : MCMCMove or list of MCMCMove, optional @@ -35,35 +31,7 @@ class MultiStateSampler(object): creation. If None is provided, Langevin dynamics with 2fm timestep, 5.0/ps collision rate, and 500 steps per iteration will be used. number_of_iterations : int or infinity, optional, default: 1 - The number of iterations to perform. Both ``float('inf')`` and - ``numpy.inf`` are accepted for infinity. If you set this to infinity, - be sure to set also ``online_analysis_interval``. - online_analysis_interval : None or Int >= 1, optional, default: 200 - Choose the interval at which to perform online analysis of the free energy. - - After every interval, the simulation will be stopped and the free energy estimated. - - If the error in the free energy estimate is at or below ``online_analysis_target_error``, then the simulation - will be considered completed. - - If set to ``None``, then no online analysis is performed - - online_analysis_target_error : float >= 0, optional, default 0.0 - The target error for the online analysis measured in kT per phase. - - Once the free energy is at or below this value, the phase will be considered complete. - - If ``online_analysis_interval`` is None, this option does nothing. - - Default is set to 0.0 since online analysis runs by default, but a finite ``number_of_iterations`` should also - be set to ensure there is some stop condition. If target error is 0 and an infinite number of iterations is set, - then the sampler will run until the user stop it manually. - - online_analysis_minimum_iterations : int >= 0, optional, default 200 - Set the minimum number of iterations which must pass before online analysis is carried out. - - Since the initial samples likely not to yield a good estimate of free energy, save time and just skip them - If ``online_analysis_interval`` is None, this does nothing + The number of iterations to perform. locality : int > 0, optional, default None If None, the energies at all states will be computed for every replica each iteration. @@ -179,6 +147,57 @@ def is_completed(self): """Check if we have reached any of the stop target criteria (read-only)""" return self._is_completed() + def _compute_replica_energies(self, replica_id): + """Compute the energy for the replica in every ThermodynamicState.""" + # Determine neighborhood + import jax.numpy as jnp + + state_index = self._replica_thermodynamic_states[replica_id] + neighborhood = self._neighborhood(state_index) + + # Only compute energies of the sampled states over neighborhoods. + energy_neighborhood_states = jnp.zeros(len(neighborhood)) + neighborhood_thermodynamic_states = [ + self._thermodynamic_states[n] for n in neighborhood + ] + + # Retrieve sampler state associated to this replica. + sampler_state = self._sampler_states[replica_id] + + # Compute energy for all thermodynamic states. + from openmmtools.states import group_by_compatibility + + for energies, the_states in [ + (energy_neighborhood_states, neighborhood_thermodynamic_states), + ]: + # Group thermodynamic states by compatibility. + compatible_groups, original_indices = group_by_compatibility(the_states) + + # Compute the reduced potentials of all the compatible states. + for compatible_group, state_indices in zip( + compatible_groups, original_indices + ): + # Get the context, any Integrator works. + context, integrator = self.energy_context_cache.get_context( + compatible_group[0] + ) + + # Update positions and box vectors. We don't need + # to set Context velocities for the potential. + sampler_state.apply_to_context(context, ignore_velocities=True) + + # Compute and update the reduced potentials. + compatible_energies = ( + states.ThermodynamicState.reduced_potential_at_states( + context, compatible_group + ) + ) + for energy_idx, state_idx in enumerate(state_indices): + energies[state_idx] = compatible_energies[energy_idx] + + # Return the new energies. + return energy_neighborhood_states, energy_unsampled_states + def create( self, thermodynamic_states: List[ThermodynamicState], From 2be00f2c856e318a8ef4ed4bcb7190ca88ae82e7 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 2 Jan 2024 14:39:09 +0100 Subject: [PATCH 14/55] Add function to calculate reduced potential at different thermodynamic states --- chiron/states.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/chiron/states.py b/chiron/states.py index b9dd8ad..0a0104b 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -3,7 +3,6 @@ from jax import numpy as jnp from loguru import logger as log from .potential import NeuralNetworkPotential -from openmm.app import Topology class SamplerState: @@ -287,3 +286,30 @@ def get_reduced_potential( def kT_to_kJ_per_mol(self, energy): energy = energy * unit.AVOGADRO_CONSTANT_NA return energy / self.beta + + +def calculate_reduced_potential_at_states( + sampler_state: SamplerState, + themrodynamic_states: List[ThermodynamicState], + nbr_list=None, +): + """ + Calculate the reduced potential for a list of thermodynamic states. + + Parameters + ---------- + sampler_state : SamplerState + The sampler state for which to compute the reduced potential. + thermodynamic_states : list of ThermodynamicState + The thermodynamic states for which to compute the reduced potential. + nbr_list : NeighborList or PairList, optional + Returns + ------- + list of float + The reduced potential of the system for each thermodynamic state. + + """ + reduced_potentials = [] + for state in themrodynamic_states: + reduced_potentials.append(state.get_reduced_potential(sampler_state)) + return reduced_potentials From 2c615138f9e8a40be10348b96e7194dbd6762ba3 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 2 Jan 2024 15:59:40 +0100 Subject: [PATCH 15/55] Add n_particles property to SamplerState class --- chiron/states.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/chiron/states.py b/chiron/states.py index 0a0104b..3a71dbf 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -68,6 +68,10 @@ def __init__( self._box_vectors = box_vectors self._distance_unit = unit.nanometer + @property + def n_particles(self) -> int: + return self._x0.shape[0] + @property def x0(self) -> jnp.array: return self._convert_to_jnp(self._x0) From b425500cbf8684ab7c246c9a4be7f1828a9666f3 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 2 Jan 2024 16:02:27 +0100 Subject: [PATCH 16/55] first pass on chiron multistate sampling class --- chiron/multistate.py | 578 +++++++++++++++++++++++++++++++++---------- 1 file changed, 452 insertions(+), 126 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index 63f47e2..fcf1663 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -1,7 +1,6 @@ import copy import time -from typing import List -import loguru as logger +from typing import List, Optional from chiron.states import SamplerState, ThermodynamicState import datetime from loguru import logger as log @@ -10,6 +9,7 @@ from chiron.neighbors import NeighborListNsqrd from openmm import unit from chiron.mcmc import MCMCMove +from openmmtools.multistate import MultiStateReporter class MultiStateSampler(object): @@ -147,23 +147,38 @@ def is_completed(self): """Check if we have reached any of the stop target criteria (read-only)""" return self._is_completed() - def _compute_replica_energies(self, replica_id): - """Compute the energy for the replica in every ThermodynamicState.""" - # Determine neighborhood + def _compute_replica_energies(self, replica_id: int) -> np.ndarray: + """ + Compute the energy for the replica in every ThermodynamicState. + + Parameters + ---------- + replica_id : int + The ID of the replica to compute energies for. + + Returns + ------- + np.ndarray + Array of energies for the specified replica across all thermodynamic states. + """ import jax.numpy as jnp + from chiron.states import calculate_reduced_potential_at_states + + log.debug(f"{self._replica_thermodynamic_states=}") + # Determine neighborhood state_index = self._replica_thermodynamic_states[replica_id] neighborhood = self._neighborhood(state_index) - + log.debug(f"{neighborhood=}") # Only compute energies of the sampled states over neighborhoods. - energy_neighborhood_states = jnp.zeros(len(neighborhood)) + energy_neighborhood_states = np.zeros(len(neighborhood)) neighborhood_thermodynamic_states = [ self._thermodynamic_states[n] for n in neighborhood ] # Retrieve sampler state associated to this replica. sampler_state = self._sampler_states[replica_id] - + log.debug(f"{sampler_state=}") # Compute energy for all thermodynamic states. from openmmtools.states import group_by_compatibility @@ -177,52 +192,58 @@ def _compute_replica_energies(self, replica_id): for compatible_group, state_indices in zip( compatible_groups, original_indices ): - # Get the context, any Integrator works. - context, integrator = self.energy_context_cache.get_context( - compatible_group[0] - ) - - # Update positions and box vectors. We don't need - # to set Context velocities for the potential. - sampler_state.apply_to_context(context, ignore_velocities=True) - # Compute and update the reduced potentials. - compatible_energies = ( - states.ThermodynamicState.reduced_potential_at_states( - context, compatible_group - ) + compatible_energies = calculate_reduced_potential_at_states( + sampler_state, compatible_group, self.nbr_list ) for energy_idx, state_idx in enumerate(state_indices): energies[state_idx] = compatible_energies[energy_idx] # Return the new energies. - return energy_neighborhood_states, energy_unsampled_states + log.info(f"Computed energies for replica {replica_id}") + log.info(f"{energy_neighborhood_states=}") + return energy_neighborhood_states def create( self, thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_list: NeighborListNsqrd, + reporter: MultiStateReporter, + metadata: Optional[dict] = None, ): """Create new multistate sampler simulation. - Parameters - ---------- - thermodynamic_states : list of ThermodynamicState - Thermodynamic states to simulate, where one replica is allocated per state. - Each state must have a system with the same number of atoms. - sampler_states : list of SamplerState - One or more sets of initial sampler states. - The number of replicas is taken to be the number of sampler states provided. - If the sampler states do not have box_vectors attached and the system is periodic, - an exception will be thrown. - metadata : dict, optional, default=None - Simulation metadata to be stored in the file. + thermodynamic_states : List[ThermodynamicState] + List of ThermodynamicStates to simulate, with one replica allocated per state. + sampler_states : List[SamplerState] + List of initial SamplerStates. The number of replicas is taken to be the number + of sampler states provided. + nbr_list : NeighborListNsqrd + Neighbor list object to be used in the simulation. + reporter : MultiStateReporter + Reporter object to record simulation data. + metadata : dict, optional + Optional simulation metadata to be stored in the file. + + Raises + ------ + RuntimeError + If the lengths of thermodynamic_states and sampler_states are not equal. """ # TODO: initialize reporter here # TODO: consider unsampled thermodynamic states for reweighting schemes + + # Ensure the number of thermodynamic states matches the number of sampler states + if len(thermodynamic_states) != len(sampler_states): + raise RuntimeError( + "Number of thermodynamic states and sampler states must be equal." + ) + self._allocate_variables(thermodynamic_states, sampler_states) self.nbr_list = nbr_list + self._reporter = reporter + self._reporter.open(mode="a") @classmethod def _default_initial_thermodynamic_states( @@ -241,7 +262,32 @@ def _default_initial_thermodynamic_states( initial_thermo_states = np.arange(n_thermo, dtype=int) return initial_thermo_states - def _allocate_variables(self, thermodynamic_states, sampler_states): + def _allocate_variables( + self, + thermodynamic_states: List[ThermodynamicState], + sampler_states: List[SamplerState], + unsampled_thermodynamic_states: Optional[List[ThermodynamicState]] = None, + ) -> None: + """ + Allocate and initialize internal variables for the sampler. + + Parameters + ---------- + thermodynamic_states : List[ThermodynamicState] + A list of ThermodynamicState objects to be used in the sampler. + sampler_states : List[SamplerState] + A list of SamplerState objects for initializing the sampler. + unsampled_thermodynamic_states : Optional[List[ThermodynamicState]], optional + A list of additional ThermodynamicState objects that are not directly sampled but + for which energies will be computed for reweighting schemes. Defaults to None, + meaning no unsampled states are considered. + + Raises + ------ + RuntimeError + If the number of MCMC moves and ThermodynamicStates do not match. + """ + # Save thermodynamic states. This sets n_replicas. self._thermodynamic_states = [ copy.deepcopy(thermodynamic_state) @@ -253,6 +299,13 @@ def _allocate_variables(self, thermodynamic_states, sampler_states): copy.deepcopy(sampler_state) for sampler_state in sampler_states ] + # Handle default unsampled thermodynamic states. + self._unsampled_states = ( + copy.deepcopy(unsampled_thermodynamic_states) + if unsampled_thermodynamic_states is not None + else [] + ) + # Set initial thermodynamic state indices initial_thermodynamic_states = self._default_initial_thermodynamic_states( thermodynamic_states, sampler_states @@ -262,37 +315,66 @@ def _allocate_variables(self, thermodynamic_states, sampler_states): ) # Reset statistics. + # _n_accepted_matrix[i][j] is the number of swaps proposed between thermodynamic states i and j. # _n_proposed_matrix[i][j] is the number of swaps proposed between thermodynamic states i and j. - self._n_accepted_matrix = np.zeros([self.n_states, self.n_states], np.int64) - self._n_proposed_matrix = np.zeros([self.n_states, self.n_states], np.int64) - # Allocate memory for energy matrix. energy_thermodynamic_states[k][l] # is the reduced potential computed at the positions of SamplerState sampler_states[k] # and ThermodynamicState thermodynamic_states[l]. + + self._n_accepted_matrix = np.zeros([self.n_states, self.n_states], np.int64) + self._n_proposed_matrix = np.zeros([self.n_states, self.n_states], np.int64) self._energy_thermodynamic_states = np.zeros( [self.n_replicas, self.n_states], np.float64 ) self._neighborhoods = np.zeros([self.n_replicas, self.n_states], "i1") + self._energy_unsampled_states = np.zeros( + [self.n_replicas, len(self._unsampled_states)], np.float64 + ) + + # Ensure there is an MCMCMove for each thermodynamic state. + if isinstance(self._mcmc_moves, MCMCMove): + self._mcmc_moves = [ + copy.deepcopy(self._mcmc_moves) for _ in range(self.n_states) + ] + elif len(self._mcmc_moves) != self.n_states: + raise RuntimeError( + f"The number of MCMCMoves ({len(self._mcmc_moves)}) and ThermodynamicStates ({self.n_states}) must be the same." + ) + + # Reset iteration counter. + self._iteration = 0 def _minimize_replica( - self, replica_id: int, tolerance: unit.Quantity, max_iterations: int = 1_000 - ): + self, + replica_id: int, + tolerance: unit.Quantity = 1.0 * unit.kilojoules_per_mole / unit.nanometers, + max_iterations: int = 1_000, + ) -> None: """ - Minimizes the energy of a replica using the provided parameters. + Minimize the energy of a single replica. Parameters ---------- - replica_id (int): The ID of the replica. - tolerance (unit.Quantity): The tolerance for convergence. - max_iterations (int, optional): The maximum number of iterations for the minimization. Defaults to 1_000. + replica_id : int + The index of the replica to minimize. + tolerance : unit.Quantity, optional + The energy tolerance to which the system should be minimized. + Defaults to 1.0 kilojoules/mole/nanometers. + max_iterations : int, optional + The maximum number of minimization iterations. Defaults to 1000. + + Notes + ----- + The minimization modifies the SamplerState associated with the replica. """ from chiron.minimze import minimize_energy # Retrieve thermodynamic and sampler states. - thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] - thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] + thermodynamic_state = self._thermodynamic_states[ + self._replica_thermodynamic_states[replica_id] + ] sampler_state = self._sampler_states[replica_id] # Compute the initial energy of the system for logging. @@ -301,13 +383,18 @@ def _minimize_replica( f"Replica {replica_id + 1}/{self.n_replicas}: initial energy {initial_energy:8.3f}kT" ) - results = minimize_energy( + # Perform minimization + minimized_state = minimize_energy( sampler_state.x0, thermodynamic_state.potential.compute_energy, self.nbr_list, maxiter=max_iterations, ) - sampler_state.x0 = results.params + + # Update the sampler state + self._sampler_states[replica_id].x0 = minimized_state.params + + # Compute and log final energy final_energy = thermodynamic_state.get_reduced_potential(sampler_state) log.debug( f"Replica {replica_id + 1}/{self.n_replicas}: final energy {final_energy:8.3f}kT" @@ -317,21 +404,28 @@ def minimize( self, tolerance=1.0 * unit.kilojoules_per_mole / unit.nanometers, max_iterations: int = 1_000, - ): - """Minimize all replicas. + ) -> None: + """ + Minimize all replicas in the sampler. - Minimized positions are stored at the end. + This method minimizes the positions of all replicas to the nearest local + minimum of the potential energy surface. The minimized positions are stored + at the end of the process. Parameters ---------- - tolerance : openmm.unit.Quantity, optional - Minimization tolerance (units of energy/mole/length, default is - ``1.0 * unit.kilojoules_per_mole / unit.nanometers``). + tolerance : unit.Quantity, optional + The energy tolerance for the minimization. Default is 1.0 kJ/mol/nm. max_iterations : int, optional - Maximum number of iterations for minimization. If 0, minimization - continues until converged. + The maximum number of iterations for the minimization process. + Default is 1000. + Raises + ------ + RuntimeError + If the simulation has not been created before calling this method. """ + # Check that simulation has been created. if self.n_replicas == 0: raise RuntimeError( @@ -340,21 +434,51 @@ def minimize( log.debug("Minimizing all replicas...") - # minimization and update sampler states + # Iterate over all replicas and minimize them for replica_id in range(self.n_replicas): self._minimize_replica(replica_id, tolerance, max_iterations) - def equilibrate(self, n_iterations, mcmc_moves=None): - """Equilibrate all replicas. + def _equilibration_timings(self, timer, iteration: int, n_iterations: int): + iteration_time = timer.stop("Equilibration Iteration") + partial_total_time = timer.partial("Run Equilibration") + time_per_iteration = partial_total_time / iteration + estimated_time_remaining = time_per_iteration * (n_iterations - iteration) + estimated_total_time = time_per_iteration * n_iterations + estimated_finish_time = time.time() + estimated_time_remaining + # TODO: Transmit timing information + + log.info(f"Iteration took {iteration_time:.3f}s.") + if estimated_time_remaining != float("inf"): + log.info( + "Estimated completion (of equilibration only) in {}, at {} (consuming total wall clock time {}).".format( + str(datetime.timedelta(seconds=estimated_time_remaining)), + time.ctime(estimated_finish_time), + str(datetime.timedelta(seconds=estimated_total_time)), + ) + ) + + def equilibrate( + self, n_iterations: int, mcmc_moves: Optional[List[MCMCMove]] = None + ): + """ + Equilibrate all replicas in the sampler. + + This method equilibrates the system by running a specified number of + MCMC iterations. The equilibration uses either the provided MCMC moves + or the default ones set during initialization. Parameters ---------- n_iterations : int - Number of equilibration iterations. - mcmc_moves : MCMCMove or list of MCMCMove, optional - Optionally, the MCMCMoves to use for equilibration can be - different from the ones used in production. - + The number of equilibration iterations to perform. + mcmc_moves : Optional[List[mcmc.MCMCMove]], optional + A list of MCMCMove objects to use for equilibration. If None, the + MCMC moves used in production will be used. Defaults to None. + + Raises + ------ + RuntimeError + If the simulation has not been created before calling this method. """ # Check that simulation has been created. if self.n_replicas == 0: @@ -362,14 +486,14 @@ def equilibrate(self, n_iterations, mcmc_moves=None): "Cannot equilibrate replicas. The simulation must be created first." ) - # If no MCMCMove is specified, use the ones for production. - if mcmc_moves is None: - mcmc_moves = self._mcmc_moves + # Use production MCMC moves if none are provided + mcmc_moves = mcmc_moves or self._mcmc_moves # Make sure there is one MCMCMove per thermodynamic state. if isinstance(mcmc_moves, MCMCMove): mcmc_moves = [copy.deepcopy(mcmc_moves) for _ in range(self.n_states)] - elif len(mcmc_moves) != self.n_states: + + if len(mcmc_moves) != self.n_states: raise RuntimeError( f"The number of MCMCMoves ({len(self._mcmc_moves)}) and ThermodynamicStates ({self.n_states}) for equilibration must be the same." ) @@ -381,8 +505,9 @@ def equilibrate(self, n_iterations, mcmc_moves=None): # Temporarily set the equilibration MCMCMoves. production_mcmc_moves = self._mcmc_moves self._mcmc_moves = mcmc_moves - for iteration in range(1, 1 + n_iterations): - log.info("Equilibration iteration {iteration}/{n_iterations}") + + for iteration in range(1, n_iterations + 1): + log.info(f"Equilibration iteration {iteration}/{n_iterations}") timer.start("Equilibration Iteration") # NOTE: Unlike run(), do NOT increment iteration counter. @@ -398,23 +523,9 @@ def equilibrate(self, n_iterations, mcmc_moves=None): self._replica_thermodynamic_states = self._mix_replicas() # Computing timing information - iteration_time = timer.stop("Equilibration Iteration") - partial_total_time = timer.partial("Run Equilibration") - time_per_iteration = partial_total_time / iteration - estimated_time_remaining = time_per_iteration * (n_iterations - iteration) - estimated_total_time = time_per_iteration * n_iterations - estimated_finish_time = time.time() + estimated_time_remaining - # TODO: Transmit timing information - - log.info(f"Iteration took {iteration_time:.3f}s.") - if estimated_time_remaining != float("inf"): - log.info( - "Estimated completion (of equilibration only) in {}, at {} (consuming total wall clock time {}).".format( - str(datetime.timedelta(seconds=estimated_time_remaining)), - time.ctime(estimated_finish_time), - str(datetime.timedelta(seconds=estimated_total_time)), - ) - ) + self._equilibration_timings( + timer, iteration=iteration, n_iterations=n_iterations + ) timer.report_timing() # Restore production MCMCMoves. @@ -422,13 +533,29 @@ def equilibrate(self, n_iterations, mcmc_moves=None): # TODO: Update stored positions. - def _propagate_replica(self, replica_id): - """Propagate thermodynamic state associated to the given replica.""" + def _propagate_replica(self, replica_id: int): + """ + Propagate the state of a single replica. + + This method applies the MCMC move to the replica to change its state + according to the specified thermodynamic state. + + Parameters + ---------- + replica_id : int + The index of the replica to propagate. + Raises + ------ + RuntimeError + If an error occurs during the propagation of the replica. + """ # Retrieve thermodynamic, sampler states, and MCMC move of this replica. + # Retrieve thermodynamic and sampler states for the replica thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] + sampler_state = self._sampler_states[replica_id] + thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] mcmc_move = self._mcmc_moves[thermodynamic_state_id] - sampler_state = self._sampler_states[replica_id] # Apply MCMC move. try: @@ -437,6 +564,59 @@ def _propagate_replica(self, replica_id): log.warning(e) raise e + def _perform_swap_proposals(self): + """ + Perform swap proposals between replicas. + + Placeholder method for replica swapping logic. Subclasses should + override this method with specific swapping algorithms. + + Returns + ------- + np.ndarray + An array of updated thermodynamic state indices for each replica. + """ + + # Placeholder implementation, should be overridden by subclasses + # For this example, we'll just return the current state indices + return self._replica_thermodynamic_states + + def _mix_replicas(self) -> np.ndarray: + """ + Propose and execute swaps between replicas. + + This method is responsible for enhancing sampling efficiency by proposing + swaps between different thermodynamic states of the replicas. The actual + swapping algorithm depends on the specific subclass implementation. + + Returns + ------- + np.ndarray + An array of updated thermodynamic state indices for each replica. + """ + + log.debug("Mixing replicas (does nothing for MultiStateSampler)...") + + # Reset storage to keep track of swap attempts this iteration. + self._n_accepted_matrix[:, :] = 0 + self._n_proposed_matrix[:, :] = 0 + + # Perform replica mixing (swap proposals and acceptances) + # The actual swapping logic would depend on subclass implementations + # Here, we assume a placeholder implementation + new_replica_states = self._perform_swap_proposals() + + # Calculate swap acceptance statistics + n_swaps_proposed = self._n_proposed_matrix.sum() + n_swaps_accepted = self._n_accepted_matrix.sum() + swap_fraction_accepted = 0.0 + if n_swaps_proposed > 0: + swap_fraction_accepted = n_swaps_accepted / n_swaps_proposed + log.debug( + f"Accepted {n_swaps_accepted}/{n_swaps_proposed} attempted swaps ({swap_fraction_accepted * 100.0:.1f}%)" + ) + return new_replica_states + @with_timer("Propagating all replicas") def _propagate_replicas(self): """Propagate all replicas.""" @@ -481,24 +661,42 @@ def _compute_energies(self): neighborhood = self._neighborhood(state_index) self._neighborhoods[replica_index, neighborhood] = True - # Distribute energy computation across nodes. Only node 0 receives - # all the energies since it needs to store them and mix states. + # Calculate energies for all replicas. new_energies, replica_ids = [], [] - for i in range(self.n_replicas): - new_energy, replica_id = self._compute_replica_energies(i) + for replica_id in range(self.n_replicas): + new_energy = self._compute_replica_energies(replica_id) new_energies.append(new_energy) replica_ids.append(replica_id) - # Update energy matrices. Non-0 nodes update only the energies computed by this replica. + # Update energy matrices. for replica_id, energies in zip(replica_ids, new_energies): - energy_thermodynamic_states, energy_unsampled_states = energies # Unpack. + energy_thermodynamic_states = energies # Unpack. neighborhood = self._neighborhood( self._replica_thermodynamic_states[replica_id] ) self._energy_thermodynamic_states[ replica_id, neighborhood ] = energy_thermodynamic_states - self._energy_unsampled_states[replica_id] = energy_unsampled_states + + def _is_completed(self, iteration_limit=None): + """Check if we have reached any of the stop target criteria. + + Parameters + ---------- + iteration_limit : int, optional + If specified, the simulation will stop if the iteration counter reaches this value. + + Returns + ------- + is_completed : bool + If True, the simulation is completed and should be terminated. + """ + if iteration_limit is not None and self._iteration >= iteration_limit: + log.info( + f"Reached iteration limit {iteration_limit} (current iteration {self._iteration})" + ) + return True + return False def run(self, n_iterations=None): """Run the replica-exchange simulation. @@ -513,28 +711,21 @@ def run(self, n_iterations=None): """ # If this is the first iteration, compute and store the # starting energies of the minimized/equilibrated structures. + + log.info("Running simulation...") if self._iteration == 0: try: self._compute_energies() - # We're intercepting a possible initial NaN position here thrown by OpenMM, which is a simple exception - # So we have to under-specify this trap. except Exception as e: - if "coordinate is nan" in str(e).lower(): - err_message = "Initial coordinates were NaN! Check your inputs!" - logger.critical(err_message) - raise SimulationNaNError(err_message) - else: - # If not the special case, raise the error normally - raise e - mpiplus.run_single_node( - 0, - self._reporter.write_energies, - self._energy_thermodynamic_states, - self._neighborhoods, - self._energy_unsampled_states, - self._iteration, + log.critical(e) + raise e + + self._reporter.write_energies( + energy_thermodynamic_states=self._energy_thermodynamic_states, + energy_neighborhoods=self._neighborhoods, + energy_unsampled_states=self._energy_unsampled_states, + iteration=self._iteration, ) - self._check_nan_energy() from openmmtools.utils import Timer @@ -555,9 +746,9 @@ def run(self, n_iterations=None): # Increment iteration counter. self._iteration += 1 - logger.info("*" * 80) - logger.info("Iteration {}/{}".format(self._iteration, iteration_limit)) - logger.info("*" * 80) + log.info("-" * 80) + log.info(f"Iteration {self._iteration}/{iteration_limit}") + log.info("-" * 80) timer.start("Iteration") # Update thermodynamic states @@ -572,8 +763,8 @@ def run(self, n_iterations=None): # Write iteration to storage file self._report_iteration() - # Update analysis - self._update_analysis() + # TODO: Update analysis + # self._update_analysis() # Computing and transmitting timing information iteration_time = timer.stop("Iteration") @@ -586,17 +777,152 @@ def run(self, n_iterations=None): ) # Log timing data as info level -- useful for users by default - logger.info( + log.info( "Iteration took {:.3f}s.".format(self._timing_data["iteration_seconds"]) ) if self._timing_data["estimated_time_remaining"] != float("inf"): - logger.info( - "Estimated completion in {}, at {} (consuming total wall clock time {}).".format( - self._timing_data["estimated_time_remaining"], - self._timing_data["estimated_localtime_finish_date"], - self._timing_data["estimated_total_time"], - ) + log.info( + f"Estimated completion in {self._timing_data['estimated_time_remaining']}, at {self._timing_data['estimated_localtime_finish_date']} (consuming total wall clock time {self._timing_data['estimated_total_time']})." ) # Perform sanity checks to see if we should terminate here. self._check_nan_energy() + + @with_timer("Writing iteration information to storage") + def _report_iteration(self): + """Store positions, states, and energies of current iteration.n""" + # Call report_iteration_items for a subclass-friendly function + self._report_iteration_items() + self._reporter.write_timestamp(self._iteration) + self._reporter.write_last_iteration(self._iteration) + + def _report_iteration_items(self): + """ + Sub-function of :func:`_report_iteration` which handles all the actual individual item reporting in a + sub-class friendly way. The final actions of writing timestamp, last-good-iteration, and syncing + should be left to the :func:`_report_iteration` and subclasses should extend this function instead + """ + self._reporter.write_sampler_states(self._sampler_states, self._iteration) + self._reporter.write_replica_thermodynamic_states( + self._replica_thermodynamic_states, self._iteration + ) + self._reporter.write_mcmc_moves( + self._mcmc_moves + ) # MCMCMoves can store internal statistics. + self._reporter.write_energies( + self._energy_thermodynamic_states, + self._neighborhoods, + self._energy_unsampled_states, + self._iteration, + ) + self._reporter.write_mixing_statistics( + self._n_accepted_matrix, self._n_proposed_matrix, self._iteration + ) + + def _update_timing( + self, iteration_time, partial_total_time, run_initial_iteration, iteration_limit + ): + """ + Function that computes and transmits timing information to reporter. + + Parameters + ---------- + iteration_time : float + Time took in the iteration. + partial_total_time : float + Partial total time elapsed. + run_initial_iteration : int + Iteration where to start/resume the simulation. + iteration_limit : int + Hard limit on number of iterations to be run by the sampler. + """ + self._timing_data["iteration_seconds"] = iteration_time + self._timing_data["average_seconds_per_iteration"] = partial_total_time / ( + self._iteration - run_initial_iteration + ) + estimated_timedelta_remaining = datetime.timedelta( + seconds=self._timing_data["average_seconds_per_iteration"] + * (iteration_limit - self._iteration) + ) + estimated_finish_date = datetime.datetime.now() + estimated_timedelta_remaining + self._timing_data["estimated_time_remaining"] = str( + estimated_timedelta_remaining + ) # Putting it in dict as str + self._timing_data[ + "estimated_localtime_finish_date" + ] = estimated_finish_date.strftime("%Y-%b-%d-%H:%M:%S") + total_time_in_seconds = datetime.timedelta( + seconds=self._timing_data["average_seconds_per_iteration"] * iteration_limit + ) + self._timing_data["estimated_total_time"] = str(total_time_in_seconds) + + # Estimate performance + moves_iterator = self._flatten_moves_iterator() + # Only consider "dynamic" moves (timestep and n_steps attributes) + moves_times = [ + move.timestep.value_in_unit(unit.nanosecond) * move.n_steps + for move in moves_iterator + if hasattr(move, "timestep") and hasattr(move, "n_steps") + ] + iteration_simulated_nanoseconds = sum(moves_times) + seconds_in_a_day = (1 * unit.day).value_in_unit(unit.seconds) + self._timing_data["ns_per_day"] = iteration_simulated_nanoseconds / ( + self._timing_data["average_seconds_per_iteration"] / seconds_in_a_day + ) + + def _flatten_moves_iterator(self): + """Recursively flatten MCMC moves. Handles the cases where each move can be a set of moves, for example with + SequenceMove or WeightedMove objects.""" + + def flatten(iterator): + try: + yield from [ + inner_move for move in iterator for inner_move in flatten(move) + ] + except TypeError: # Inner object is not iterable, finish flattening. + yield iterator + + return flatten(self.mcmc_moves) + + def _check_nan_energy(self): + """Checks that energies are finite and abort otherwise. + + Checks both sampled and unsampled thermodynamic states. + + """ + # Find faulty replicas to create error message. + nan_replicas = [] + + # Check sampled thermodynamic states first. + state_type = "thermodynamic state" + for replica_id, state_id in enumerate(self._replica_thermodynamic_states): + neighborhood = self._neighborhood(state_id) + energies_neighborhood = self._energy_thermodynamic_states[ + replica_id, neighborhood + ] + if np.any(np.isnan(energies_neighborhood)): + nan_replicas.append((replica_id, energies_neighborhood)) + + # If there are no NaNs in energies, look for NaNs in the unsampled states energies. + if (len(nan_replicas) == 0) and (self._energy_unsampled_states.shape[1] > 0): + state_type = "unsampled thermodynamic state" + for replica_id in range(self.n_replicas): + if np.any(np.isnan(self._energy_unsampled_states[replica_id])): + nan_replicas.append( + (replica_id, self._energy_unsampled_states[replica_id]) + ) + + # Raise exception if we have found some NaN energies. + if len(nan_replicas) > 0: + # Log failed replica, its thermo state, and the energy matrix row. + err_msg = "NaN encountered in {} energies for the following replicas and states".format( + state_type + ) + for replica_id, energy_row in nan_replicas: + err_msg += "\n\tEnergies for positions at replica {} (current state {}): {} kT".format( + replica_id, + self._replica_thermodynamic_states[replica_id], + energy_row, + ) + log.critical(err_msg) + raise RuntimeError(err_msg) From ba286c6b4c086a364cc4b6c62cc547f4ef406e05 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 2 Jan 2024 16:02:47 +0100 Subject: [PATCH 17/55] Add MultiStateReporter and test functions for equilibration and running --- chiron/tests/test_multistate.py | 45 +++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 4125eeb..50c460d 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -55,11 +55,17 @@ def ho_multistate_sampler() -> MultiStateSampler: ) move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=50) + + from openmmtools.multistate import MultiStateReporter + + reporter = MultiStateReporter("test.nc") + multistate_sampler = MultiStateSampler(mcmc_moves=move, number_of_iterations=2) multistate_sampler.create( thermodynamic_states=thermodynamic_states, sampler_states=sampler_state, nbr_list=nbr_list, + reporter=reporter, ) return multistate_sampler @@ -91,5 +97,40 @@ def test_multistate_minimize(ho_multistate_sampler): assert np.allclose( ho_multistate_sampler.sampler_states[0].x0, np.array([[0.0, 0.0, 0.0]]) ) - assert np.allclose(ho_multistate_sampler.sampler_states[1].x0, np.array([[0.05, 0.0, 0.0]]), atol=1e-2) - assert np.allclose(ho_multistate_sampler.sampler_states[2].x0, np.array([[0.1, 0.0, 0.0]]), atol=1e-2) + assert np.allclose( + ho_multistate_sampler.sampler_states[1].x0, + np.array([[0.05, 0.0, 0.0]]), + atol=1e-2, + ) + assert np.allclose( + ho_multistate_sampler.sampler_states[2].x0, + np.array([[0.1, 0.0, 0.0]]), + atol=1e-2, + ) + + +def test_multistate_equilibration(ho_multistate_sampler): + import numpy as np + + ho_multistate_sampler.equilibrate(10) + + assert np.allclose( + ho_multistate_sampler._replica_thermodynamic_states, np.array([0, 1, 2]) + ) + assert np.allclose( + ho_multistate_sampler._energy_thermodynamic_states, + np.array( + [ + [4.81132936, 3.84872651, 3.10585403], + [6.54490519, 5.0176239, 3.85019779], + [9.48260307, 7.07196712, 5.21255827], + ] + ), + ) + + +def test_multistate_run(ho_multistate_sampler): + import numpy as np + + ho_multistate_sampler.equilibrate(10) + ho_multistate_sampler.run(10) From f38665d0df24b3f9a8b0caaa06045522baa4194a Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 2 Jan 2024 16:05:16 +0100 Subject: [PATCH 18/55] Refactor _propagate_replicas method to iterate over replica_id --- chiron/multistate.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index fcf1663..d2cbb85 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -618,13 +618,18 @@ def _mix_replicas(self) -> np.ndarray: return new_replica_states @with_timer("Propagating all replicas") - def _propagate_replicas(self): - """Propagate all replicas.""" + def _propagate_replicas(self) -> None: + """ + Propagate all replicas through their respective MCMC moves. + + This method iterates over all replicas and applies the corresponding MCMC move + to each one, based on its current thermodynamic state. + """ log.debug("Propagating all replicas...") - for i in range(self.n_replicas): - self._propagate_replica(i) + for replica_id in range(self.n_replicas): + self._propagate_replica(replica_id) def _neighborhood(self, state_index): """Compute the states in the local neighborhood determined by self.locality From 078e7c649367f5770137c7a951ed97d150af455d Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 2 Jan 2024 16:07:29 +0100 Subject: [PATCH 19/55] Refactor _neighborhood method to calculate the neighborhood of states --- chiron/multistate.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index d2cbb85..f980e80 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -631,28 +631,35 @@ def _propagate_replicas(self) -> None: for replica_id in range(self.n_replicas): self._propagate_replica(replica_id) - def _neighborhood(self, state_index): - """Compute the states in the local neighborhood determined by self.locality + def _neighborhood(self, state_index: int) -> List[int]: + """ + Compute the indices of neighboring states for a given state. + + This method determines the neighborhood of states around a given state index, + considering the 'locality' parameter. If 'locality' is None, the neighborhood + includes all states; otherwise, it includes states within the 'locality' range. Parameters ---------- state_index : int - The current state + The index of the state for which the neighborhood is to be calculated. Returns ------- - neighborhood : list of int - The states in the local neighborhood + List[int] + A list of state indices that are considered neighbors of the given state. """ if self.locality is None: # Global neighborhood - return list(range(0, self.n_states)) + return list(range(self.n_states)) else: # Local neighborhood specified by 'locality' + lower_bound = max(0, state_index - self.locality) + upper_bound = min(self.n_states, state_index + self.locality + 1) return list( range( - max(0, state_index - self.locality), - min(self.n_states, state_index + self.locality + 1), + lower_bound, + upper_bound, ) ) From 617a4360e6dca4abedb4d6434b4937ca6d949075 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 2 Jan 2024 16:33:34 +0100 Subject: [PATCH 20/55] Refactor energy computation in MultiStateSampler --- chiron/multistate.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index f980e80..23ed9f1 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -664,31 +664,34 @@ def _neighborhood(self, state_index: int) -> List[int]: ) @with_timer("Computing energy matrix") - def _compute_energies(self): - """Compute energies of all replicas at all states.""" + def _compute_energies(self) -> None: + """ + Compute the energies of all replicas at all thermodynamic states. - # Determine neighborhoods (all nodes) - self._neighborhoods[:, :] = False - for replica_index, state_index in enumerate(self._replica_thermodynamic_states): - neighborhood = self._neighborhood(state_index) - self._neighborhoods[replica_index, neighborhood] = True + This method calculates the energy for each replica in every thermodynamic state, + considering the defined neighborhoods to optimize the computation. The energies + are stored in the internal energy matrix of the sampler. + """ - # Calculate energies for all replicas. - new_energies, replica_ids = [], [] - for replica_id in range(self.n_replicas): - new_energy = self._compute_replica_energies(replica_id) - new_energies.append(new_energy) - replica_ids.append(replica_id) + log.debug("Computing energy matrix for all replicas...") + # Initialize the energy matrix and neighborhoods + self._energy_thermodynamic_states = np.zeros((self.n_replicas, self.n_states)) + self._neighborhoods = np.zeros((self.n_replicas, self.n_states), dtype=bool) - # Update energy matrices. - for replica_id, energies in zip(replica_ids, new_energies): - energy_thermodynamic_states = energies # Unpack. + # Calculate energies for each replica + for replica_id in range(self.n_replicas): neighborhood = self._neighborhood( self._replica_thermodynamic_states[replica_id] ) + self._neighborhoods[replica_id, neighborhood] = True + + # Compute and store energies for the neighborhood states self._energy_thermodynamic_states[ replica_id, neighborhood - ] = energy_thermodynamic_states + ] = self._compute_replica_energies(replica_id) + + log.debug(self._energy_thermodynamic_states) + log.debug(self._neighborhoods) def _is_completed(self, iteration_limit=None): """Check if we have reached any of the stop target criteria. From 56e91ee2e7da76e4549dce509dda45ef5c8c326a Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 2 Jan 2024 16:35:40 +0100 Subject: [PATCH 21/55] Refactor _is_completed method in MultiStateSampler class --- chiron/multistate.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index 23ed9f1..d7c401a 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -693,24 +693,35 @@ def _compute_energies(self) -> None: log.debug(self._energy_thermodynamic_states) log.debug(self._neighborhoods) - def _is_completed(self, iteration_limit=None): - """Check if we have reached any of the stop target criteria. + def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: + """ + Determine if the sampling process has met its completion criteria. + + This method checks if the simulation has reached a specified iteration limit + or any other predefined stopping condition. Parameters ---------- - iteration_limit : int, optional - If specified, the simulation will stop if the iteration counter reaches this value. + iteration_limit : Optional[int], default=None + An optional iteration limit. If specified, the method checks if the + current iteration number has reached this limit. Returns ------- - is_completed : bool - If True, the simulation is completed and should be terminated. + bool + True if the simulation has completed based on the stopping criteria, + False otherwise. """ + + # Check if iteration limit has been reached if iteration_limit is not None and self._iteration >= iteration_limit: log.info( f"Reached iteration limit {iteration_limit} (current iteration {self._iteration})" ) return True + + # Additional stopping criteria can be implemented here + return False def run(self, n_iterations=None): From 7046330b4d932250035062418fcc8b54d55a816f Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 2 Jan 2024 16:45:10 +0100 Subject: [PATCH 22/55] Refactor run method and add progress updates --- chiron/multistate.py | 90 +++++++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 38 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index d7c401a..4a071d8 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -724,27 +724,57 @@ def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: return False - def run(self, n_iterations=None): - """Run the replica-exchange simulation. + def _update_run_progress(self, timer, run_initial_iteration, iteration_limit): + # Computing and transmitting timing information + iteration_time = timer.stop("Iteration") + partial_total_time = timer.partial("Run ReplicaExchange") + self._update_timing( + iteration_time, + partial_total_time, + run_initial_iteration, + iteration_limit, + ) + + # Log timing data as info level -- useful for users by default + log.info( + "Iteration took {:.3f}s.".format(self._timing_data["iteration_seconds"]) + ) + if self._timing_data["estimated_time_remaining"] != float("inf"): + log.info( + f"Estimated completion in {self._timing_data['estimated_time_remaining']}, at {self._timing_data['estimated_localtime_finish_date']} (consuming total wall clock time {self._timing_data['estimated_total_time']})." + ) + + # Perform sanity checks to see if we should terminate here. + self._check_nan_energy() - This runs at most ``number_of_iterations`` iterations. + def run(self, n_iterations: Optional[int] = None) -> None: + """ + Execute the replica-exchange simulation. + + Run the simulation for a specified number of iterations. If no number is + specified, it runs for the number of iterations set during the initialization + of the sampler. Parameters ---------- - n_iterations : int, optional - If specified, only at most the specified number of iterations - will be run (default is None). + n_iterations : Optional[int], default=None + The number of iterations to run. If None, the sampler runs for the + number of iterations specified at initialization. + + Raises + ------ + RuntimeError + If an error occurs during the computation of energies. """ + # If this is the first iteration, compute and store the # starting energies of the minimized/equilibrated structures. log.info("Running simulation...") + + # Initialize energies if this is the first iteration if self._iteration == 0: - try: - self._compute_energies() - except Exception as e: - log.critical(e) - raise e + self._compute_energies() self._reporter.write_energies( energy_thermodynamic_states=self._energy_thermodynamic_states, @@ -759,13 +789,12 @@ def run(self, n_iterations=None): timer.start("Run ReplicaExchange") run_initial_iteration = self._iteration - # Handle default argument and determine number of iterations to run. - if n_iterations is None: - iteration_limit = self.number_of_iterations - else: - iteration_limit = min( - self._iteration + n_iterations, self.number_of_iterations - ) + # Determine the number of iterations to run before stopping. + iteration_limit = ( + min(self._iteration + n_iterations, self.number_of_iterations) + if n_iterations is not None + else self.number_of_iterations + ) # Main loop. while not self._is_completed(iteration_limit): @@ -792,27 +821,12 @@ def run(self, n_iterations=None): # TODO: Update analysis # self._update_analysis() - # Computing and transmitting timing information - iteration_time = timer.stop("Iteration") - partial_total_time = timer.partial("Run ReplicaExchange") - self._update_timing( - iteration_time, - partial_total_time, - run_initial_iteration, - iteration_limit, - ) - - # Log timing data as info level -- useful for users by default - log.info( - "Iteration took {:.3f}s.".format(self._timing_data["iteration_seconds"]) + # Update timing and progress information + self._update_run_progress( + timer=timer, + run_initial_iteration=run_initial_iteration, + iteration_limit=iteration_limit, ) - if self._timing_data["estimated_time_remaining"] != float("inf"): - log.info( - f"Estimated completion in {self._timing_data['estimated_time_remaining']}, at {self._timing_data['estimated_localtime_finish_date']} (consuming total wall clock time {self._timing_data['estimated_total_time']})." - ) - - # Perform sanity checks to see if we should terminate here. - self._check_nan_energy() @with_timer("Writing iteration information to storage") def _report_iteration(self): From c610b94aeae555d707200641b264309bb2e824db Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 2 Jan 2024 18:13:47 +0100 Subject: [PATCH 23/55] Add test cases for multistate equilibration and run --- chiron/tests/test_multistate.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 50c460d..8499e31 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -128,9 +128,22 @@ def test_multistate_equilibration(ho_multistate_sampler): ), ) + a = 7 + def test_multistate_run(ho_multistate_sampler): import numpy as np ho_multistate_sampler.equilibrate(10) + assert np.allclose( + ho_multistate_sampler._energy_thermodynamic_states, + np.array( + [ + [4.81132936, 3.84872651, 3.10585403], + [6.54490519, 5.0176239, 3.85019779], + [9.48260307, 7.07196712, 5.21255827], + ] + ), + ) ho_multistate_sampler.run(10) + From 9afc05bb0ab638f2745daf7e9639c54ef7289b1f Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 3 Jan 2024 17:12:19 +0100 Subject: [PATCH 24/55] Add mbar analysis functionality to MultiStateSampler class --- chiron/multistate.py | 70 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index 4a071d8..d7606a4 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -48,7 +48,13 @@ class MultiStateSampler(object): is_completed """ - def __init__(self, mcmc_moves=None, number_of_iterations=1, locality=None): + def __init__( + self, + mcmc_moves=None, + number_of_iterations=1, + locality=None, + online_analysis_interval=5, + ): # These will be set on initialization. See function # create() for explanation of single variables. self._thermodynamic_states = None @@ -63,6 +69,7 @@ def __init__(self, mcmc_moves=None, number_of_iterations=1, locality=None): self._n_proposed_matrix = None self._reporter = None self._metadata = None + self._online_analysis_interval = online_analysis_interval self._timing_data = dict() # Handling default propagator. @@ -818,8 +825,8 @@ def run(self, n_iterations: Optional[int] = None) -> None: # Write iteration to storage file self._report_iteration() - # TODO: Update analysis - # self._update_analysis() + # Update analysis + self._update_analysis() # Update timing and progress information self._update_run_progress( @@ -966,3 +973,60 @@ def _check_nan_energy(self): ) log.critical(err_msg) raise RuntimeError(err_msg) + + def _update_analysis(self): + """Update analysis of free energies""" + + if self._online_analysis_interval is None: + log.debug("No online analysis requested") + # Perform no analysis and exit function + return + + # Always perform fast online analysis + if self.free_energy_estimator == "mbar": + self._last_err_free_energy = self._mbar_analysis() + + return + + def _mbar_analysis(self): + """ + Perform online analysis of the simulation. + + This method performs online analysis of the simulation, including + the calculation of free energies and other thermodynamic properties. + """ + import jax.numpy as jnp + from jax.scipy.special import logsumexp + + gamma = 1.0 / self._iteration + 1 + + self._last_mbar_f_k = np.zeros([self.n_states], np.float64) + + logZ = -self._last_mbar_f_k + + for replica_index, state_index in enumerate(self._replica_thermodynamic_states): + neighborhood = self._neighborhood(state_index) + u_k = self._energy_thermodynamic_states[replica_index, :] + log_P_k = np.zeros([self.n_states], np.float64) + log_pi_k = np.zeros([self.n_states], np.float64) + log_weights = np.zeros([self.n_states], np.float64) + log_P_k[neighborhood] = log_weights[neighborhood] - u_k[neighborhood] + log_P_k[neighborhood] -= logsumexp(log_P_k[neighborhood]) + logZ[neighborhood] += gamma * np.exp( + log_P_k[neighborhood] - log_pi_k[neighborhood] + ) + + # Subtract off logZ[0] to prevent logZ from growing without bound + logZ[:] -= logZ[0] + self._last_mbar_f_k = -logZ + free_energy = self._last_mbar_f_k[-1] - self._last_mbar_f_k[0] + self._last_err_free_energy = np.Inf + + # Report online analysis to debug log + log.debug("*** MBAR analysis free energies:") + msg = " " + for x in self._last_mbar_f_k: + msg += "%8.1f" % x + log.debug(msg) + log.debug(free_energy) + return self._last_err_free_energy From 67999732c7731d68bb8f0cbe6e2d24611923c433 Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 3 Jan 2024 19:34:48 +0100 Subject: [PATCH 25/55] Add test_multistate_sampler_single_sampler_state fixture --- chiron/tests/test_multistate.py | 87 ++++++++++++++++++++++++++++----- 1 file changed, 74 insertions(+), 13 deletions(-) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 8499e31..d20b37e 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -128,22 +128,83 @@ def test_multistate_equilibration(ho_multistate_sampler): ), ) - a = 7 +@pytest.fixture +def ho_multistate_sampler_single_sampler_state() -> MultiStateSampler: + """ + Create a multi-state sampler for a harmonic oscillator system. -def test_multistate_run(ho_multistate_sampler): + Returns: + MultiStateSampler: The multi-state sampler object. + """ + from openmm import unit + from chiron.mcmc import LangevinDynamicsMove + from chiron.states import ThermodynamicState, SamplerState + from openmmtools.testsystems import HarmonicOscillator + from chiron.potential import HarmonicOscillatorPotential + from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace + + ho = HarmonicOscillator() + n_states = 4 + + T = 300.0 * unit.kelvin # Minimum temperature. + kT = unit.BOLTZMANN_CONSTANT_kB * T * unit.AVOGADRO_CONSTANT_NA + sigmas = [ + unit.Quantity(1.0 + 0.2 * state_index, unit.angstrom) + for state_index in range(n_states) + ] + Ks = [kT / sigma ** 2 for sigma in sigmas] + thermodynamic_states = [ + ThermodynamicState( + HarmonicOscillatorPotential(ho.topology, k=k), temperature=T + ) + for k in Ks + ] + sampler_state = [SamplerState(ho.positions) for _ in sigmas] import numpy as np - ho_multistate_sampler.equilibrate(10) - assert np.allclose( - ho_multistate_sampler._energy_thermodynamic_states, - np.array( - [ - [4.81132936, 3.84872651, 3.10585403], - [6.54490519, 5.0176239, 3.85019779], - [9.48260307, 7.07196712, 5.21255827], - ] - ), + f_i = np.array( + [ + -np.log(2 * np.pi * (sigma / unit.angstroms) ** 2) * (3.0 / 2.0) + for sigma in sigmas + ] + ) + + # Initialize simulation object with options. Run with a langevin integrator. + # initialize the LennardJones potential in chiron + # + sigma = 0.34 * unit.nanometer + cutoff = 3.0 * sigma + skin = 0.5 * unit.nanometer + + nbr_list = NeighborListNsqrd( + OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=10 + ) + + move = LangevinDynamicsMove(stepsize=1.0 * unit.femtoseconds, nr_of_steps=100) + + from openmmtools.multistate import MultiStateReporter + + reporter = MultiStateReporter("test.nc") + + multistate_sampler = MultiStateSampler(mcmc_moves=move, number_of_iterations=10) + multistate_sampler.create( + thermodynamic_states=thermodynamic_states, + sampler_states=sampler_state, + nbr_list=nbr_list, + reporter=reporter, ) - ho_multistate_sampler.run(10) + multistate_sampler.analytical_f_i = f_i + multistate_sampler.delta_f_ij_analytical = f_i - f_i[:, np.newaxis] + return multistate_sampler + +def test_multistate_run(ho_multistate_sampler_single_sampler_state): + ho_sampler = ho_multistate_sampler_single_sampler_state + import numpy as np + + ho_sampler.equilibrate(10) + ho_sampler.run(200) + print(ho_sampler.analytical_f_i) + print(ho_sampler.delta_f_ij_analytical) + a = 7 From 9cb9b617da125171e41e1400db13725a3082752c Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 3 Jan 2024 19:34:59 +0100 Subject: [PATCH 26/55] Add free energy estimator and update gamma value --- chiron/multistate.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chiron/multistate.py b/chiron/multistate.py index d7606a4..503200e 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -71,6 +71,7 @@ def __init__( self._metadata = None self._online_analysis_interval = online_analysis_interval self._timing_data = dict() + self.free_energy_estimator = None # Handling default propagator. if mcmc_moves is None: @@ -240,6 +241,7 @@ def create( """ # TODO: initialize reporter here # TODO: consider unsampled thermodynamic states for reweighting schemes + self.free_energy_estimator = "mbar" # Ensure the number of thermodynamic states matches the number of sampler states if len(thermodynamic_states) != len(sampler_states): @@ -999,6 +1001,7 @@ def _mbar_analysis(self): from jax.scipy.special import logsumexp gamma = 1.0 / self._iteration + 1 + gamma = 1.0 self._last_mbar_f_k = np.zeros([self.n_states], np.float64) From 5cc16269a32489250be106b1888fa23e89cb042d Mon Sep 17 00:00:00 2001 From: wiederm Date: Mon, 8 Jan 2024 15:30:02 +0100 Subject: [PATCH 27/55] Refactor code and remove debug statements --- chiron/multistate.py | 139 +++++++++++++------------------- chiron/reporters.py | 4 +- chiron/states.py | 4 +- chiron/tests/test_multistate.py | 25 ++++-- 4 files changed, 75 insertions(+), 97 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index 503200e..20f84fe 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -5,7 +5,7 @@ import datetime from loguru import logger as log import numpy as np -from openmmtools.utils import time_it, with_timer +from openmmtools.utils import with_timer from chiron.neighbors import NeighborListNsqrd from openmm import unit from chiron.mcmc import MCMCMove @@ -30,8 +30,6 @@ class MultiStateSampler(object): they will be assigned to the correspondent thermodynamic state on creation. If None is provided, Langevin dynamics with 2fm timestep, 5.0/ps collision rate, and 500 steps per iteration will be used. - number_of_iterations : int or infinity, optional, default: 1 - The number of iterations to perform. locality : int > 0, optional, default None If None, the energies at all states will be computed for every replica each iteration. @@ -41,7 +39,6 @@ class MultiStateSampler(object): ---------- n_replicas n_states - iteration mcmc_moves sampler_states metadata @@ -51,7 +48,6 @@ class MultiStateSampler(object): def __init__( self, mcmc_moves=None, - number_of_iterations=1, locality=None, online_analysis_interval=5, ): @@ -63,6 +59,7 @@ def __init__( self._replica_thermodynamic_states = None self._iteration = None self._energy_thermodynamic_states = None + self._energy_thermodynamic_states_for_each_iteration = None self._neighborhoods = None self._energy_unsampled_states = None self._n_accepted_matrix = None @@ -86,10 +83,6 @@ def __init__( else: self._mcmc_moves = copy.deepcopy(mcmc_moves) - # Store constructor parameters. Everything is marked for internal - # usage because any change to these attribute implies a change - # in the storage file as well. Use properties for checks. - self.number_of_iterations = number_of_iterations self._last_mbar_f_k = None self._last_err_free_energy = None @@ -699,9 +692,6 @@ def _compute_energies(self) -> None: replica_id, neighborhood ] = self._compute_replica_energies(replica_id) - log.debug(self._energy_thermodynamic_states) - log.debug(self._neighborhoods) - def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: """ Determine if the sampling process has met its completion criteria. @@ -756,7 +746,7 @@ def _update_run_progress(self, timer, run_initial_iteration, iteration_limit): # Perform sanity checks to see if we should terminate here. self._check_nan_energy() - def run(self, n_iterations: Optional[int] = None) -> None: + def run(self, n_iterations: int = 10) -> None: """ Execute the replica-exchange simulation. @@ -766,9 +756,8 @@ def run(self, n_iterations: Optional[int] = None) -> None: Parameters ---------- - n_iterations : Optional[int], default=None - The number of iterations to run. If None, the sampler runs for the - number of iterations specified at initialization. + n_iterations : int, default=10 + The number of iterations to run. Raises ------ @@ -778,34 +767,31 @@ def run(self, n_iterations: Optional[int] = None) -> None: # If this is the first iteration, compute and store the # starting energies of the minimized/equilibrated structures. + self.number_of_iterations = n_iterations log.info("Running simulation...") + self._energy_thermodynamic_states_for_each_iteration_in_run = np.zeros( + [self.n_replicas, self.n_states, n_iterations + 1], np.float64 + ) # Initialize energies if this is the first iteration if self._iteration == 0: self._compute_energies() - - self._reporter.write_energies( - energy_thermodynamic_states=self._energy_thermodynamic_states, - energy_neighborhoods=self._neighborhoods, - energy_unsampled_states=self._energy_unsampled_states, - iteration=self._iteration, - ) + # store energies for mbar analysis + self._energy_thermodynamic_states_for_each_iteration_in_run[ + :, :, self._iteration + ] = self._energy_thermodynamic_states + # TODO report energies from openmmtools.utils import Timer timer = Timer() timer.start("Run ReplicaExchange") - run_initial_iteration = self._iteration - # Determine the number of iterations to run before stopping. - iteration_limit = ( - min(self._iteration + n_iterations, self.number_of_iterations) - if n_iterations is not None - else self.number_of_iterations - ) + iteration_limit = n_iterations - # Main loop. + # start the sampling loop + log.debug(f"{iteration_limit=}") while not self._is_completed(iteration_limit): # Increment iteration counter. self._iteration += 1 @@ -824,26 +810,35 @@ def run(self, n_iterations: Optional[int] = None) -> None: # Compute energies of all replicas at all states self._compute_energies() + # Add energies to the energy matrix + self._energy_thermodynamic_states_for_each_iteration_in_run[ + :, :, self._iteration + ] = self._energy_thermodynamic_states + log.info( + self._energy_thermodynamic_states_for_each_iteration_in_run[:, :, 1] + ) + log.info(self._energy_thermodynamic_states) # Write iteration to storage file - self._report_iteration() + # TODO + # self._report_iteration() # Update analysis self._update_analysis() - # Update timing and progress information - self._update_run_progress( - timer=timer, - run_initial_iteration=run_initial_iteration, - iteration_limit=iteration_limit, - ) - - @with_timer("Writing iteration information to storage") def _report_iteration(self): - """Store positions, states, and energies of current iteration.n""" - # Call report_iteration_items for a subclass-friendly function - self._report_iteration_items() - self._reporter.write_timestamp(self._iteration) - self._reporter.write_last_iteration(self._iteration) + """Store positions, states, and energies of current iteration.""" + + # TODO: write energies + + # TODO: write trajectory + + # TODO: write mixing statistics + self._reporter.write_energies( + self._energy_thermodynamic_states, + self._neighborhoods, + self._energy_unsampled_states, + self._iteration, + ) def _report_iteration_items(self): """ @@ -984,7 +979,7 @@ def _update_analysis(self): # Perform no analysis and exit function return - # Always perform fast online analysis + # Perform offline free energy estimate if requested if self.free_energy_estimator == "mbar": self._last_err_free_energy = self._mbar_analysis() @@ -992,44 +987,20 @@ def _update_analysis(self): def _mbar_analysis(self): """ - Perform online analysis of the simulation. - - This method performs online analysis of the simulation, including - the calculation of free energies and other thermodynamic properties. + Perform mbar analysis """ - import jax.numpy as jnp - from jax.scipy.special import logsumexp + from pymbar import MBAR - gamma = 1.0 / self._iteration + 1 - gamma = 1.0 + self._last_mbar_f_k_offline = np.zeros(len(self._thermodynamic_states)) - self._last_mbar_f_k = np.zeros([self.n_states], np.float64) - - logZ = -self._last_mbar_f_k - - for replica_index, state_index in enumerate(self._replica_thermodynamic_states): - neighborhood = self._neighborhood(state_index) - u_k = self._energy_thermodynamic_states[replica_index, :] - log_P_k = np.zeros([self.n_states], np.float64) - log_pi_k = np.zeros([self.n_states], np.float64) - log_weights = np.zeros([self.n_states], np.float64) - log_P_k[neighborhood] = log_weights[neighborhood] - u_k[neighborhood] - log_P_k[neighborhood] -= logsumexp(log_P_k[neighborhood]) - logZ[neighborhood] += gamma * np.exp( - log_P_k[neighborhood] - log_pi_k[neighborhood] - ) - - # Subtract off logZ[0] to prevent logZ from growing without bound - logZ[:] -= logZ[0] - self._last_mbar_f_k = -logZ - free_energy = self._last_mbar_f_k[-1] - self._last_mbar_f_k[0] - self._last_err_free_energy = np.Inf - - # Report online analysis to debug log - log.debug("*** MBAR analysis free energies:") - msg = " " - for x in self._last_mbar_f_k: - msg += "%8.1f" % x - log.debug(msg) - log.debug(free_energy) - return self._last_err_free_energy + log.debug( + f"{self._energy_thermodynamic_states_for_each_iteration_in_run.shape=}" + ) + log.debug(f"{self.n_states=}") + u_kn = self._energy_thermodynamic_states_for_each_iteration_in_run + log.debug(f"{self._iteration=}") + N_k = [self._iteration] * self.n_states + log.debug(f"{N_k=}") + mbar = MBAR(u_kn=u_kn, N_k=N_k) + log.debug(mbar.f_k) + self._last_mbar_f_k_offline = mbar.f_k diff --git a/chiron/reporters.py b/chiron/reporters.py index 0ec29b4..b2fae35 100644 --- a/chiron/reporters.py +++ b/chiron/reporters.py @@ -21,8 +21,6 @@ def __init__(self, filename: str, topology: Topology, buffer_size: int = 1): Number of data points to buffer before writing to disk (default is 1). """ - import mdtraj as md - self.filename = filename self.buffer_size = buffer_size self.topology = topology @@ -52,7 +50,7 @@ def report(self, data_dict): if len(self.buffer[key]) >= self.buffer_size: self._write_to_disk(key) - def _write_to_disk(self, key): + def _write_to_disk(self, key:str): """ Write buffered data of a given key to the HDF5 file. diff --git a/chiron/states.py b/chiron/states.py index 3a71dbf..10f07f4 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -274,14 +274,14 @@ def get_reduced_potential( self.beta = 1.0 / ( unit.BOLTZMANN_CONSTANT_kB * (self.temperature * unit.kelvin) ) - log.debug(f"sample state: {sampler_state.x0}") + # log.debug(f"sample state: {sampler_state.x0}") reduced_potential = ( unit.Quantity( self.potential.compute_energy(sampler_state.x0, nbr_list), unit.kilojoule_per_mole, ) ) / unit.AVOGADRO_CONSTANT_NA - log.debug(f"reduced potential: {reduced_potential}") + # log.debug(f"reduced potential: {reduced_potential}") if self.pressure is not None: reduced_potential += self.pressure * self.volume diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index d20b37e..ba02528 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -146,18 +146,16 @@ def ho_multistate_sampler_single_sampler_state() -> MultiStateSampler: ho = HarmonicOscillator() n_states = 4 - + T = 300.0 * unit.kelvin # Minimum temperature. kT = unit.BOLTZMANN_CONSTANT_kB * T * unit.AVOGADRO_CONSTANT_NA sigmas = [ unit.Quantity(1.0 + 0.2 * state_index, unit.angstrom) for state_index in range(n_states) ] - Ks = [kT / sigma ** 2 for sigma in sigmas] + Ks = [kT / sigma**2 for sigma in sigmas] thermodynamic_states = [ - ThermodynamicState( - HarmonicOscillatorPotential(ho.topology, k=k), temperature=T - ) + ThermodynamicState(HarmonicOscillatorPotential(ho.topology, k=k), temperature=T) for k in Ks ] sampler_state = [SamplerState(ho.positions) for _ in sigmas] @@ -187,7 +185,9 @@ def ho_multistate_sampler_single_sampler_state() -> MultiStateSampler: reporter = MultiStateReporter("test.nc") - multistate_sampler = MultiStateSampler(mcmc_moves=move, number_of_iterations=10) + multistate_sampler = MultiStateSampler( + mcmc_moves=move, + ) multistate_sampler.create( thermodynamic_states=thermodynamic_states, sampler_states=sampler_state, @@ -203,8 +203,17 @@ def test_multistate_run(ho_multistate_sampler_single_sampler_state): ho_sampler = ho_multistate_sampler_single_sampler_state import numpy as np - ho_sampler.equilibrate(10) - ho_sampler.run(200) + n_iteratinos = 100 + ho_sampler.run(n_iteratinos) + + # check that we have the correct number of iterations, replicas and states + assert ho_sampler.iteration == n_iteratinos + assert ho_sampler._iteration == n_iteratinos + assert ho_sampler.n_replicas == 4 + assert ho_sampler.n_states == 4 + + # check that the free energies are correct print(ho_sampler.analytical_f_i) print(ho_sampler.delta_f_ij_analytical) + print(ho_sampler._last_mbar_f_k_offline) a = 7 From 8431899f158a615b763fa17a1568b4658893ee05 Mon Sep 17 00:00:00 2001 From: wiederm Date: Mon, 8 Jan 2024 17:47:34 +0100 Subject: [PATCH 28/55] Refactor ThermodynamicState class and calculate_reduced_potential_at_states function --- chiron/multistate.py | 267 +++---------------------------------------- chiron/states.py | 37 ++---- 2 files changed, 22 insertions(+), 282 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index 20f84fe..f4b2d3b 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -165,45 +165,16 @@ def _compute_replica_energies(self, replica_id: int) -> np.ndarray: import jax.numpy as jnp from chiron.states import calculate_reduced_potential_at_states - log.debug(f"{self._replica_thermodynamic_states=}") - - # Determine neighborhood - state_index = self._replica_thermodynamic_states[replica_id] - neighborhood = self._neighborhood(state_index) - log.debug(f"{neighborhood=}") # Only compute energies of the sampled states over neighborhoods. - energy_neighborhood_states = np.zeros(len(neighborhood)) - neighborhood_thermodynamic_states = [ - self._thermodynamic_states[n] for n in neighborhood + thermodynamic_states = [ + self._thermodynamic_states[n] for n in range(self.n_states) ] - # Retrieve sampler state associated to this replica. sampler_state = self._sampler_states[replica_id] - log.debug(f"{sampler_state=}") # Compute energy for all thermodynamic states. - from openmmtools.states import group_by_compatibility - - for energies, the_states in [ - (energy_neighborhood_states, neighborhood_thermodynamic_states), - ]: - # Group thermodynamic states by compatibility. - compatible_groups, original_indices = group_by_compatibility(the_states) - - # Compute the reduced potentials of all the compatible states. - for compatible_group, state_indices in zip( - compatible_groups, original_indices - ): - # Compute and update the reduced potentials. - compatible_energies = calculate_reduced_potential_at_states( - sampler_state, compatible_group, self.nbr_list - ) - for energy_idx, state_idx in enumerate(state_indices): - energies[state_idx] = compatible_energies[energy_idx] - - # Return the new energies. - log.info(f"Computed energies for replica {replica_id}") - log.info(f"{energy_neighborhood_states=}") - return energy_neighborhood_states + return calculate_reduced_potential_at_states( + sampler_state, thermodynamic_states, self.nbr_list + ) def create( self, @@ -247,28 +218,10 @@ def create( self._reporter = reporter self._reporter.open(mode="a") - @classmethod - def _default_initial_thermodynamic_states( - cls, - thermodynamic_states: List[ThermodynamicState], - sampler_states: List[SamplerState], - ): - """ - Create the initial_thermodynamic_states obeying the following rules: - - * ``len(thermodynamic_states) == len(sampler_states)``: 1-to-1 distribution - """ - n_thermo = len(thermodynamic_states) - n_sampler = len(sampler_states) - assert n_thermo == n_sampler, "Must have 1-to-1 distribution of states" - initial_thermo_states = np.arange(n_thermo, dtype=int) - return initial_thermo_states - def _allocate_variables( self, thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], - unsampled_thermodynamic_states: Optional[List[ThermodynamicState]] = None, ) -> None: """ Allocate and initialize internal variables for the sampler. @@ -301,16 +254,10 @@ def _allocate_variables( copy.deepcopy(sampler_state) for sampler_state in sampler_states ] - # Handle default unsampled thermodynamic states. - self._unsampled_states = ( - copy.deepcopy(unsampled_thermodynamic_states) - if unsampled_thermodynamic_states is not None - else [] - ) - + assert len(self._thermodynamic_states) == len(self._sampler_states) # Set initial thermodynamic state indices - initial_thermodynamic_states = self._default_initial_thermodynamic_states( - thermodynamic_states, sampler_states + initial_thermodynamic_states = np.arange( + len(self._thermodynamic_states), dtype=int ) self._replica_thermodynamic_states = np.array( initial_thermodynamic_states, np.int64 @@ -329,10 +276,6 @@ def _allocate_variables( self._energy_thermodynamic_states = np.zeros( [self.n_replicas, self.n_states], np.float64 ) - self._neighborhoods = np.zeros([self.n_replicas, self.n_states], "i1") - self._energy_unsampled_states = np.zeros( - [self.n_replicas, len(self._unsampled_states)], np.float64 - ) # Ensure there is an MCMCMove for each thermodynamic state. if isinstance(self._mcmc_moves, MCMCMove): @@ -440,101 +383,6 @@ def minimize( for replica_id in range(self.n_replicas): self._minimize_replica(replica_id, tolerance, max_iterations) - def _equilibration_timings(self, timer, iteration: int, n_iterations: int): - iteration_time = timer.stop("Equilibration Iteration") - partial_total_time = timer.partial("Run Equilibration") - time_per_iteration = partial_total_time / iteration - estimated_time_remaining = time_per_iteration * (n_iterations - iteration) - estimated_total_time = time_per_iteration * n_iterations - estimated_finish_time = time.time() + estimated_time_remaining - # TODO: Transmit timing information - - log.info(f"Iteration took {iteration_time:.3f}s.") - if estimated_time_remaining != float("inf"): - log.info( - "Estimated completion (of equilibration only) in {}, at {} (consuming total wall clock time {}).".format( - str(datetime.timedelta(seconds=estimated_time_remaining)), - time.ctime(estimated_finish_time), - str(datetime.timedelta(seconds=estimated_total_time)), - ) - ) - - def equilibrate( - self, n_iterations: int, mcmc_moves: Optional[List[MCMCMove]] = None - ): - """ - Equilibrate all replicas in the sampler. - - This method equilibrates the system by running a specified number of - MCMC iterations. The equilibration uses either the provided MCMC moves - or the default ones set during initialization. - - Parameters - ---------- - n_iterations : int - The number of equilibration iterations to perform. - mcmc_moves : Optional[List[mcmc.MCMCMove]], optional - A list of MCMCMove objects to use for equilibration. If None, the - MCMC moves used in production will be used. Defaults to None. - - Raises - ------ - RuntimeError - If the simulation has not been created before calling this method. - """ - # Check that simulation has been created. - if self.n_replicas == 0: - raise RuntimeError( - "Cannot equilibrate replicas. The simulation must be created first." - ) - - # Use production MCMC moves if none are provided - mcmc_moves = mcmc_moves or self._mcmc_moves - - # Make sure there is one MCMCMove per thermodynamic state. - if isinstance(mcmc_moves, MCMCMove): - mcmc_moves = [copy.deepcopy(mcmc_moves) for _ in range(self.n_states)] - - if len(mcmc_moves) != self.n_states: - raise RuntimeError( - f"The number of MCMCMoves ({len(self._mcmc_moves)}) and ThermodynamicStates ({self.n_states}) for equilibration must be the same." - ) - from openmmtools.utils import Timer - - timer = Timer() - timer.start("Run Equilibration") - - # Temporarily set the equilibration MCMCMoves. - production_mcmc_moves = self._mcmc_moves - self._mcmc_moves = mcmc_moves - - for iteration in range(1, n_iterations + 1): - log.info(f"Equilibration iteration {iteration}/{n_iterations}") - timer.start("Equilibration Iteration") - - # NOTE: Unlike run(), do NOT increment iteration counter. - # self._iteration += 1 - - # Propagate replicas. - self._propagate_replicas() - - # Compute energies of all replicas at all states - self._compute_energies() - - # Update thermodynamic states - self._replica_thermodynamic_states = self._mix_replicas() - - # Computing timing information - self._equilibration_timings( - timer, iteration=iteration, n_iterations=n_iterations - ) - timer.report_timing() - - # Restore production MCMCMoves. - self._mcmc_moves = production_mcmc_moves - - # TODO: Update stored positions. - def _propagate_replica(self, replica_id: int): """ Propagate the state of a single replica. @@ -552,19 +400,18 @@ def _propagate_replica(self, replica_id: int): If an error occurs during the propagation of the replica. """ # Retrieve thermodynamic, sampler states, and MCMC move of this replica. - # Retrieve thermodynamic and sampler states for the replica thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] sampler_state = self._sampler_states[replica_id] thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] + log.info(thermodynamic_state.potential.K) mcmc_move = self._mcmc_moves[thermodynamic_state_id] - + log.info(f"Position before move: {sampler_state.x0}") # Apply MCMC move. - try: - mcmc_move.run(sampler_state, thermodynamic_state) - except Exception as e: - log.warning(e) - raise e + mcmc_move.run(sampler_state, thermodynamic_state) + self._sampler_states[replica_id] = sampler_state + log.info(f"Position after move: {self._sampler_states[replica_id].x0}") + a = 6 def _perform_swap_proposals(self): """ @@ -633,38 +480,6 @@ def _propagate_replicas(self) -> None: for replica_id in range(self.n_replicas): self._propagate_replica(replica_id) - def _neighborhood(self, state_index: int) -> List[int]: - """ - Compute the indices of neighboring states for a given state. - - This method determines the neighborhood of states around a given state index, - considering the 'locality' parameter. If 'locality' is None, the neighborhood - includes all states; otherwise, it includes states within the 'locality' range. - - Parameters - ---------- - state_index : int - The index of the state for which the neighborhood is to be calculated. - - Returns - ------- - List[int] - A list of state indices that are considered neighbors of the given state. - """ - if self.locality is None: - # Global neighborhood - return list(range(self.n_states)) - else: - # Local neighborhood specified by 'locality' - lower_bound = max(0, state_index - self.locality) - upper_bound = min(self.n_states, state_index + self.locality + 1) - return list( - range( - lower_bound, - upper_bound, - ) - ) - @with_timer("Computing energy matrix") def _compute_energies(self) -> None: """ @@ -678,18 +493,12 @@ def _compute_energies(self) -> None: log.debug("Computing energy matrix for all replicas...") # Initialize the energy matrix and neighborhoods self._energy_thermodynamic_states = np.zeros((self.n_replicas, self.n_states)) - self._neighborhoods = np.zeros((self.n_replicas, self.n_states), dtype=bool) # Calculate energies for each replica for replica_id in range(self.n_replicas): - neighborhood = self._neighborhood( - self._replica_thermodynamic_states[replica_id] - ) - self._neighborhoods[replica_id, neighborhood] = True - # Compute and store energies for the neighborhood states self._energy_thermodynamic_states[ - replica_id, neighborhood + replica_id, : ] = self._compute_replica_energies(replica_id) def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: @@ -743,9 +552,6 @@ def _update_run_progress(self, timer, run_initial_iteration, iteration_limit): f"Estimated completion in {self._timing_data['estimated_time_remaining']}, at {self._timing_data['estimated_localtime_finish_date']} (consuming total wall clock time {self._timing_data['estimated_total_time']})." ) - # Perform sanity checks to see if we should terminate here. - self._check_nan_energy() - def run(self, n_iterations: int = 10) -> None: """ Execute the replica-exchange simulation. @@ -928,49 +734,6 @@ def flatten(iterator): return flatten(self.mcmc_moves) - def _check_nan_energy(self): - """Checks that energies are finite and abort otherwise. - - Checks both sampled and unsampled thermodynamic states. - - """ - # Find faulty replicas to create error message. - nan_replicas = [] - - # Check sampled thermodynamic states first. - state_type = "thermodynamic state" - for replica_id, state_id in enumerate(self._replica_thermodynamic_states): - neighborhood = self._neighborhood(state_id) - energies_neighborhood = self._energy_thermodynamic_states[ - replica_id, neighborhood - ] - if np.any(np.isnan(energies_neighborhood)): - nan_replicas.append((replica_id, energies_neighborhood)) - - # If there are no NaNs in energies, look for NaNs in the unsampled states energies. - if (len(nan_replicas) == 0) and (self._energy_unsampled_states.shape[1] > 0): - state_type = "unsampled thermodynamic state" - for replica_id in range(self.n_replicas): - if np.any(np.isnan(self._energy_unsampled_states[replica_id])): - nan_replicas.append( - (replica_id, self._energy_unsampled_states[replica_id]) - ) - - # Raise exception if we have found some NaN energies. - if len(nan_replicas) > 0: - # Log failed replica, its thermo state, and the energy matrix row. - err_msg = "NaN encountered in {} energies for the following replicas and states".format( - state_type - ) - for replica_id, energy_row in nan_replicas: - err_msg += "\n\tEnergies for positions at replica {} (current state {}): {} kT".format( - replica_id, - self._replica_thermodynamic_states[replica_id], - energy_row, - ) - log.critical(err_msg) - raise RuntimeError(err_msg) - def _update_analysis(self): """Update analysis of free energies""" diff --git a/chiron/states.py b/chiron/states.py index 10f07f4..8ba5e52 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -217,32 +217,6 @@ def _check_completness(self): if self.temperature and self.pressure and self.nr_of_particles: log.info("NpT ensemble is simulated.") - def is_state_compatible(self, thermodynamic_state): - """Check compatibility between ThermodynamicStates. - - Parameters - ---------- - thermodynamic_state : ThermodynamicState - The thermodynamic state to test. - - Returns - ------- - is_compatible : bool - True if the states are compatible, False otherwise. - - Examples - -------- - States in the same ensemble (NVT or NPT) are compatible. - States in different ensembles are not compatible. - States that store different systems (that differ by more than - barostat and thermostat pressure and temperature) are also not - compatible. - """ - - # Check that the states are in the same ensemble. - # TODO: implement this - pass - def get_reduced_potential( self, sampler_state: SamplerState, nbr_list=None ) -> float: @@ -294,7 +268,7 @@ def kT_to_kJ_per_mol(self, energy): def calculate_reduced_potential_at_states( sampler_state: SamplerState, - themrodynamic_states: List[ThermodynamicState], + thermodynamic_states: List[ThermodynamicState], nbr_list=None, ): """ @@ -313,7 +287,10 @@ def calculate_reduced_potential_at_states( The reduced potential of the system for each thermodynamic state. """ - reduced_potentials = [] - for state in themrodynamic_states: - reduced_potentials.append(state.get_reduced_potential(sampler_state)) + import numpy as np + + reduced_potentials = np.zeros(len(thermodynamic_states)) + for state_idx, state in enumerate(thermodynamic_states): + reduced_potentials[state_idx] = state.get_reduced_potential(sampler_state) + log.debug(f"reduced potentials per sampler sate: {reduced_potentials}") return reduced_potentials From c850b43ccc22e1c848a6d6601c923f2066b20d67 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 11:23:59 +0100 Subject: [PATCH 29/55] Add save_traj_in_memory option to LangevinIntegrator and LangevinDynamicsMove --- chiron/integrators.py | 17 ++++++++++++----- chiron/mcmc.py | 8 +++++++- chiron/multistate.py | 19 +++++-------------- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/chiron/integrators.py b/chiron/integrators.py index b25d348..1870867 100644 --- a/chiron/integrators.py +++ b/chiron/integrators.py @@ -29,6 +29,7 @@ def __init__( collision_rate=1.0 / unit.picoseconds, save_frequency: int = 100, reporter: Optional[SimulationReporter] = None, + save_traj_in_memory: bool = False, ) -> None: """ Initialize the LangevinIntegrator object. @@ -56,7 +57,8 @@ def __init__( log.info(f"Using reporter {reporter} saving to {reporter.filename}") self.reporter = reporter self.save_frequency = save_frequency - + self.save_traj_in_memory = save_traj_in_memory + self.traj = [] self.velocities = None def set_velocities(self, vel: unit.Quantity) -> None: @@ -109,10 +111,10 @@ def run( temperature = thermodynamic_state.temperature x0 = sampler_state.x0 - log.info("Running Langevin dynamics") - log.info(f"n_steps = {n_steps}") - log.info(f"temperature = {temperature}") - log.info(f"Using seed: {key}") + log.debug("Running Langevin dynamics") + log.debug(f"n_steps = {n_steps}") + log.debug(f"temperature = {temperature}") + log.debug(f"Using seed: {key}") kbT_unitless = (self.kB * temperature).value_in_unit_system(unit.md_unit_system) mass_unitless = jnp.array(mass.value_in_unit_system(unit.md_unit_system))[ @@ -169,6 +171,9 @@ def run( if step % self.save_frequency == 0: # log.debug(f"Saving at step {step}") # check if reporter is attribute of the class + # log.debug(f"step {step} energy {potential.compute_energy(x, nbr_list)}") + # log.debug(f"step {step} force {F}") + if hasattr(self, "reporter") and self.reporter is not None: d = { "traj": x, @@ -180,6 +185,8 @@ def run( # log.debug(d) self.reporter.report(d) + if self.save_traj_in_memory: + self.traj.append(x) log.debug("Finished running Langevin dynamics") # save the final state of the simulation in the sampler_state object diff --git a/chiron/mcmc.py b/chiron/mcmc.py index c67730e..cbd9979 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -71,6 +71,7 @@ def __init__( simulation_reporter: Optional[SimulationReporter] = None, nr_of_steps=1_000, seed: int = 1234, + save_traj_in_memory: bool = False, ): """ Initialize the LangevinDynamicsMove with a molecular system. @@ -88,13 +89,15 @@ def __init__( self.stepsize = stepsize self.collision_rate = collision_rate self.simulation_reporter = simulation_reporter - + self.save_traj_in_memory = save_traj_in_memory + self.traj = [] from chiron.integrators import LangevinIntegrator self.integrator = LangevinIntegrator( stepsize=self.stepsize, collision_rate=self.collision_rate, reporter=self.simulation_reporter, + save_traj_in_memory=save_traj_in_memory, ) def run( @@ -122,6 +125,9 @@ def run( n_steps=self.nr_of_moves, key=self.key, ) + if self.save_traj_in_memory: + self.traj.append(self.integrator.traj) + self.integrator.traj = [] class MCMove(MCMCMove): diff --git a/chiron/multistate.py b/chiron/multistate.py index f4b2d3b..0243c97 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -1,5 +1,4 @@ import copy -import time from typing import List, Optional from chiron.states import SamplerState, ThermodynamicState import datetime @@ -69,6 +68,7 @@ def __init__( self._online_analysis_interval = online_analysis_interval self._timing_data = dict() self.free_energy_estimator = None + self._traj = None # Handling default propagator. if mcmc_moves is None: @@ -181,7 +181,6 @@ def create( thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_list: NeighborListNsqrd, - reporter: MultiStateReporter, metadata: Optional[dict] = None, ): """Create new multistate sampler simulation. @@ -193,8 +192,6 @@ def create( of sampler states provided. nbr_list : NeighborListNsqrd Neighbor list object to be used in the simulation. - reporter : MultiStateReporter - Reporter object to record simulation data. metadata : dict, optional Optional simulation metadata to be stored in the file. @@ -215,8 +212,7 @@ def create( self._allocate_variables(thermodynamic_states, sampler_states) self.nbr_list = nbr_list - self._reporter = reporter - self._reporter.open(mode="a") + self._reporter = None def _allocate_variables( self, @@ -276,7 +272,7 @@ def _allocate_variables( self._energy_thermodynamic_states = np.zeros( [self.n_replicas, self.n_states], np.float64 ) - + self._traj = [[] for _ in range(self.n_replicas)] # Ensure there is an MCMCMove for each thermodynamic state. if isinstance(self._mcmc_moves, MCMCMove): self._mcmc_moves = [ @@ -404,14 +400,10 @@ def _propagate_replica(self, replica_id: int): sampler_state = self._sampler_states[replica_id] thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] - log.info(thermodynamic_state.potential.K) mcmc_move = self._mcmc_moves[thermodynamic_state_id] - log.info(f"Position before move: {sampler_state.x0}") # Apply MCMC move. mcmc_move.run(sampler_state, thermodynamic_state) - self._sampler_states[replica_id] = sampler_state - log.info(f"Position after move: {self._sampler_states[replica_id].x0}") - a = 6 + self._traj[replica_id].append(sampler_state.x0) def _perform_swap_proposals(self): """ @@ -464,7 +456,6 @@ def _mix_replicas(self) -> np.ndarray: log.debug( f"Accepted {n_swaps_accepted}/{n_swaps_proposed} attempted swaps ({swap_fraction_accepted * 100.0:.1f}%)" ) - return new_replica_states @with_timer("Propagating all replicas") def _propagate_replicas(self) -> None: @@ -608,7 +599,7 @@ def run(self, n_iterations: int = 10) -> None: timer.start("Iteration") # Update thermodynamic states - self._replica_thermodynamic_states = self._mix_replicas() + self._mix_replicas() # Propagate replicas. self._propagate_replicas() From ef18482df113f982029c58e1e1f2cbe787f9a13f Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 11:39:35 +0100 Subject: [PATCH 30/55] Remove unnecessary log statements in MultiStateSampler class --- chiron/multistate.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index 0243c97..08d36d0 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -611,10 +611,6 @@ def run(self, n_iterations: int = 10) -> None: self._energy_thermodynamic_states_for_each_iteration_in_run[ :, :, self._iteration ] = self._energy_thermodynamic_states - log.info( - self._energy_thermodynamic_states_for_each_iteration_in_run[:, :, 1] - ) - log.info(self._energy_thermodynamic_states) # Write iteration to storage file # TODO # self._report_iteration() From d0252af473f7e920088a31f7d8bdafa85fa697c9 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 11:48:58 +0100 Subject: [PATCH 31/55] refactor tests --- chiron/tests/test_multistate.py | 206 +++++++++++++------------------- 1 file changed, 82 insertions(+), 124 deletions(-) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index ba02528..fc3e13f 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -2,47 +2,11 @@ import pytest -@pytest.fixture -def ho_multistate_sampler() -> MultiStateSampler: - """ - Create a multi-state sampler for a harmonic oscillator system. - - Returns: - MultiStateSampler: The multi-state sampler object. - """ - import math +def setup_sampler(): from openmm import unit from chiron.mcmc import LangevinDynamicsMove - from chiron.states import ThermodynamicState, SamplerState - from openmmtools.testsystems import HarmonicOscillator - from chiron.potential import HarmonicOscillatorPotential from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace - ho = HarmonicOscillator() - n_replicas = 3 - T_min = 298.0 * unit.kelvin # Minimum temperature. - T_max = 600.0 * unit.kelvin # Maximum temperature. - temperatures = [ - T_min - + (T_max - T_min) - * (math.exp(float(i) / float(n_replicas - 1)) - 1.0) - / (math.e - 1.0) - for i in range(n_replicas) - ] - import jax.numpy as jnp - - x0s = [ - unit.Quantity(jnp.array([[x0, 0.0, 0.0]]), unit.angstrom) - for x0 in jnp.linspace(0.0, 1.0, n_replicas) - ] - thermodynamic_states = [ - ThermodynamicState( - HarmonicOscillatorPotential(ho.topology, x0=x0), temperature=T - ) - for T, x0 in zip(temperatures, x0s) - ] - sampler_state = [SamplerState(ho.positions) for _ in temperatures] - # Initialize simulation object with options. Run with a langevin integrator. # initialize the LennardJones potential in chiron # @@ -56,81 +20,53 @@ def ho_multistate_sampler() -> MultiStateSampler: move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=50) - from openmmtools.multistate import MultiStateReporter - - reporter = MultiStateReporter("test.nc") - - multistate_sampler = MultiStateSampler(mcmc_moves=move, number_of_iterations=2) - multistate_sampler.create( - thermodynamic_states=thermodynamic_states, - sampler_states=sampler_state, - nbr_list=nbr_list, - reporter=reporter, - ) - - return multistate_sampler - + multistate_sampler = MultiStateSampler(mcmc_moves=move) + return nbr_list, multistate_sampler -def test_multistate_class(ho_multistate_sampler): - # test the multistate_sampler object - assert ho_multistate_sampler.number_of_iterations == 2 - assert ho_multistate_sampler.n_replicas == 3 - assert ho_multistate_sampler.n_states == 3 - assert ho_multistate_sampler._energy_thermodynamic_states.shape == (3, 3) - assert ho_multistate_sampler._n_proposed_matrix.shape == (3, 3) - -def test_multistate_minimize(ho_multistate_sampler): +@pytest.fixture +def ho_multistate_sampler_multiple_minima() -> MultiStateSampler: """ - Test function for the `minimize` method of the `ho_multistate_sampler` object. - It checks if the sampler states are correctly minimized. + Create a multi-state sampler for a harmonic oscillator system. - Parameters - ---------- - ho_multistate_sampler: The `ho_multistate_sampler` object to be tested. + Returns: + MultiStateSampler: The multi-state sampler object. """ + from chiron.states import ThermodynamicState, SamplerState + from chiron.potential import HarmonicOscillatorPotential + import jax.numpy as jnp + from openmm import unit - import numpy as np - - ho_multistate_sampler.minimize() - - assert np.allclose( - ho_multistate_sampler.sampler_states[0].x0, np.array([[0.0, 0.0, 0.0]]) - ) - assert np.allclose( - ho_multistate_sampler.sampler_states[1].x0, - np.array([[0.05, 0.0, 0.0]]), - atol=1e-2, - ) - assert np.allclose( - ho_multistate_sampler.sampler_states[2].x0, - np.array([[0.1, 0.0, 0.0]]), - atol=1e-2, - ) - + n_replicas = 3 + T = 300.0 * unit.kelvin # Minimum temperature. + x0s = [ + unit.Quantity(jnp.array([[x0, 0.0, 0.0]]), unit.angstrom) + for x0 in jnp.linspace(0.0, 1.0, n_replicas) + ] -def test_multistate_equilibration(ho_multistate_sampler): - import numpy as np + from openmmtools.testsystems import HarmonicOscillator - ho_multistate_sampler.equilibrate(10) + ho = HarmonicOscillator() - assert np.allclose( - ho_multistate_sampler._replica_thermodynamic_states, np.array([0, 1, 2]) - ) - assert np.allclose( - ho_multistate_sampler._energy_thermodynamic_states, - np.array( - [ - [4.81132936, 3.84872651, 3.10585403], - [6.54490519, 5.0176239, 3.85019779], - [9.48260307, 7.07196712, 5.21255827], - ] - ), + thermodynamic_states = [ + ThermodynamicState( + HarmonicOscillatorPotential(ho.topology, x0=x0), temperature=T + ) + for x0 in x0s + ] + sampler_state = [SamplerState(ho.positions) for _ in x0s] + nbr_list, multistate_sampler = setup_sampler() + multistate_sampler.create( + thermodynamic_states=thermodynamic_states, + sampler_states=sampler_state, + nbr_list=nbr_list, ) + return multistate_sampler + @pytest.fixture -def ho_multistate_sampler_single_sampler_state() -> MultiStateSampler: +def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: """ Create a multi-state sampler for a harmonic oscillator system. @@ -138,11 +74,9 @@ def ho_multistate_sampler_single_sampler_state() -> MultiStateSampler: MultiStateSampler: The multi-state sampler object. """ from openmm import unit - from chiron.mcmc import LangevinDynamicsMove from chiron.states import ThermodynamicState, SamplerState from openmmtools.testsystems import HarmonicOscillator from chiron.potential import HarmonicOscillatorPotential - from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace ho = HarmonicOscillator() n_states = 4 @@ -150,14 +84,19 @@ def ho_multistate_sampler_single_sampler_state() -> MultiStateSampler: T = 300.0 * unit.kelvin # Minimum temperature. kT = unit.BOLTZMANN_CONSTANT_kB * T * unit.AVOGADRO_CONSTANT_NA sigmas = [ - unit.Quantity(1.0 + 0.2 * state_index, unit.angstrom) + unit.Quantity(2.0 + 0.2 * state_index, unit.angstrom) for state_index in range(n_states) ] Ks = [kT / sigma**2 for sigma in sigmas] + thermodynamic_states = [ ThermodynamicState(HarmonicOscillatorPotential(ho.topology, k=k), temperature=T) for k in Ks ] + from loguru import logger as log + + log.info(f"Initialize harmonic oscillator with {n_states} states and ks {Ks}") + sampler_state = [SamplerState(ho.positions) for _ in sigmas] import numpy as np @@ -168,39 +107,58 @@ def ho_multistate_sampler_single_sampler_state() -> MultiStateSampler: ] ) - # Initialize simulation object with options. Run with a langevin integrator. - # initialize the LennardJones potential in chiron - # - sigma = 0.34 * unit.nanometer - cutoff = 3.0 * sigma - skin = 0.5 * unit.nanometer - - nbr_list = NeighborListNsqrd( - OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=10 - ) - - move = LangevinDynamicsMove(stepsize=1.0 * unit.femtoseconds, nr_of_steps=100) - - from openmmtools.multistate import MultiStateReporter - - reporter = MultiStateReporter("test.nc") + nbr_list, multistate_sampler = setup_sampler() - multistate_sampler = MultiStateSampler( - mcmc_moves=move, - ) multistate_sampler.create( thermodynamic_states=thermodynamic_states, sampler_states=sampler_state, nbr_list=nbr_list, - reporter=reporter, ) multistate_sampler.analytical_f_i = f_i multistate_sampler.delta_f_ij_analytical = f_i - f_i[:, np.newaxis] return multistate_sampler -def test_multistate_run(ho_multistate_sampler_single_sampler_state): - ho_sampler = ho_multistate_sampler_single_sampler_state +def test_multistate_class(ho_multistate_sampler_multiple_minima: MultiStateSampler): + # test the multistate_sampler object + assert ho_multistate_sampler_multiple_minima._iteration == 0 + assert ho_multistate_sampler_multiple_minima.n_replicas == 3 + assert ho_multistate_sampler_multiple_minima.n_states == 3 + assert ho_multistate_sampler_multiple_minima._energy_thermodynamic_states.shape == (3, 3) + assert ho_multistate_sampler_multiple_minima._n_proposed_matrix.shape == (3, 3) + + +def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSampler): + """ + Test function for the `minimize` method of the `ho_multistate_sampler` object. + It checks if the sampler states are correctly minimized. + + Parameters + ---------- + ho_multistate_sampler: The `ho_multistate_sampler` object to be tested. + """ + + import numpy as np + + ho_multistate_sampler_multiple_minima.minimize() + + assert np.allclose( + ho_multistate_sampler_multiple_minima.sampler_states[0].x0, np.array([[0.0, 0.0, 0.0]]) + ) + assert np.allclose( + ho_multistate_sampler_multiple_minima.sampler_states[1].x0, + np.array([[0.05, 0.0, 0.0]]), + atol=1e-2, + ) + assert np.allclose( + ho_multistate_sampler_multiple_minima.sampler_states[2].x0, + np.array([[0.1, 0.0, 0.0]]), + atol=1e-2, + ) + + +def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): + ho_sampler = ho_multistate_sampler_multiple_ks import numpy as np n_iteratinos = 100 From b92c89f71a61a9b4539b9cbe7a9e00c29cff7ee1 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 12:10:52 +0100 Subject: [PATCH 32/55] Refactor multistate sampler setup and add test cases --- chiron/tests/test_multistate.py | 74 +++++++++++++++++++++++++-------- 1 file changed, 57 insertions(+), 17 deletions(-) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index fc3e13f..d38f357 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -1,15 +1,20 @@ from chiron.multistate import MultiStateSampler +from chiron.neighbors import NeighborListNsqrd import pytest +from typing import Tuple -def setup_sampler(): +def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: + """ + Set up the neighbor list and multistate sampler for the simulation. + + Returns: + Tuple: A tuple containing the neighbor list and multistate sampler objects. + """ from openmm import unit from chiron.mcmc import LangevinDynamicsMove from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace - # Initialize simulation object with options. Run with a langevin integrator. - # initialize the LennardJones potential in chiron - # sigma = 0.34 * unit.nanometer cutoff = 3.0 * sigma skin = 0.5 * unit.nanometer @@ -18,7 +23,7 @@ def setup_sampler(): OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) - move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=50) + move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=500) multistate_sampler = MultiStateSampler(mcmc_moves=move) return nbr_list, multistate_sampler @@ -27,7 +32,7 @@ def setup_sampler(): @pytest.fixture def ho_multistate_sampler_multiple_minima() -> MultiStateSampler: """ - Create a multi-state sampler for a harmonic oscillator system. + Create a multi-state sampler for multiple harmonic oscillators with different minimum values. Returns: MultiStateSampler: The multi-state sampler object. @@ -68,10 +73,11 @@ def ho_multistate_sampler_multiple_minima() -> MultiStateSampler: @pytest.fixture def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: """ - Create a multi-state sampler for a harmonic oscillator system. - - Returns: - MultiStateSampler: The multi-state sampler object. + Create a multi-state sampler for a harmonic oscillator system with different spring constants. + Returns + ------- + MultiStateSampler + The multi-state sampler object. """ from openmm import unit from chiron.states import ThermodynamicState, SamplerState @@ -120,22 +126,37 @@ def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: def test_multistate_class(ho_multistate_sampler_multiple_minima: MultiStateSampler): - # test the multistate_sampler object + """ + Test initialization for the MultiStateSampler class. + + Parameters: + ------- + ho_multistate_sampler_multiple_minima: MultiStateSampler + An instance of the MultiStateSampler class. + Raises: + ------- + AssertionError: + If any of the assertions fail. + + """ assert ho_multistate_sampler_multiple_minima._iteration == 0 assert ho_multistate_sampler_multiple_minima.n_replicas == 3 assert ho_multistate_sampler_multiple_minima.n_states == 3 - assert ho_multistate_sampler_multiple_minima._energy_thermodynamic_states.shape == (3, 3) + assert ho_multistate_sampler_multiple_minima._energy_thermodynamic_states.shape == ( + 3, + 3, + ) assert ho_multistate_sampler_multiple_minima._n_proposed_matrix.shape == (3, 3) def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSampler): """ Test function for the `minimize` method of the `ho_multistate_sampler` object. - It checks if the sampler states are correctly minimized. + Check if the sampler states are correctly minimized. Parameters ---------- - ho_multistate_sampler: The `ho_multistate_sampler` object to be tested. + ho_multistate_sampler: MultiStateSampler """ import numpy as np @@ -143,7 +164,8 @@ def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSa ho_multistate_sampler_multiple_minima.minimize() assert np.allclose( - ho_multistate_sampler_multiple_minima.sampler_states[0].x0, np.array([[0.0, 0.0, 0.0]]) + ho_multistate_sampler_multiple_minima.sampler_states[0].x0, + np.array([[0.0, 0.0, 0.0]]), ) assert np.allclose( ho_multistate_sampler_multiple_minima.sampler_states[1].x0, @@ -158,10 +180,25 @@ def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSa def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): + """ + Test function for running the multistate sampler. + + Parameters + ---------- + ho_multistate_sampler_multiple_ks: MultiStateSampler + The multistate sampler object. + Raises + ------- + AssertionError: If free energy does not converge to the analytical free energy difference. + + """ + ho_sampler = ho_multistate_sampler_multiple_ks import numpy as np - n_iteratinos = 100 + print(f"Analytical free energy difference: {ho_sampler.delta_f_ij_analytical[0]}") + + n_iteratinos = 25 ho_sampler.run(n_iteratinos) # check that we have the correct number of iterations, replicas and states @@ -174,4 +211,7 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): print(ho_sampler.analytical_f_i) print(ho_sampler.delta_f_ij_analytical) print(ho_sampler._last_mbar_f_k_offline) - a = 7 + + assert np.isclose( + ho_sampler.delta_f_ij_analytical[0], ho_sampler._last_mbar_f_k_offline + ) From 5e07290f7f184c966d75d70eca57c7f023031ca6 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 14:27:04 +0100 Subject: [PATCH 33/55] Fix assertions in test_multistate_class and test_sampler_state_inputs --- chiron/tests/test_multistate.py | 10 +++++----- chiron/tests/test_states.py | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index d38f357..ce74e1d 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -76,7 +76,7 @@ def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: Create a multi-state sampler for a harmonic oscillator system with different spring constants. Returns ------- - MultiStateSampler + MultiStateSampler The multi-state sampler object. """ from openmm import unit @@ -131,11 +131,11 @@ def test_multistate_class(ho_multistate_sampler_multiple_minima: MultiStateSampl Parameters: ------- - ho_multistate_sampler_multiple_minima: MultiStateSampler + ho_multistate_sampler_multiple_minima: MultiStateSampler An instance of the MultiStateSampler class. Raises: ------- - AssertionError: + AssertionError: If any of the assertions fail. """ @@ -212,6 +212,6 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): print(ho_sampler.delta_f_ij_analytical) print(ho_sampler._last_mbar_f_k_offline) - assert np.isclose( - ho_sampler.delta_f_ij_analytical[0], ho_sampler._last_mbar_f_k_offline + assert np.allclose( + ho_sampler.delta_f_ij_analytical[0], ho_sampler._last_mbar_f_k_offline, atol=0.1 ) diff --git a/chiron/tests/test_states.py b/chiron/tests/test_states.py index c9640a2..2e314a9 100644 --- a/chiron/tests/test_states.py +++ b/chiron/tests/test_states.py @@ -124,11 +124,12 @@ def test_sampler_state_inputs(): x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), box_vectors=openmm_box, ) - assert jnp.all( + assert jnp.allclose( state.box_vectors == jnp.array( [[4.0311456, 0.0, 0.0], [0.0, 4.0311456, 0.0], [0.0, 0.0, 4.0311456]] - ) + ), + atol=1e-4, ) # openmm box vectors end up as a list with contents; check to make sure we capture an error if we pass a bad list From 8e4dc0553c6b3f03cb1adabad90d473240cfa0a9 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 14:46:07 +0100 Subject: [PATCH 34/55] Fix box_vectors comparison in test_sampler_state_inputs() --- chiron/tests/test_states.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chiron/tests/test_states.py b/chiron/tests/test_states.py index 2e314a9..d94869a 100644 --- a/chiron/tests/test_states.py +++ b/chiron/tests/test_states.py @@ -125,8 +125,8 @@ def test_sampler_state_inputs(): box_vectors=openmm_box, ) assert jnp.allclose( - state.box_vectors - == jnp.array( + state.box_vectors, + jnp.array( [[4.0311456, 0.0, 0.0], [0.0, 4.0311456, 0.0], [0.0, 0.0, 4.0311456]] ), atol=1e-4, From 08dcf4b63f731e72aac2f9b716e859159b9fe2a7 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 15:48:30 +0100 Subject: [PATCH 35/55] Refactor move_set to MoveSchedule --- chiron/mcmc.py | 6 +- chiron/multistate.py | 163 ++++--------------------- chiron/tests/test_convergence_tests.py | 4 +- chiron/tests/test_mcmc.py | 12 +- 4 files changed, 35 insertions(+), 150 deletions(-) diff --git a/chiron/mcmc.py b/chiron/mcmc.py index cbd9979..d303ec7 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -242,9 +242,9 @@ def apply_move(self): pass -class MoveSet: +class MoveSchedule: """ - Represents a set of moves for a Markov Chain Monte Carlo (MCMC) algorithm. + Represents an (optimizable) series of moves for a Markov Chain Monte Carlo (MCMC) algorithm. Parameters ---------- @@ -298,7 +298,7 @@ class MCMCSampler(object): def __init__( self, - move_set: MoveSet, + move_set: MoveSchedule, sampler_state: SamplerState, thermodynamic_state: ThermodynamicState, ): diff --git a/chiron/multistate.py b/chiron/multistate.py index 08d36d0..9e55fd5 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -1,17 +1,11 @@ -import copy from typing import List, Optional from chiron.states import SamplerState, ThermodynamicState -import datetime -from loguru import logger as log -import numpy as np -from openmmtools.utils import with_timer from chiron.neighbors import NeighborListNsqrd from openmm import unit -from chiron.mcmc import MCMCMove -from openmmtools.multistate import MultiStateReporter +import numpy as np -class MultiStateSampler(object): +class MultiStateSampler: """ Base class for samplers that sample multiple thermodynamic states using one or more replicas. @@ -30,26 +24,23 @@ class MultiStateSampler(object): creation. If None is provided, Langevin dynamics with 2fm timestep, 5.0/ps collision rate, and 500 steps per iteration will be used. - locality : int > 0, optional, default None - If None, the energies at all states will be computed for every replica each iteration. - If int > 0, energies will only be computed for states ``range(max(0, state-locality), min(n_states, state+locality))``. - Attributes ---------- n_replicas n_states mcmc_moves sampler_states - metadata is_completed """ def __init__( self, mcmc_moves=None, - locality=None, online_analysis_interval=5, ): + import copy + from openmm import unit + # These will be set on initialization. See function # create() for explanation of single variables. self._thermodynamic_states = None @@ -86,9 +77,6 @@ def __init__( self._last_mbar_f_k = None self._last_err_free_energy = None - # Store locality - self.locality = locality - @property def n_states(self): """The integer number of thermodynamic states (read-only).""" @@ -121,6 +109,8 @@ def mcmc_moves(self): This can be set only before creation. """ + import copy + return copy.deepcopy(self._mcmc_moves) @property @@ -129,6 +119,8 @@ def sampler_states(self): This can be set only before running. """ + import copy + return copy.deepcopy(self._sampler_states) @property @@ -141,6 +133,8 @@ def is_periodic(self): @property def metadata(self): """A copy of the metadata dictionary passed on creation (read-only).""" + import copy + return copy.deepcopy(self._metadata) @property @@ -238,6 +232,8 @@ def _allocate_variables( RuntimeError If the number of MCMC moves and ThermodynamicStates do not match. """ + import copy + import numpy as np # Save thermodynamic states. This sets n_replicas. self._thermodynamic_states = [ @@ -274,6 +270,8 @@ def _allocate_variables( ) self._traj = [[] for _ in range(self.n_replicas)] # Ensure there is an MCMCMove for each thermodynamic state. + from chiron.mcmc import MCMCMove + if isinstance(self._mcmc_moves, MCMCMove): self._mcmc_moves = [ copy.deepcopy(self._mcmc_moves) for _ in range(self.n_states) @@ -311,6 +309,7 @@ def _minimize_replica( """ from chiron.minimze import minimize_energy + from loguru import logger as log # Retrieve thermodynamic and sampler states. thermodynamic_state = self._thermodynamic_states[ @@ -366,6 +365,7 @@ def minimize( RuntimeError If the simulation has not been created before calling this method. """ + from loguru import logger as log # Check that simulation has been created. if self.n_replicas == 0: @@ -435,6 +435,7 @@ def _mix_replicas(self) -> np.ndarray: np.ndarray An array of updated thermodynamic state indices for each replica. """ + from loguru import logger as log log.debug("Mixing replicas (does nothing for MultiStateSampler)...") @@ -457,7 +458,6 @@ def _mix_replicas(self) -> np.ndarray: f"Accepted {n_swaps_accepted}/{n_swaps_proposed} attempted swaps ({swap_fraction_accepted * 100.0:.1f}%)" ) - @with_timer("Propagating all replicas") def _propagate_replicas(self) -> None: """ Propagate all replicas through their respective MCMC moves. @@ -465,13 +465,13 @@ def _propagate_replicas(self) -> None: This method iterates over all replicas and applies the corresponding MCMC move to each one, based on its current thermodynamic state. """ + from loguru import logger as log log.debug("Propagating all replicas...") for replica_id in range(self.n_replicas): self._propagate_replica(replica_id) - @with_timer("Computing energy matrix") def _compute_energies(self) -> None: """ Compute the energies of all replicas at all thermodynamic states. @@ -480,6 +480,7 @@ def _compute_energies(self) -> None: considering the defined neighborhoods to optimize the computation. The energies are stored in the internal energy matrix of the sampler. """ + from loguru import logger as log log.debug("Computing energy matrix for all replicas...") # Initialize the energy matrix and neighborhoods @@ -511,6 +512,7 @@ def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: True if the simulation has completed based on the stopping criteria, False otherwise. """ + from loguru import logger as log # Check if iteration limit has been reached if iteration_limit is not None and self._iteration >= iteration_limit: @@ -523,26 +525,6 @@ def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: return False - def _update_run_progress(self, timer, run_initial_iteration, iteration_limit): - # Computing and transmitting timing information - iteration_time = timer.stop("Iteration") - partial_total_time = timer.partial("Run ReplicaExchange") - self._update_timing( - iteration_time, - partial_total_time, - run_initial_iteration, - iteration_limit, - ) - - # Log timing data as info level -- useful for users by default - log.info( - "Iteration took {:.3f}s.".format(self._timing_data["iteration_seconds"]) - ) - if self._timing_data["estimated_time_remaining"] != float("inf"): - log.info( - f"Estimated completion in {self._timing_data['estimated_time_remaining']}, at {self._timing_data['estimated_localtime_finish_date']} (consuming total wall clock time {self._timing_data['estimated_total_time']})." - ) - def run(self, n_iterations: int = 10) -> None: """ Execute the replica-exchange simulation. @@ -561,6 +543,7 @@ def run(self, n_iterations: int = 10) -> None: RuntimeError If an error occurs during the computation of energies. """ + from loguru import logger as log # If this is the first iteration, compute and store the # starting energies of the minimized/equilibrated structures. @@ -580,11 +563,6 @@ def run(self, n_iterations: int = 10) -> None: ] = self._energy_thermodynamic_states # TODO report energies - from openmmtools.utils import Timer - - timer = Timer() - timer.start("Run ReplicaExchange") - iteration_limit = n_iterations # start the sampling loop @@ -596,7 +574,6 @@ def run(self, n_iterations: int = 10) -> None: log.info("-" * 80) log.info(f"Iteration {self._iteration}/{iteration_limit}") log.info("-" * 80) - timer.start("Iteration") # Update thermodynamic states self._mix_replicas() @@ -626,103 +603,10 @@ def _report_iteration(self): # TODO: write trajectory # TODO: write mixing statistics - self._reporter.write_energies( - self._energy_thermodynamic_states, - self._neighborhoods, - self._energy_unsampled_states, - self._iteration, - ) - - def _report_iteration_items(self): - """ - Sub-function of :func:`_report_iteration` which handles all the actual individual item reporting in a - sub-class friendly way. The final actions of writing timestamp, last-good-iteration, and syncing - should be left to the :func:`_report_iteration` and subclasses should extend this function instead - """ - self._reporter.write_sampler_states(self._sampler_states, self._iteration) - self._reporter.write_replica_thermodynamic_states( - self._replica_thermodynamic_states, self._iteration - ) - self._reporter.write_mcmc_moves( - self._mcmc_moves - ) # MCMCMoves can store internal statistics. - self._reporter.write_energies( - self._energy_thermodynamic_states, - self._neighborhoods, - self._energy_unsampled_states, - self._iteration, - ) - self._reporter.write_mixing_statistics( - self._n_accepted_matrix, self._n_proposed_matrix, self._iteration - ) - - def _update_timing( - self, iteration_time, partial_total_time, run_initial_iteration, iteration_limit - ): - """ - Function that computes and transmits timing information to reporter. - - Parameters - ---------- - iteration_time : float - Time took in the iteration. - partial_total_time : float - Partial total time elapsed. - run_initial_iteration : int - Iteration where to start/resume the simulation. - iteration_limit : int - Hard limit on number of iterations to be run by the sampler. - """ - self._timing_data["iteration_seconds"] = iteration_time - self._timing_data["average_seconds_per_iteration"] = partial_total_time / ( - self._iteration - run_initial_iteration - ) - estimated_timedelta_remaining = datetime.timedelta( - seconds=self._timing_data["average_seconds_per_iteration"] - * (iteration_limit - self._iteration) - ) - estimated_finish_date = datetime.datetime.now() + estimated_timedelta_remaining - self._timing_data["estimated_time_remaining"] = str( - estimated_timedelta_remaining - ) # Putting it in dict as str - self._timing_data[ - "estimated_localtime_finish_date" - ] = estimated_finish_date.strftime("%Y-%b-%d-%H:%M:%S") - total_time_in_seconds = datetime.timedelta( - seconds=self._timing_data["average_seconds_per_iteration"] * iteration_limit - ) - self._timing_data["estimated_total_time"] = str(total_time_in_seconds) - - # Estimate performance - moves_iterator = self._flatten_moves_iterator() - # Only consider "dynamic" moves (timestep and n_steps attributes) - moves_times = [ - move.timestep.value_in_unit(unit.nanosecond) * move.n_steps - for move in moves_iterator - if hasattr(move, "timestep") and hasattr(move, "n_steps") - ] - iteration_simulated_nanoseconds = sum(moves_times) - seconds_in_a_day = (1 * unit.day).value_in_unit(unit.seconds) - self._timing_data["ns_per_day"] = iteration_simulated_nanoseconds / ( - self._timing_data["average_seconds_per_iteration"] / seconds_in_a_day - ) - - def _flatten_moves_iterator(self): - """Recursively flatten MCMC moves. Handles the cases where each move can be a set of moves, for example with - SequenceMove or WeightedMove objects.""" - - def flatten(iterator): - try: - yield from [ - inner_move for move in iterator for inner_move in flatten(move) - ] - except TypeError: # Inner object is not iterable, finish flattening. - yield iterator - - return flatten(self.mcmc_moves) def _update_analysis(self): """Update analysis of free energies""" + from loguru import logger as log if self._online_analysis_interval is None: log.debug("No online analysis requested") @@ -740,6 +624,7 @@ def _mbar_analysis(self): Perform mbar analysis """ from pymbar import MBAR + from loguru import logger as log self._last_mbar_f_k_offline = np.zeros(len(self._thermodynamic_states)) diff --git a/chiron/tests/test_convergence_tests.py b/chiron/tests/test_convergence_tests.py index 3deff74..86c881c 100644 --- a/chiron/tests/test_convergence_tests.py +++ b/chiron/tests/test_convergence_tests.py @@ -55,7 +55,7 @@ def test_convergence_of_MC_estimator(prep_temp_dir): simulation_reporter = SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") # Initalize the move set (here only LangevinDynamicsMove) - from chiron.mcmc import MetropolisDisplacementMove, MoveSet, MCMCSampler + from chiron.mcmc import MetropolisDisplacementMove, MoveSchedule, MCMCSampler mc_displacement_move = MetropolisDisplacementMove( nr_of_moves=100_000, @@ -64,7 +64,7 @@ def test_convergence_of_MC_estimator(prep_temp_dir): simulation_reporter=simulation_reporter, ) - move_set = MoveSet([("MetropolisDisplacementMove", mc_displacement_move)]) + move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) # Initalize the sampler sampler = MCMCSampler(move_set, sampler_state, thermodynamic_state) diff --git a/chiron/tests/test_mcmc.py b/chiron/tests/test_mcmc.py index 950b530..63c3f92 100644 --- a/chiron/tests/test_mcmc.py +++ b/chiron/tests/test_mcmc.py @@ -85,7 +85,7 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_LangevinDynamics """ from openmm import unit from chiron.potential import HarmonicOscillatorPotential - from chiron.mcmc import LangevinDynamicsMove, MoveSet, MCMCSampler + from chiron.mcmc import LangevinDynamicsMove, MoveSchedule, MCMCSampler # Initalize the testsystem from openmmtools.testsystems import HarmonicOscillatorArray @@ -117,7 +117,7 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_LangevinDynamics nr_of_steps=10, seed=1234, simulation_reporter=simulation_reporter ) - move_set = MoveSet([("LangevinMove", langevin_move)]) + move_set = MoveSchedule([("LangevinMove", langevin_move)]) # Initalize the sampler sampler = MCMCSampler(move_set, sampler_state, thermodynamic_state) @@ -137,7 +137,7 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla """ from openmm import unit from chiron.potential import HarmonicOscillatorPotential - from chiron.mcmc import MetropolisDisplacementMove, MoveSet, MCMCSampler + from chiron.mcmc import MetropolisDisplacementMove, MoveSchedule, MCMCSampler # Initalize the testsystem from openmmtools.testsystems import HarmonicOscillator @@ -173,7 +173,7 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla simulation_reporter=simulation_reporter, ) - move_set = MoveSet([("MetropolisDisplacementMove", mc_displacement_move)]) + move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) # Initalize the sampler sampler = MCMCSampler(move_set, sampler_state, thermodynamic_state) @@ -192,7 +192,7 @@ def test_sample_from_harmonic_osciallator_array_with_MCMC_classes_and_Metropolis sampler states, and uses the Metropolis displacement move in an MCMC sampling scheme. """ from openmm import unit - from chiron.mcmc import MetropolisDisplacementMove, MoveSet, MCMCSampler + from chiron.mcmc import MetropolisDisplacementMove, MoveSchedule, MCMCSampler # Initalize the testsystem from openmmtools.testsystems import HarmonicOscillatorArray @@ -228,7 +228,7 @@ def test_sample_from_harmonic_osciallator_array_with_MCMC_classes_and_Metropolis simulation_reporter=simulation_reporter, ) - move_set = MoveSet([("MetropolisDisplacementMove", mc_displacement_move)]) + move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) # Initalize the sampler sampler = MCMCSampler(move_set, sampler_state, thermodynamic_state) From 8e9a88f16cacccd99964a9d527b037b0c9370051 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 16:15:04 +0100 Subject: [PATCH 36/55] Refactor MCMC module and update imports --- chiron/mcmc.py | 53 ++++++++------------------------------------------ 1 file changed, 8 insertions(+), 45 deletions(-) diff --git a/chiron/mcmc.py b/chiron/mcmc.py index d303ec7..2cb7d75 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -1,50 +1,9 @@ -"""Markov chain Monte Carlo simulation framework. - -This module provides a framework for equilibrium sampling from a given -thermodynamic state of a biomolecule using a Markov chain Monte Carlo scheme. - -It currently offer supports for -* Langevin dynamics, -* Monte Carlo, - -which can be combined through the SequenceMove classes. - ->>> from chiron import unit ->>> from openmmtools.testsystems import AlanineDipeptideVacuum ->>> from chiron.states import ThermodynamicState, SamplerState ->>> from chiron.potential import NeuralNetworkPotential ->>> from modelforge.potential.pretrained_models import SchNetModel ->>> from chiron.mcmc import MCMCSampler, SequenceMove, MCMove, LangevinDynamicsMove - -Create the initial state for an alanine -dipeptide system in vacuum. - ->>> alanine_dipeptide = AlanineDipeptideVacuum() ->>> potential = NeuralNetworkPotential(SchNetModel, alanine_dipeptide.topology) ->>> thermodynamic_state = ThermodynamicState(temperature=298*unit.kelvin) ->>> simulation_state = SamplerState(positions=test.positions) - -Create an MCMC move to sample the equilibrium distribution. - ->>> langevin_move = LangevinDynamicsMove(n_steps=10) - ->>> mc_move = MCMove(timestep=1.0*unit.femtosecond, n_steps=50) ->>> sampler = MCMCSampler(state, move=ghmc_move) - -You can combine them to form a sequence of moves - ->>> sequence_move = SequenceMove([ghmc_move, langevin_move]) ->>> sampler = MCMCSampler(thermodynamic_state, sampler_state, move=sequence_move) - -""" from chiron.states import SamplerState, ThermodynamicState from openmm import unit -from loguru import logger as log from typing import Tuple, List, Optional import jax.numpy as jnp from chiron.reporters import SimulationReporter - class MCMCMove: def __init__(self, nr_of_moves: int, seed: int): """ @@ -303,6 +262,7 @@ def __init__( thermodynamic_state: ThermodynamicState, ): from copy import deepcopy + from loguru import logger as log log.info("Initializing Gibbs sampler") self.move = move_set @@ -318,6 +278,8 @@ def run(self, n_iterations: int = 1): n_iterations : int, optional Number of iterations of the sampler to run. """ + from loguru import logger as log + log.info("Running MCMC sampler") log.info(f"move_schedule = {self.move.move_schedule}") for iteration in range(n_iterations): @@ -369,6 +331,8 @@ def __init__( self.n_proposed = 0 self.atom_subset = atom_subset super().__init__(nr_of_moves=nr_of_moves, seed=seed) + from loguru import logger as log + log.debug(f"Atom subset is {atom_subset}.") @property @@ -405,6 +369,7 @@ def apply( Default is None and will use an unoptimized pairlist without PBC """ import jax.numpy as jnp + from loguru import logger as log # Compute initial energy initial_energy = thermodynamic_state.get_reduced_potential( @@ -553,6 +518,7 @@ def __init__( ------- None """ + from loguru import logger as log super().__init__(nr_of_moves=nr_of_moves, seed=seed) self.displacement_sigma = displacement_sigma @@ -586,17 +552,13 @@ def displace_positions( self.key, subkey = jrandom.split(self.key) nr_of_atoms = positions.shape[0] - # log.debug(f"Number of atoms is {nr_of_atoms}.") unitless_displacement_sigma = displacement_sigma.value_in_unit_system( unit.md_unit_system ) - # log.debug(f"Displacement sigma is {unitless_displacement_sigma}.") displacement_vector = ( jrandom.normal(subkey, shape=(nr_of_atoms, 3)) * 0.1 ) # NOTE: convert from Angstrom to nm scaled_displacement_vector = displacement_vector * unitless_displacement_sigma - # log.debug(f"Unscaled Displacement vector is {displacement_vector}.") - # log.debug(f"Scaled Displacement vector is {scaled_displacement_vector}.") updated_position = positions + scaled_displacement_vector return updated_position @@ -613,6 +575,7 @@ def run( progress_bar=True, ): from tqdm import tqdm + from loguru import logger as log for trials in ( tqdm(range(self.nr_of_moves)) if progress_bar else range(self.nr_of_moves) From 50d38df7a6c5e0062059fc658f8e55d3c4ae73b8 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 16:23:14 +0100 Subject: [PATCH 37/55] Remove unused imports and add missing imports --- chiron/integrators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chiron/integrators.py b/chiron/integrators.py index 1870867..cc08ae6 100644 --- a/chiron/integrators.py +++ b/chiron/integrators.py @@ -2,11 +2,8 @@ import jax.numpy as jnp from jax import random -from tqdm import tqdm from openmm import unit from .states import SamplerState, ThermodynamicState -from typing import Dict -from loguru import logger as log from .reporters import SimulationReporter from typing import Optional @@ -45,6 +42,7 @@ def __init__( reporter : SimulationReporter, optional Reporter object for saving the simulation data. Default is None. """ + from loguru import logger as log self.kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA log.info(f"stepsize = {stepsize}") @@ -101,6 +99,8 @@ def run( """ from .utils import get_list_of_mass + from tqdm import tqdm + from loguru import logger as log potential = thermodynamic_state.potential From fc4512147eba41899c42be1e126c747cfe49cd3d Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 16:28:02 +0100 Subject: [PATCH 38/55] move import in local scope where possible --- chiron/minimze.py | 16 ++++++++++------ chiron/neighbors.py | 4 ++-- chiron/potential.py | 17 ++++++++++------- chiron/states.py | 3 ++- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/chiron/minimze.py b/chiron/minimze.py index 0547e81..ea8fccf 100644 --- a/chiron/minimze.py +++ b/chiron/minimze.py @@ -1,16 +1,16 @@ -import jax -import jax.numpy as jnp -from jaxopt import GradientDescent -from loguru import logger as log +from typing import Callable +from jax import numpy as jnp -def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000): +def minimize_energy( + coordinates: jnp.array, potential_fn: Callable, nbr_list=None, maxiter: int = 1000 +): """ Minimize the potential energy of a system using JAXopt. Parameters ---------- - coordinates : jnp.ndarray + coordinates : jnp.array The initial coordinates of the system. potential_fn : callable The potential energy function of the system, which takes coordinates as input. @@ -24,6 +24,7 @@ def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000): jnp.ndarray The optimized coordinates. """ + from loguru import logger as log def objective_fn(x): if nbr_list is not None: @@ -33,6 +34,9 @@ def objective_fn(x): log.debug("Using NO neighbor list") return potential_fn(x) + from jaxopt import GradientDescent + import jax + optimizer = GradientDescent( fun=jax.value_and_grad(objective_fn), value_and_grad=True, maxiter=maxiter ) diff --git a/chiron/neighbors.py b/chiron/neighbors.py index 69a9502..9941e44 100644 --- a/chiron/neighbors.py +++ b/chiron/neighbors.py @@ -3,9 +3,8 @@ import jax import jax.numpy as jnp from functools import partial -from typing import Tuple, Optional, Union +from typing import Tuple, Union from .states import SamplerState -from loguru import logger as log from openmm import unit @@ -580,6 +579,7 @@ def build( ) self.neighbor_list = self.neighbor_list.reshape(-1, self.n_max_neighbors) + from loguru import logger as log while jnp.any(self.n_neighbors == self.n_max_neighbors).block_until_ready(): log.debug( diff --git a/chiron/potential.py b/chiron/potential.py index 14676dd..6d9b415 100644 --- a/chiron/potential.py +++ b/chiron/potential.py @@ -1,13 +1,13 @@ import jax import jax.numpy as jnp -from loguru import logger as log from openmm import unit from openmm.app import Topology -from typing import Optional class NeuralNetworkPotential: def __init__(self, model, **kwargs): + from loguru import logger as log + if model is None: log.warning("No model provided, using default model") else: @@ -166,6 +166,7 @@ def compute_energy(self, positions: jnp.array, nbr_list=None, debug_mode=False): """ # Compute the pair distances and displacement vectors + from loguru import logger as log if nbr_list is None: log.debug( @@ -319,11 +320,13 @@ def __init__( f"U0 must be a unit.Quantity with units of energy, U0.unit = {U0.unit}" ) - log.info("Initializing HarmonicOscillatorPotential") - log.info(f"k = {k}") - log.info(f"x0 = {x0}") - log.info(f"U0 = {U0}") - log.info("Energy is calculate: U(x) = (K/2) * ( (x-x0)^2 + y^2 + z^2 ) + U0") + from loguru import logger as log + + log.debug("Initializing HarmonicOscillatorPotential") + log.debug(f"k = {k}") + log.debug(f"x0 = {x0}") + log.debug(f"U0 = {U0}") + log.debug("Energy is calculate: U(x) = (K/2) * ( (x-x0)^2 + y^2 + z^2 ) + U0") self.k = jnp.array( k.value_in_unit_system(unit.md_unit_system) ) # spring constant diff --git a/chiron/states.py b/chiron/states.py index 8ba5e52..99ee8f9 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -1,7 +1,6 @@ from openmm import unit from typing import List, Optional, Union from jax import numpy as jnp -from loguru import logger as log from .potential import NeuralNetworkPotential @@ -204,6 +203,7 @@ def check_variables(self) -> None: def _check_completness(self): # check which variables are set set_variables = self.check_variables() + from loguru import logger as log if len(set_variables) == 0: log.info("No variables are set.") @@ -288,6 +288,7 @@ def calculate_reduced_potential_at_states( """ import numpy as np + from loguru import logger as log reduced_potentials = np.zeros(len(thermodynamic_states)) for state_idx, state in enumerate(thermodynamic_states): From c50ff8c4b2ae70470d27f328e9bbc6a51eccc4f2 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 18:13:45 +0100 Subject: [PATCH 39/55] Add save_traj_in_memory option to LangevinIntegrator and update MultiStateSampler constructor signature --- chiron/integrators.py | 2 ++ chiron/multistate.py | 50 +++++++++++++++++-------------------------- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/chiron/integrators.py b/chiron/integrators.py index cc08ae6..b9120fd 100644 --- a/chiron/integrators.py +++ b/chiron/integrators.py @@ -41,6 +41,8 @@ def __init__( Frequency of saving the simulation data. Default is 100. reporter : SimulationReporter, optional Reporter object for saving the simulation data. Default is None. + save_traj_in_memory : bool + Whether to save the trajectory in memory. For debugging purposes only. """ from loguru import logger as log diff --git a/chiron/multistate.py b/chiron/multistate.py index 9e55fd5..13b7798 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -1,8 +1,9 @@ -from typing import List, Optional +from typing import List, Optional, Union from chiron.states import SamplerState, ThermodynamicState from chiron.neighbors import NeighborListNsqrd from openmm import unit import numpy as np +from chiron.mcmc import MCMCMove class MultiStateSampler: @@ -15,29 +16,29 @@ class MultiStateSampler: If instantiated on its own, the thermodynamic state indices associated with each state are specified and replica mixing does not change any thermodynamic states, meaning that each replica remains in its original thermodynamic state. - - Parameters - ---------- - mcmc_moves : MCMCMove or list of MCMCMove, optional - The MCMCMove used to propagate the thermodynamic states. If a list of MCMCMoves, - they will be assigned to the correspondent thermodynamic state on - creation. If None is provided, Langevin dynamics with 2fm timestep, 5.0/ps collision rate, - and 500 steps per iteration will be used. - - Attributes - ---------- - n_replicas - n_states - mcmc_moves - sampler_states - is_completed """ def __init__( self, - mcmc_moves=None, + mcmc_moves=Union[MCMCMove, List[MCMCMove]], online_analysis_interval=5, ): + """ + Parameters + ---------- + mcmc_moves : MCMCMove or list of MCMCMove + The MCMCMove used to propagate the thermodynamic states. If a list of MCMCMoves, + they will be assigned to the correspondent thermodynamic state on + creation. + + Attributes + ---------- + n_replicas + n_states + mcmc_moves + sampler_states + is_completed + """ import copy from openmm import unit @@ -61,18 +62,7 @@ def __init__( self.free_energy_estimator = None self._traj = None - # Handling default propagator. - if mcmc_moves is None: - from .mcmc import LangevinDynamicsMove - - # This will be converted to a list in create(). - self._mcmc_moves = LangevinDynamicsMove( - timestep=2.0 * unit.femtosecond, - collision_rate=5.0 / unit.picosecond, - n_steps=500, - ) - else: - self._mcmc_moves = copy.deepcopy(mcmc_moves) + self._mcmc_moves = copy.deepcopy(mcmc_moves) self._last_mbar_f_k = None self._last_err_free_energy = None From c3bd752a3b4a9fb87e31412d289bdfdfd2357abe Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 18:28:46 +0100 Subject: [PATCH 40/55] Add MBAREstimator class and MultistateReporter class --- chiron/analyzer.py | 3 +++ chiron/reporters.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 chiron/analyzer.py diff --git a/chiron/analyzer.py b/chiron/analyzer.py new file mode 100644 index 0000000..eef0e25 --- /dev/null +++ b/chiron/analyzer.py @@ -0,0 +1,3 @@ +class MBAREstimator: + def __init__(self) -> None: + pass diff --git a/chiron/reporters.py b/chiron/reporters.py index b2fae35..6230a44 100644 --- a/chiron/reporters.py +++ b/chiron/reporters.py @@ -113,3 +113,23 @@ def get_mdtraj_trajectory(self): unitcell_lengths=self.get_property("box_vectors"), unitcell_angles=self.get_property("box_angles"), ) + + + +class MultistateReporter: + + def __init__(self, path_to_dir:str) -> None: + self.path_to_dir = path_to_dir + + def _write_trajectories(): + pass + + def _write_energies(): + pass + + def _write_states(): + pass + + + + \ No newline at end of file From 84a85be1be7a7716608700686608e9a711680bf0 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 18:46:56 +0100 Subject: [PATCH 41/55] Add MBAREstimator class for performing mbar analysis --- chiron/analysis.py | 18 ++++++++++++++++++ chiron/analyzer.py | 3 --- 2 files changed, 18 insertions(+), 3 deletions(-) create mode 100644 chiron/analysis.py delete mode 100644 chiron/analyzer.py diff --git a/chiron/analysis.py b/chiron/analysis.py new file mode 100644 index 0000000..06c54dc --- /dev/null +++ b/chiron/analysis.py @@ -0,0 +1,18 @@ +import numpy as np + + +class MBAREstimator: + def __init__(self, N_u: int) -> None: + self.mbar_f_k = np.zeros(len(N_u)) + + def initialize(self, u_kn: np.ndarray, N_k: np.ndarray): + """ + Perform mbar analysis + """ + from pymbar import MBAR + from loguru import logger as log + + log.debug(f"{N_k=}") + mbar = MBAR(u_kn=u_kn, N_k=N_k) + log.debug(mbar.f_k) + self.mbar_f_k = mbar.f_k diff --git a/chiron/analyzer.py b/chiron/analyzer.py deleted file mode 100644 index eef0e25..0000000 --- a/chiron/analyzer.py +++ /dev/null @@ -1,3 +0,0 @@ -class MBAREstimator: - def __init__(self) -> None: - pass From 6ce2d70d79749da4cb7cf79035eb46cd9c6ba78d Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 18:47:02 +0100 Subject: [PATCH 42/55] Remove online analysis and add offline estimator --- chiron/multistate.py | 46 +++++++++++--------------------------------- 1 file changed, 11 insertions(+), 35 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index 13b7798..40f5c0e 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -57,7 +57,6 @@ def __init__( self._n_proposed_matrix = None self._reporter = None self._metadata = None - self._online_analysis_interval = online_analysis_interval self._timing_data = dict() self.free_energy_estimator = None self._traj = None @@ -67,6 +66,11 @@ def __init__( self._last_mbar_f_k = None self._last_err_free_energy = None + self._online_estimator = None + + from chiron.analysis import MBAREstimator + self._offline_estimator = MBAREstimator() + @property def n_states(self): """The integer number of thermodynamic states (read-only).""" @@ -553,16 +557,14 @@ def run(self, n_iterations: int = 10) -> None: ] = self._energy_thermodynamic_states # TODO report energies - iteration_limit = n_iterations - # start the sampling loop - log.debug(f"{iteration_limit=}") - while not self._is_completed(iteration_limit): + log.debug(f"{n_iterations=}") + while not self._is_completed(n_iterations): # Increment iteration counter. self._iteration += 1 log.info("-" * 80) - log.info(f"Iteration {self._iteration}/{iteration_limit}") + log.info(f"Iteration {self._iteration}/{n_iterations}") log.info("-" * 80) # Update thermodynamic states @@ -598,34 +600,8 @@ def _update_analysis(self): """Update analysis of free energies""" from loguru import logger as log - if self._online_analysis_interval is None: - log.debug("No online analysis requested") - # Perform no analysis and exit function - return - # Perform offline free energy estimate if requested - if self.free_energy_estimator == "mbar": - self._last_err_free_energy = self._mbar_analysis() - - return - - def _mbar_analysis(self): - """ - Perform mbar analysis - """ - from pymbar import MBAR - from loguru import logger as log + if self._offline_estimator: + log.debug("Performing offline free energy estimate...") + self._offline_estimator.initialize(self._energy_thermodynamic_states_for_each_iteration_in_run) - self._last_mbar_f_k_offline = np.zeros(len(self._thermodynamic_states)) - - log.debug( - f"{self._energy_thermodynamic_states_for_each_iteration_in_run.shape=}" - ) - log.debug(f"{self.n_states=}") - u_kn = self._energy_thermodynamic_states_for_each_iteration_in_run - log.debug(f"{self._iteration=}") - N_k = [self._iteration] * self.n_states - log.debug(f"{N_k=}") - mbar = MBAR(u_kn=u_kn, N_k=N_k) - log.debug(mbar.f_k) - self._last_mbar_f_k_offline = mbar.f_k From 1006be5eb227afff3e3080d5b13ae51640b22e07 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 18:54:20 +0100 Subject: [PATCH 43/55] Fix initialization of MBAREstimator and update offline free energy estimation --- chiron/analysis.py | 2 +- chiron/multistate.py | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/chiron/analysis.py b/chiron/analysis.py index 06c54dc..6550901 100644 --- a/chiron/analysis.py +++ b/chiron/analysis.py @@ -3,7 +3,7 @@ class MBAREstimator: def __init__(self, N_u: int) -> None: - self.mbar_f_k = np.zeros(len(N_u)) + self.mbar_f_k = np.zeros(N_u) def initialize(self, u_kn: np.ndarray, N_k: np.ndarray): """ diff --git a/chiron/multistate.py b/chiron/multistate.py index 40f5c0e..713f1d0 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -66,11 +66,6 @@ def __init__( self._last_mbar_f_k = None self._last_err_free_energy = None - self._online_estimator = None - - from chiron.analysis import MBAREstimator - self._offline_estimator = MBAREstimator() - @property def n_states(self): """The integer number of thermodynamic states (read-only).""" @@ -190,10 +185,17 @@ def create( """ # TODO: initialize reporter here # TODO: consider unsampled thermodynamic states for reweighting schemes - self.free_energy_estimator = "mbar" + self._online_estimator = None + + from chiron.analysis import MBAREstimator + + n_thermodynamic_states = len(thermodynamic_states) + n_sampler_states = len(sampler_states) + + self._offline_estimator = MBAREstimator(N_u=n_thermodynamic_states) # Ensure the number of thermodynamic states matches the number of sampler states - if len(thermodynamic_states) != len(sampler_states): + if n_thermodynamic_states != n_sampler_states: raise RuntimeError( "Number of thermodynamic states and sampler states must be equal." ) @@ -603,5 +605,9 @@ def _update_analysis(self): # Perform offline free energy estimate if requested if self._offline_estimator: log.debug("Performing offline free energy estimate...") - self._offline_estimator.initialize(self._energy_thermodynamic_states_for_each_iteration_in_run) - + N_k = [self._iteration] * self.n_states + log.debug(f"{N_k=}") + self._offline_estimator.initialize( + u_kn=self._energy_thermodynamic_states_for_each_iteration_in_run, + N_k=N_k, + ) From 0d7c944c776b278d7154352d1600c4bd8e8eca4f Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 10 Jan 2024 10:18:16 +0100 Subject: [PATCH 44/55] Add MBAR class and update free energy estimators --- chiron/analysis.py | 19 ++++++++++++++++--- chiron/multistate.py | 19 ++++++++++++++++++- chiron/tests/test_multistate.py | 6 ++---- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/chiron/analysis.py b/chiron/analysis.py index 6550901..9608a81 100644 --- a/chiron/analysis.py +++ b/chiron/analysis.py @@ -4,6 +4,7 @@ class MBAREstimator: def __init__(self, N_u: int) -> None: self.mbar_f_k = np.zeros(N_u) + self.mbar = None def initialize(self, u_kn: np.ndarray, N_k: np.ndarray): """ @@ -13,6 +14,18 @@ def initialize(self, u_kn: np.ndarray, N_k: np.ndarray): from loguru import logger as log log.debug(f"{N_k=}") - mbar = MBAR(u_kn=u_kn, N_k=N_k) - log.debug(mbar.f_k) - self.mbar_f_k = mbar.f_k + self.mbar = MBAR(u_kn=u_kn, N_k=N_k) + + @property + def f_k(self): + from loguru import logger as log + + log.debug(self.mbar.f_k) + return self.mbar.f_k + + def get_free_energy_difference(self): + from loguru import logger as log + + log.debug(self.mbar.f_k[-1]) + self.f_k = self.mbar.f_k + return self.mbar_f_k[-1] diff --git a/chiron/multistate.py b/chiron/multistate.py index 713f1d0..6a0ba75 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -606,8 +606,25 @@ def _update_analysis(self): if self._offline_estimator: log.debug("Performing offline free energy estimate...") N_k = [self._iteration] * self.n_states - log.debug(f"{N_k=}") self._offline_estimator.initialize( u_kn=self._energy_thermodynamic_states_for_each_iteration_in_run, N_k=N_k, ) + elif self._online_estimator: + log.debug("Performing online free energy estimate...") + self._online_estimator.update( + u_kn=self._energy_thermodynamic_states_for_each_iteration_in_run[ + :, :, self._iteration + ] + ) + else: + raise RuntimeError("No free energy estimator provided.") + + @property + def f_k(self): + if self._offline_estimator: + return self._offline_estimator.f_k + elif self._online_estimator: + return self._online_estimator.f_k + else: + raise RuntimeError("No free energy estimator found.") diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index ce74e1d..c316046 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -210,8 +210,6 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): # check that the free energies are correct print(ho_sampler.analytical_f_i) print(ho_sampler.delta_f_ij_analytical) - print(ho_sampler._last_mbar_f_k_offline) + print(ho_sampler.f_k) - assert np.allclose( - ho_sampler.delta_f_ij_analytical[0], ho_sampler._last_mbar_f_k_offline, atol=0.1 - ) + assert np.allclose(ho_sampler.delta_f_ij_analytical[0], ho_sampler.f_k, atol=0.1) From f3340a3b3a9037aca2a6bd8a5e1e1efcdec3986e Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 10 Jan 2024 10:27:41 +0100 Subject: [PATCH 45/55] Refactor MBAREstimator class in analysis.py --- chiron/analysis.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/chiron/analysis.py b/chiron/analysis.py index 9608a81..47fdb86 100644 --- a/chiron/analysis.py +++ b/chiron/analysis.py @@ -2,13 +2,27 @@ class MBAREstimator: - def __init__(self, N_u: int) -> None: - self.mbar_f_k = np.zeros(N_u) + def __init__(self) -> None: + """ + Initialize the MBAR analysis class. + + Returns: + - None + """ + self.mbar_f_k = None self.mbar = None def initialize(self, u_kn: np.ndarray, N_k: np.ndarray): """ - Perform mbar analysis + Initialize the analysis object. + + Parameters + ---------- + u_kn: np.ndarray + Array of dimensionless reduced potentials for each state. + N_k: np.ndarray + Array of number of samples for each state. + """ from pymbar import MBAR from loguru import logger as log @@ -18,12 +32,27 @@ def initialize(self, u_kn: np.ndarray, N_k: np.ndarray): @property def f_k(self): + """ + Free energy for each state. + + Returns + ------- + mbar.f_k. + """ + from loguru import logger as log log.debug(self.mbar.f_k) return self.mbar.f_k def get_free_energy_difference(self): + """ + Calculate the free energy difference between the endstates. + + Returns + ------- + float + """ from loguru import logger as log log.debug(self.mbar.f_k[-1]) From ab0d114ccc03070e350f46ab4e19eee2b7eaaf4e Mon Sep 17 00:00:00 2001 From: Marcus Wieder <31651017+wiederm@users.noreply.github.com> Date: Thu, 18 Jan 2024 14:42:14 +0100 Subject: [PATCH 46/55] Adding Simulation and MultistateSimulation reporter (#18) * Add MBAREstimator class and MultistateReporter class * Add MBAR class and update free energy estimators * Update LangevinIntegrator class in integrators.py * Update MCMCMove and LangevinDynamicsMove constructors * Fix MBAREstimator initialization in MultiStateSampler * Refactor reporters and tests * Add new reporters and update tests * Wrap and rebuild neighborlist in LangevinIntegrator * Refactor code to transpose u_kn array in MultiStateSampler and _SimulationReporter * Refactor _SimulationReporter class to improve code readability and maintainability * Refactor code and add random seed functionality * Fix reporter visibility and add test for multistate reporter --- Examples/LJ_langevin.py | 6 +- Examples/LJ_mcmove.py | 6 +- .../data/{test_md.h5 => langevin_reporter.h5} | Bin chiron/integrators.py | 131 ++++-- chiron/mcmc.py | 163 ++++--- chiron/multistate.py | 444 ++++++++++-------- chiron/reporters.py | 411 +++++++++++++--- chiron/states.py | 15 + chiron/tests/data/langevin_reporter.h5 | Bin 0 -> 12872 bytes chiron/tests/test_convergence_tests.py | 12 +- chiron/tests/test_integrators.py | 18 +- chiron/tests/test_mcmc.py | 79 ++-- chiron/tests/test_minization.py | 11 +- chiron/tests/test_multistate.py | 27 +- chiron/tests/test_pairs.py | 28 +- chiron/tests/test_potential.py | 94 ++-- chiron/tests/test_states.py | 35 +- chiron/tests/test_testsystems.py | 4 + chiron/tests/test_utils.py | 135 ++++-- chiron/utils.py | 35 ++ 20 files changed, 1151 insertions(+), 503 deletions(-) rename chiron/data/{test_md.h5 => langevin_reporter.h5} (100%) create mode 100644 chiron/tests/data/langevin_reporter.h5 diff --git a/Examples/LJ_langevin.py b/Examples/LJ_langevin.py index e2d66ca..d769b1a 100644 --- a/Examples/LJ_langevin.py +++ b/Examples/LJ_langevin.py @@ -43,7 +43,7 @@ # build the neighbor list from the sampler state nbr_list.build_from_state(sampler_state) -from chiron.reporters import SimulationReporter +from chiron.reporters import _SimulationReporter # initialize a reporter to save the simulation data filename = "test_lj.h5" @@ -51,12 +51,12 @@ if os.path.isfile(filename): os.remove(filename) -reporter = SimulationReporter("test_lj.h5", lj_fluid.topology, 1) +reporter = _SimulationReporter("test_lj.h5", lj_fluid.topology, 1) from chiron.integrators import LangevinIntegrator # initialize the Langevin integrator -integrator = LangevinIntegrator(reporter=reporter, save_frequency=100) +integrator = LangevinIntegrator(reporter=reporter, report_frequency=100) print("init_energy: ", lj_potential.compute_energy(sampler_state.x0, nbr_list)) integrator.run( diff --git a/Examples/LJ_mcmove.py b/Examples/LJ_mcmove.py index 0ed407e..bc673f6 100644 --- a/Examples/LJ_mcmove.py +++ b/Examples/LJ_mcmove.py @@ -45,7 +45,7 @@ # build the neighbor list from the sampler state nbr_list.build_from_state(sampler_state) -from chiron.reporters import SimulationReporter +from chiron.reporters import _SimulationReporter # initialize a reporter to save the simulation data filename = "test_lj.h5" @@ -53,7 +53,7 @@ if os.path.isfile(filename): os.remove(filename) -reporter = SimulationReporter("test_mc_lj.h5", lj_fluid.topology, 1) +reporter = _SimulationReporter("test_mc_lj.h5", lj_fluid.topology, 1) from chiron.mcmc import MetropolisDisplacementMove @@ -61,7 +61,7 @@ seed=1234, displacement_sigma=0.01 * unit.nanometer, nr_of_moves=1000, - simulation_reporter=reporter, + reporter=reporter, ) mc_move.run(sampler_state, thermodynamic_state, nbr_list, True) diff --git a/chiron/data/test_md.h5 b/chiron/data/langevin_reporter.h5 similarity index 100% rename from chiron/data/test_md.h5 rename to chiron/data/langevin_reporter.h5 diff --git a/chiron/integrators.py b/chiron/integrators.py index b9120fd..0d77452 100644 --- a/chiron/integrators.py +++ b/chiron/integrators.py @@ -4,8 +4,10 @@ from jax import random from openmm import unit from .states import SamplerState, ThermodynamicState -from .reporters import SimulationReporter +from .reporters import LangevinDynamicsReporter from typing import Optional +from .potential import NeuralNetworkPotential +from .neighbors import PairsBase class LangevinIntegrator: @@ -24,8 +26,8 @@ def __init__( self, stepsize=1.0 * unit.femtoseconds, collision_rate=1.0 / unit.picoseconds, - save_frequency: int = 100, - reporter: Optional[SimulationReporter] = None, + report_frequency: int = 100, + reporter: Optional[LangevinDynamicsReporter] = None, save_traj_in_memory: bool = False, ) -> None: """ @@ -37,29 +39,33 @@ def __init__( Time step of integration with units of time. Default is 1.0 * unit.femtoseconds. collision_rate : unit.Quantity, optional Collision rate for the Langevin dynamics, with units 1/time. Default is 1.0 / unit.picoseconds. - save_frequency : int, optional + report_frequency : int, optional Frequency of saving the simulation data. Default is 100. reporter : SimulationReporter, optional Reporter object for saving the simulation data. Default is None. - save_traj_in_memory : bool - Whether to save the trajectory in memory. For debugging purposes only. + save_traj_in_memory: bool + Flag indicating whether to save the trajectory in memory. + Default is False. NOTE: Only for debugging purposes. """ from loguru import logger as log self.kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA log.info(f"stepsize = {stepsize}") log.info(f"collision_rate = {collision_rate}") - log.info(f"save_frequency = {save_frequency}") + log.info(f"report_frequency = {report_frequency}") self.stepsize = stepsize self.collision_rate = collision_rate - if reporter is not None: - log.info(f"Using reporter {reporter} saving to {reporter.filename}") + if reporter: + log.info( + f"Using reporter {reporter} saving trajectory to {reporter.xtc_file_path}" + ) + log.info(f"and logging to {reporter.log_file_path}") self.reporter = reporter - self.save_frequency = save_frequency + self.report_frequency = report_frequency + self.velocities = None self.save_traj_in_memory = save_traj_in_memory self.traj = [] - self.velocities = None def set_velocities(self, vel: unit.Quantity) -> None: """ @@ -77,8 +83,7 @@ def run( sampler_state: SamplerState, thermodynamic_state: ThermodynamicState, n_steps: int = 5_000, - key=random.PRNGKey(0), - nbr_list=None, + nbr_list: Optional[PairsBase] = None, progress_bar=False, ): """ @@ -92,9 +97,7 @@ def run( The thermodynamic state of the system, including temperature and potential. n_steps : int, optional Number of simulation steps to perform. - key : jax.random.PRNGKey, optional - Random key for generating random numbers. - nbr_list : NeighborListNsqrd, optional + nbr_list : PairBase, optional Neighbor list for the system. progress_bar : bool, optional Flag indicating whether to display a progress bar during integration. @@ -116,8 +119,11 @@ def run( log.debug("Running Langevin dynamics") log.debug(f"n_steps = {n_steps}") log.debug(f"temperature = {temperature}") - log.debug(f"Using seed: {key}") + # Initialize the random number generator + key = sampler_state.new_PRNG_key + + # Convert to dimensionless quantities kbT_unitless = (self.kB * temperature).value_in_unit_system(unit.md_unit_system) mass_unitless = jnp.array(mass.value_in_unit_system(unit.md_unit_system))[ :, None @@ -127,22 +133,24 @@ def run( collision_rate_unitless = self.collision_rate.value_in_unit_system( unit.md_unit_system ) + a = jnp.exp((-collision_rate_unitless * stepsize_unitless)) + b = jnp.sqrt(1 - jnp.exp(-2 * collision_rate_unitless * stepsize_unitless)) # Initialize velocities if self.velocities is None: v0 = sigma_v * random.normal(key, x0.shape) else: v0 = self.velocities.value_in_unit_system(unit.md_unit_system) - # Convert to dimensionless quantities - a = jnp.exp((-collision_rate_unitless * stepsize_unitless)) - b = jnp.sqrt(1 - jnp.exp(-2 * collision_rate_unitless * stepsize_unitless)) x = x0 v = v0 + if nbr_list is not None: nbr_list.build_from_state(sampler_state) F = potential.compute_force(x, nbr_list) + + # propagation loop for step in tqdm(range(n_steps)) if self.progress_bar else range(n_steps): key, subkey = random.split(key) # v @@ -151,46 +159,79 @@ def run( x += (stepsize_unitless * 0.5) * v if nbr_list is not None: - x = nbr_list.space.wrap(x) - # check if we need to rebuild the neighborlist after moving the particles - if nbr_list.check(x): - nbr_list.build(x, self.box_vectors) + x = self._wrap_and_rebuild_neighborlist(x, nbr_list) # o random_noise_v = random.normal(subkey, x.shape) v = (a * v) + (b * sigma_v * random_noise_v) x += (stepsize_unitless * 0.5) * v if nbr_list is not None: - x = nbr_list.space.wrap(x) - # check if we need to rebuild the neighborlist after moving the particles - if nbr_list.check(x): - nbr_list.build(x, self.box_vectors) + x = self._wrap_and_rebuild_neighborlist(x, nbr_list) F = potential.compute_force(x, nbr_list) # v v += (stepsize_unitless * 0.5) * F / mass_unitless - if step % self.save_frequency == 0: - # log.debug(f"Saving at step {step}") - # check if reporter is attribute of the class - # log.debug(f"step {step} energy {potential.compute_energy(x, nbr_list)}") - # log.debug(f"step {step} force {F}") - + if step % self.report_frequency == 0: if hasattr(self, "reporter") and self.reporter is not None: - d = { - "traj": x, - "energy": potential.compute_energy(x, nbr_list), - "step": step, - } - if nbr_list is not None: - d["box_vectors"] = nbr_list.space.box_vectors - - # log.debug(d) - self.reporter.report(d) + self._report(x, potential, nbr_list, step) + if self.save_traj_in_memory: self.traj.append(x) log.debug("Finished running Langevin dynamics") # save the final state of the simulation in the sampler_state object sampler_state.x0 = x - # self.reporter.close() + sampler_state.v0 = v + + def _wrap_and_rebuild_neighborlist(self, x: jnp.array, nbr_list: PairsBase): + """ + Wrap the coordinates and rebuild the neighborlist if necessary. + + Parameters + ---------- + x: jnp.array + The coordinates of the particles. + nbr_list: PairsBsse + The neighborlist object. + """ + + x = nbr_list.space.wrap(x) + # check if we need to rebuild the neighborlist after moving the particles + if nbr_list.check(x): + nbr_list.build(x, self.box_vectors) + return x + + def _report( + self, + x: jnp.array, + potential: NeuralNetworkPotential, + nbr_list: PairsBase, + step: int, + ): + """ + Reports the trajectory, energy, step, and box vectors (if available) to the reporter. + + Parameters + ---------- + x : jnp.array + current coordinate set + potential: NeuralNetworkPotential + potential used to compute the energy and force + nbr_list: PairsBase + The neighbor list + step: int + The current time step. + + Returns: + None + """ + d = { + "positions": x, + "potential_energy": potential.compute_energy(x, nbr_list), + "step": step, + } + if nbr_list is not None: + d["box_vectors"] = nbr_list.space.box_vectors + + self.reporter.report(d) diff --git a/chiron/mcmc.py b/chiron/mcmc.py index 2cb7d75..8063012 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -2,10 +2,16 @@ from openmm import unit from typing import Tuple, List, Optional import jax.numpy as jnp -from chiron.reporters import SimulationReporter +from chiron.reporters import LangevinDynamicsReporter, _SimulationReporter + class MCMCMove: - def __init__(self, nr_of_moves: int, seed: int): + def __init__( + self, + nr_of_moves: int, + reporter: Optional[_SimulationReporter] = None, + report_frequency: Optional[int] = 100, + ): """ Initialize a move within the molecular system. @@ -13,13 +19,22 @@ def __init__(self, nr_of_moves: int, seed: int): ---------- nr_of_moves : int Number of moves to be applied. - seed : int - Seed for random number generation. + reporter : _SimulationReporter, optional + Reporter object for saving the simulation data. + Default is None. + report_frequency : int, optional """ - import jax.random as jrandom self.nr_of_moves = nr_of_moves - self.key = jrandom.PRNGKey(seed) # 'seed' is an integer seed value + self.reporter = reporter + self.report_frequency = report_frequency + from loguru import logger as log + + if self.reporter is not None: + log.info( + f"Using reporter {self.reporter} saving to {self.reporter.workdir}" + ) + assert self.report_frequency is not None class LangevinDynamicsMove(MCMCMove): @@ -27,9 +42,9 @@ def __init__( self, stepsize=1.0 * unit.femtoseconds, collision_rate=1.0 / unit.picoseconds, - simulation_reporter: Optional[SimulationReporter] = None, + reporter: Optional[LangevinDynamicsReporter] = None, + report_frequency: int = 100, nr_of_steps=1_000, - seed: int = 1234, save_traj_in_memory: bool = False, ): """ @@ -41,13 +56,27 @@ def __init__( Time step size for the integration. collision_rate : unit.Quantity Collision rate for the Langevin dynamics. - nr_of_steps : int + reporter : LangevinDynamicsReporter, optional + Reporter object for saving the simulation data. + Default is None. + report_frequency : int + Frequency of saving the simulation data. + Default is 100. + nr_of_steps : int, optional Number of steps to run the integrator for. + Default is 1_000. + save_traj_in_memory: bool + Flag indicating whether to save the trajectory in memory. + Default is False. NOTE: Only for debugging purposes. """ - super().__init__(nr_of_steps, seed) + super().__init__( + nr_of_moves=nr_of_steps, + reporter=reporter, + report_frequency=report_frequency, + ) + self.stepsize = stepsize self.collision_rate = collision_rate - self.simulation_reporter = simulation_reporter self.save_traj_in_memory = save_traj_in_memory self.traj = [] from chiron.integrators import LangevinIntegrator @@ -55,7 +84,8 @@ def __init__( self.integrator = LangevinIntegrator( stepsize=self.stepsize, collision_rate=self.collision_rate, - reporter=self.simulation_reporter, + report_frequency=report_frequency, + reporter=reporter, save_traj_in_memory=save_traj_in_memory, ) @@ -67,8 +97,12 @@ def run( """ Run the integrator to perform molecular dynamics simulation. - Args: - state_variables (StateVariablesCollection): State variables of the system. + Parameters + ---------- + sampler_state : SamplerState + The sampler state to run the integrator on. + thermodynamic_state : ThermodynamicState + The thermodynamic state to run the integrator on. """ assert isinstance( @@ -82,38 +116,18 @@ def run( thermodynamic_state=thermodynamic_state, sampler_state=sampler_state, n_steps=self.nr_of_moves, - key=self.key, ) + if self.save_traj_in_memory: self.traj.append(self.integrator.traj) self.integrator.traj = [] class MCMove(MCMCMove): - def __init__(self, nr_of_moves: int, seed: int) -> None: - super().__init__(nr_of_moves, seed) - - def _check_state_compatiblity( - self, - old_state: SamplerState, - new_state: SamplerState, - ): - """ - Check if the states are compatible. - - Parameters - ---------- - old_state : StateVariablesCollection - The state of the system before the move. - new_state : StateVariablesCollection - The state of the system after the move. - - Raises - ------ - ValueError - If the states are not compatible. - """ - pass + def __init__( + self, nr_of_moves: int, reporter: Optional[_SimulationReporter] + ) -> None: + super().__init__(nr_of_moves, reporter=reporter) def apply_move(self): """ @@ -239,9 +253,9 @@ def _validate_sequence(self): raise ValueError(f"Move {move_name} in the sequence is not available.") -class MCMCSampler(object): +class MCMCSampler: """ - Basic Markov chain Monte Carlo Gibbs sampler. + Basic Markov chain Monte Carlo sampler. Parameters ---------- @@ -291,9 +305,13 @@ def run(self, n_iterations: int = 1): log.info("Finished running MCMC sampler") log.debug("Closing reporter") for _, move in self.move.move_schedule: - if move.simulation_reporter is not None: - move.simulation_reporter.close() - log.debug(f"Closed reporter {move.simulation_reporter.filename}") + if move.reporter is not None: + move.reporter.flush_buffer() + # TODO: flush reporter + log.debug(f"Closed reporter {move.reporter.log_file_path}") + + +from .neighbors import PairsBase class MetropolizedMove(MCMove): @@ -323,16 +341,18 @@ class MetropolizedMove(MCMove): def __init__( self, - seed: int = 1234, atom_subset: Optional[List[int]] = None, nr_of_moves: int = 100, + reporter: Optional[_SimulationReporter] = None, + report_frequency: int = 1, ): self.n_accepted = 0 self.n_proposed = 0 self.atom_subset = atom_subset - super().__init__(nr_of_moves=nr_of_moves, seed=seed) + super().__init__(nr_of_moves=nr_of_moves, reporter=reporter) from loguru import logger as log + self.report_frequency = report_frequency log.debug(f"Atom subset is {atom_subset}.") @property @@ -349,8 +369,7 @@ def apply( self, thermodynamic_state: ThermodynamicState, sampler_state: SamplerState, - reporter: SimulationReporter, - nbr_list=None, + nbr_list=Optional[PairsBase], ): """Apply a metropolized move to the sampler state. @@ -362,8 +381,6 @@ def apply( The thermodynamic state to use to apply the move. sampler_state : SamplerState The initial sampler state to apply the move to. This is modified. - reporter: SimulationReporter - The reporter to write the data to. nbr_list: Neighbor List or Pair List routine, The routine to use to calculate the interacting atoms. Default is None and will use an unoptimized pairlist without PBC @@ -376,9 +393,8 @@ def apply( sampler_state, nbr_list ) # NOTE: in kT log.debug(f"Initial energy is {initial_energy} kT.") - # Store initial positions of the atoms that are moved. - # We'll use this also to recover in case the move is rejected. + # Store initial positions of the atoms that are moved. x0 = sampler_state.x0 atom_subset = self.atom_subset if atom_subset is None: @@ -420,15 +436,14 @@ def apply( log.debug( f"Move accepted. Energy change: {delta_energy:.3f} kT. Number of accepted moves: {self.n_accepted}." ) - reporter.report( - { - "energy": thermodynamic_state.kT_to_kJ_per_mol( - proposed_energy - ).value_in_unit_system(unit.md_unit_system), - "step": self.n_proposed, - "traj": sampler_state.x0, - } - ) + if self.n_proposed % self.report_frequency == 0: + self.reporter.report( + { + "energy": proposed_energy, # in kT + "step": self.n_proposed, + "traj": sampler_state.x0, + } + ) else: # Restore original positions. if atom_subset is None: @@ -493,11 +508,10 @@ class MetropolisDisplacementMove(MetropolizedMove): def __init__( self, - seed: int = 1234, displacement_sigma=1.0 * unit.nanometer, nr_of_moves: int = 100, atom_subset: Optional[List[int]] = None, - simulation_reporter: Optional[SimulationReporter] = None, + reporter: Optional[LangevinDynamicsReporter] = None, ): """ Initialize the MCMC class. @@ -512,22 +526,16 @@ def __init__( The number of moves to perform. Default is 100. atom_subset : list of int, optional A subset of atom indices to consider for the moves. Default is None. - simulation_reporter : SimulationReporter, optional + reporter : SimulationReporter, optional The reporter to write the data to. Default is None. Returns ------- None """ - from loguru import logger as log - - super().__init__(nr_of_moves=nr_of_moves, seed=seed) + super().__init__(nr_of_moves=nr_of_moves, reporter=reporter) self.displacement_sigma = displacement_sigma self.atom_subset = atom_subset - self.simulation_reporter = simulation_reporter - if self.simulation_reporter is not None: - log.info( - f"Using reporter {self.simulation_reporter} saving to {self.simulation_reporter.filename}" - ) + self.key = None def displace_positions( self, positions: jnp.array, displacement_sigma=1.0 * unit.nanometer @@ -576,17 +584,18 @@ def run( ): from tqdm import tqdm from loguru import logger as log + from jax import random + + self.key = sampler_state.new_PRNG_key for trials in ( tqdm(range(self.nr_of_moves)) if progress_bar else range(self.nr_of_moves) ): - self.apply( - thermodynamic_state, sampler_state, self.simulation_reporter, nbr_list - ) + self.apply(thermodynamic_state, sampler_state, nbr_list) if trials % 100 == 0: log.debug(f"Acceptance rate: {self.n_accepted / self.n_proposed}") - if self.simulation_reporter is not None: - self.simulation_reporter.report( + if self.reporter is not None: + self.reporter.report( { "Acceptance rate": self.n_accepted / self.n_proposed, "step": self.n_proposed, diff --git a/chiron/multistate.py b/chiron/multistate.py index 6a0ba75..b9100ec 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -4,43 +4,65 @@ from openmm import unit import numpy as np from chiron.mcmc import MCMCMove +from chiron.reporters import MultistateReporter class MultiStateSampler: """ - Base class for samplers that sample multiple thermodynamic states using - one or more replicas. + A sampler for simulating multiple thermodynamic states using replicas. - This base class provides a general simulation facility for multistate from multiple - thermodynamic states, allowing any set of thermodynamic states to be specified. - If instantiated on its own, the thermodynamic state indices associated with each - state are specified and replica mixing does not change any thermodynamic states, + This class provides a general simulation facility for sampling from multiple + thermodynamic states. It allows specifying any set of thermodynamic states. + If instantiated on its own, the thermodynamic state indices associated with + each state are specified, and replica mixing does not change any thermodynamic states, meaning that each replica remains in its original thermodynamic state. + + Attributes + ---------- + n_states : int + Number of thermodynamic states (read-only). + n_replicas : int + Number of replicas (read-only). + iteration : int + Current iteration of the simulation (read-only). + mcmc_moves : List[MCMCMove] + MCMC moves used to propagate the simulation. + sampler_states : List[SamplerState] + Sampler states list at the current iteration. + is_periodic : bool + True if system is periodic, False if not, None if not initialized. + is_completed : bool + Check if the sampler has reached any stop target criteria (read-only). + + Methods + ------- + create(thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_list: NeighborListNsqrd) + Creates a new multistate sampler simulation. + minimize(tolerance: unit.Quantity = 1.0 * unit.kilojoules_per_mole / unit.nanometers, max_iterations: int = 1000) + Minimizes all replicas in the sampler. + run(n_iterations: int = 10) + Executes the replica-exchange simulation for a specified number of iterations. + """ def __init__( self, - mcmc_moves=Union[MCMCMove, List[MCMCMove]], - online_analysis_interval=5, + mcmc_moves: Union[MCMCMove, List[MCMCMove]], + reporter: MultistateReporter, ): """ - Parameters - ---------- - mcmc_moves : MCMCMove or list of MCMCMove - The MCMCMove used to propagate the thermodynamic states. If a list of MCMCMoves, - they will be assigned to the correspondent thermodynamic state on - creation. + Initialize the MultiStateSampler. - Attributes + Parameters ---------- - n_replicas - n_states - mcmc_moves - sampler_states - is_completed + mcmc_moves : Union[MCMCMove, List[MCMCMove]] + The MCMCMove or list of MCMCMoves used to propagate the thermodynamic states. + reporter : MultistateReporter + The reporter used to store the simulation data. """ + import copy - from openmm import unit + from chiron.analysis import MBAREstimator # These will be set on initialization. See function # create() for explanation of single variables. @@ -50,33 +72,40 @@ def __init__( self._replica_thermodynamic_states = None self._iteration = None self._energy_thermodynamic_states = None - self._energy_thermodynamic_states_for_each_iteration = None self._neighborhoods = None - self._energy_unsampled_states = None self._n_accepted_matrix = None self._n_proposed_matrix = None - self._reporter = None + self._reporter = reporter # NOTE: reporter needs to be putlic, API change ahead self._metadata = None - self._timing_data = dict() - self.free_energy_estimator = None - self._traj = None - self._mcmc_moves = copy.deepcopy(mcmc_moves) - - self._last_mbar_f_k = None - self._last_err_free_energy = None + self._online_estimator = None + self._offline_estimator = MBAREstimator() @property - def n_states(self): - """The integer number of thermodynamic states (read-only).""" + def n_states(self) -> int: + """ + Get the number of thermodynamic states in the sampler. + + Returns + ------- + int + The number of thermodynamic states. + """ if self._thermodynamic_states is None: return 0 else: return len(self._thermodynamic_states) @property - def n_replicas(self): - """The integer number of replicas (read-only).""" + def n_replicas(self) -> int: + """ + Get the number of replicas in the sampler. + + Returns + ------- + int + The number of replicas. + """ if self._sampler_states is None: return 0 else: @@ -103,29 +132,37 @@ def mcmc_moves(self): return copy.deepcopy(self._mcmc_moves) @property - def sampler_states(self): - """A copy of the sampler states list at the current iteration. + def sampler_states(self) -> Optional[List[SamplerState]]: + """ + Get a copy of the sampler states list at the current iteration. + + This property can only be set before running the simulation. - This can be set only before running. + Returns + ------- + Optional[List[SamplerState]] + The list of sampler states at the current iteration, or None if not set. """ + if self._sampler_states is None: + return None import copy return copy.deepcopy(self._sampler_states) @property def is_periodic(self): - """Return True if system is periodic, False if not, and None if not initialized""" + """ + Determine if the system is periodic. + + Returns + ------- + Optional[bool] + True if the system is periodic, False if not, and None if not initialized. + """ if self._sampler_states is None: return None return self._thermodynamic_states[0].is_periodic - @property - def metadata(self): - """A copy of the metadata dictionary passed on creation (read-only).""" - import copy - - return copy.deepcopy(self._metadata) - @property def is_completed(self): """Check if we have reached any of the stop target criteria (read-only)""" @@ -133,76 +170,65 @@ def is_completed(self): def _compute_replica_energies(self, replica_id: int) -> np.ndarray: """ - Compute the energy for the replica in every ThermodynamicState. + Compute the energy of a replica across all thermodynamic states. Parameters ---------- replica_id : int - The ID of the replica to compute energies for. + The index of the replica for which to compute energies. Returns ------- np.ndarray - Array of energies for the specified replica across all thermodynamic states. + An array of energies for the replica across all thermodynamic states. """ - import jax.numpy as jnp from chiron.states import calculate_reduced_potential_at_states - # Only compute energies of the sampled states over neighborhoods. - thermodynamic_states = [ - self._thermodynamic_states[n] for n in range(self.n_states) - ] # Retrieve sampler state associated to this replica. sampler_state = self._sampler_states[replica_id] # Compute energy for all thermodynamic states. - return calculate_reduced_potential_at_states( - sampler_state, thermodynamic_states, self.nbr_list + energies = calculate_reduced_potential_at_states( + sampler_state, self._thermodynamic_states, self.nbr_list ) + return energies def create( self, thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_list: NeighborListNsqrd, - metadata: Optional[dict] = None, ): - """Create new multistate sampler simulation. + """ + Create a new multistate sampler simulation. + Parameters + ---------- thermodynamic_states : List[ThermodynamicState] - List of ThermodynamicStates to simulate, with one replica allocated per state. + List of ThermodynamicStates to simulate, with one replica per state. sampler_states : List[SamplerState] - List of initial SamplerStates. The number of replicas is taken to be the number - of sampler states provided. + List of initial SamplerStates. The number of states is the number of replicas. nbr_list : NeighborListNsqrd - Neighbor list object to be used in the simulation. - metadata : dict, optional - Optional simulation metadata to be stored in the file. + Neighbor list object for the simulation. Raises ------ RuntimeError - If the lengths of thermodynamic_states and sampler_states are not equal. + If the lengths of `thermodynamic_states` and `sampler_states` are not equal. """ - # TODO: initialize reporter here - # TODO: consider unsampled thermodynamic states for reweighting schemes - self._online_estimator = None - from chiron.analysis import MBAREstimator - - n_thermodynamic_states = len(thermodynamic_states) - n_sampler_states = len(sampler_states) + self._online_estimator = None - self._offline_estimator = MBAREstimator(N_u=n_thermodynamic_states) + from chiron.reporters import MultistateReporter # Ensure the number of thermodynamic states matches the number of sampler states - if n_thermodynamic_states != n_sampler_states: + if len(thermodynamic_states) != len(sampler_states): raise RuntimeError( "Number of thermodynamic states and sampler states must be equal." ) - self._allocate_variables(thermodynamic_states, sampler_states) self.nbr_list = nbr_list - self._reporter = None + self._allocate_variables(thermodynamic_states, sampler_states) + self._reporter = MultistateReporter() def _allocate_variables( self, @@ -218,10 +244,6 @@ def _allocate_variables( A list of ThermodynamicState objects to be used in the sampler. sampler_states : List[SamplerState] A list of SamplerState objects for initializing the sampler. - unsampled_thermodynamic_states : Optional[List[ThermodynamicState]], optional - A list of additional ThermodynamicState objects that are not directly sampled but - for which energies will be computed for reweighting schemes. Defaults to None, - meaning no unsampled states are considered. Raises ------ @@ -231,40 +253,21 @@ def _allocate_variables( import copy import numpy as np - # Save thermodynamic states. This sets n_replicas. - self._thermodynamic_states = [ - copy.deepcopy(thermodynamic_state) - for thermodynamic_state in thermodynamic_states - ] - - # Deep copy sampler states. - self._sampler_states = [ - copy.deepcopy(sampler_state) for sampler_state in sampler_states - ] - + self._thermodynamic_states = copy.deepcopy(thermodynamic_states) + self._sampler_states = sampler_states assert len(self._thermodynamic_states) == len(self._sampler_states) - # Set initial thermodynamic state indices - initial_thermodynamic_states = np.arange( - len(self._thermodynamic_states), dtype=int - ) - self._replica_thermodynamic_states = np.array( - initial_thermodynamic_states, np.int64 + self._replica_thermodynamic_states = np.arange( + len(thermodynamic_states), dtype=int ) - # Reset statistics. - - # _n_accepted_matrix[i][j] is the number of swaps proposed between thermodynamic states i and j. - # _n_proposed_matrix[i][j] is the number of swaps proposed between thermodynamic states i and j. - # Allocate memory for energy matrix. energy_thermodynamic_states[k][l] - # is the reduced potential computed at the positions of SamplerState sampler_states[k] - # and ThermodynamicState thermodynamic_states[l]. - + # Initialize matrices for tracking acceptance and proposal statistics. self._n_accepted_matrix = np.zeros([self.n_states, self.n_states], np.int64) self._n_proposed_matrix = np.zeros([self.n_states, self.n_states], np.int64) self._energy_thermodynamic_states = np.zeros( [self.n_replicas, self.n_states], np.float64 ) self._traj = [[] for _ in range(self.n_replicas)] + # Ensure there is an MCMCMove for each thermodynamic state. from chiron.mcmc import MCMCMove @@ -294,10 +297,9 @@ def _minimize_replica( replica_id : int The index of the replica to minimize. tolerance : unit.Quantity, optional - The energy tolerance to which the system should be minimized. - Defaults to 1.0 kilojoules/mole/nanometers. + The energy tolerance for minimization (default: 1.0 kJ/mol/nm). max_iterations : int, optional - The maximum number of minimization iterations. Defaults to 1000. + Maximum number of minimization iterations (default: 1000). Notes ----- @@ -307,7 +309,6 @@ def _minimize_replica( from chiron.minimze import minimize_energy from loguru import logger as log - # Retrieve thermodynamic and sampler states. thermodynamic_state = self._thermodynamic_states[ self._replica_thermodynamic_states[replica_id] ] @@ -377,94 +378,90 @@ def minimize( def _propagate_replica(self, replica_id: int): """ - Propagate the state of a single replica. - - This method applies the MCMC move to the replica to change its state - according to the specified thermodynamic state. + Propagate the state of a single replica using its assigned MCMC move. Parameters ---------- replica_id : int The index of the replica to propagate. + Raises ------ RuntimeError If an error occurs during the propagation of the replica. """ - # Retrieve thermodynamic, sampler states, and MCMC move of this replica. + thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] sampler_state = self._sampler_states[replica_id] - thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] mcmc_move = self._mcmc_moves[thermodynamic_state_id] - # Apply MCMC move. + # Apply the MCMC move to the replica. mcmc_move.run(sampler_state, thermodynamic_state) + # Append the new state to the trajectory for analysis. self._traj[replica_id].append(sampler_state.x0) def _perform_swap_proposals(self): """ Perform swap proposals between replicas. - Placeholder method for replica swapping logic. Subclasses should - override this method with specific swapping algorithms. + This method should be overridden by subclasses to implement specific swapping algorithms. Returns ------- np.ndarray An array of updated thermodynamic state indices for each replica. """ - # Placeholder implementation, should be overridden by subclasses # For this example, we'll just return the current state indices return self._replica_thermodynamic_states def _mix_replicas(self) -> np.ndarray: """ - Propose and execute swaps between replicas. + Propose and execute swaps between replicas to enhance sampling efficiency. - This method is responsible for enhancing sampling efficiency by proposing - swaps between different thermodynamic states of the replicas. The actual - swapping algorithm depends on the specific subclass implementation. + This method handles the logic for proposing swaps between different thermodynamic states + of the replicas. The specifics of the swapping algorithm depend on subclass implementations. Returns ------- np.ndarray - An array of updated thermodynamic state indices for each replica. + An array of updated thermodynamic state indices for each replica after swapping. """ from loguru import logger as log log.debug("Mixing replicas (does nothing for MultiStateSampler)...") - # Reset storage to keep track of swap attempts this iteration. + # Reset swap attempt counters for this iteration. self._n_accepted_matrix[:, :] = 0 self._n_proposed_matrix[:, :] = 0 - # Perform replica mixing (swap proposals and acceptances) - # The actual swapping logic would depend on subclass implementations - # Here, we assume a placeholder implementation + # Perform the swap proposals and acceptances. new_replica_states = self._perform_swap_proposals() # Calculate swap acceptance statistics n_swaps_proposed = self._n_proposed_matrix.sum() n_swaps_accepted = self._n_accepted_matrix.sum() swap_fraction_accepted = 0.0 - if n_swaps_proposed > 0: - swap_fraction_accepted = n_swaps_accepted / n_swaps_proposed + swap_fraction_accepted = ( + n_swaps_accepted / n_swaps_proposed if n_swaps_proposed > 0 else 0.0 + ) log.debug( f"Accepted {n_swaps_accepted}/{n_swaps_proposed} attempted swaps ({swap_fraction_accepted * 100.0:.1f}%)" ) + return new_replica_states def _propagate_replicas(self) -> None: """ Propagate all replicas through their respective MCMC moves. - This method iterates over all replicas and applies the corresponding MCMC move - to each one, based on its current thermodynamic state. + This method applies the corresponding MCMC move to each replica based on its + current thermodynamic state, thus advancing the state of each replica. """ from loguru import logger as log log.debug("Propagating all replicas...") + # Iterate over all replicas and propagate each one. for replica_id in range(self.n_replicas): self._propagate_replica(replica_id) @@ -472,9 +469,8 @@ def _compute_energies(self) -> None: """ Compute the energies of all replicas at all thermodynamic states. - This method calculates the energy for each replica in every thermodynamic state, - considering the defined neighborhoods to optimize the computation. The energies - are stored in the internal energy matrix of the sampler. + This method calculates the energy for each replica in every thermodynamic state. + The energies are stored in the internal energy matrix of the sampler. """ from loguru import logger as log @@ -482,9 +478,8 @@ def _compute_energies(self) -> None: # Initialize the energy matrix and neighborhoods self._energy_thermodynamic_states = np.zeros((self.n_replicas, self.n_states)) - # Calculate energies for each replica + # Calculate and store energies for each replica. for replica_id in range(self.n_replicas): - # Compute and store energies for the neighborhood states self._energy_thermodynamic_states[ replica_id, : ] = self._compute_replica_energies(replica_id) @@ -493,20 +488,18 @@ def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: """ Determine if the sampling process has met its completion criteria. - This method checks if the simulation has reached a specified iteration limit - or any other predefined stopping condition. + Checks if the simulation has reached a specified iteration limit or any other + predefined stopping condition. Parameters ---------- iteration_limit : Optional[int], default=None - An optional iteration limit. If specified, the method checks if the - current iteration number has reached this limit. + An optional iteration limit to check against the current iteration number. Returns ------- bool - True if the simulation has completed based on the stopping criteria, - False otherwise. + True if the simulation has completed based on the stopping criteria, False otherwise. """ from loguru import logger as log @@ -523,105 +516,164 @@ def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: def run(self, n_iterations: int = 10) -> None: """ - Execute the replica-exchange simulation. + Execute the replica-exchange simulation for a specified number of iterations. - Run the simulation for a specified number of iterations. If no number is - specified, it runs for the number of iterations set during the initialization - of the sampler. + Runs the simulation, performing replica propagation, mixing, and energy computation + for the specified number of iterations. Parameters ---------- n_iterations : int, default=10 - The number of iterations to run. - - Raises - ------ - RuntimeError - If an error occurs during the computation of energies. + The number of iterations to run the simulation. """ from loguru import logger as log - # If this is the first iteration, compute and store the - # starting energies of the minimized/equilibrated structures. - self.number_of_iterations = n_iterations - log.info("Running simulation...") - self._energy_thermodynamic_states_for_each_iteration_in_run = np.zeros( - [self.n_replicas, self.n_states, n_iterations + 1], np.float64 - ) - # Initialize energies if this is the first iteration + self.number_of_iterations = n_iterations + if self._iteration == 0: + # Initialize energies if this is the first iteration self._compute_energies() - # store energies for mbar analysis - self._energy_thermodynamic_states_for_each_iteration_in_run[ - :, :, self._iteration - ] = self._energy_thermodynamic_states - # TODO report energies + self._report_iteration() # start the sampling loop - log.debug(f"{n_iterations=}") while not self._is_completed(n_iterations): - # Increment iteration counter. self._iteration += 1 - log.info("-" * 80) log.info(f"Iteration {self._iteration}/{n_iterations}") log.info("-" * 80) - # Update thermodynamic states self._mix_replicas() - - # Propagate replicas. self._propagate_replicas() - - # Compute energies of all replicas at all states self._compute_energies() + self._report_iteration() + self._update_analysis() - # Add energies to the energy matrix - self._energy_thermodynamic_states_for_each_iteration_in_run[ - :, :, self._iteration - ] = self._energy_thermodynamic_states - # Write iteration to storage file - # TODO - # self._report_iteration() + self._reporter.flush_buffer() - # Update analysis - self._update_analysis() + def _report_energy_matrix(self): + """ + Report the energy matrix for each thermodynamic state. - def _report_iteration(self): - """Store positions, states, and energies of current iteration.""" + This method logs the energy per thermodynamic state, which is useful for analysis + and debugging purposes. + """ + from loguru import logger as log + + log.debug("Reporting energy per thermodynamic state...") + # NOTE: self._energy_thermodynamic_states is transposed from + # shape (n_replicas, n_states) to (n_states, n_replicas) + return {"u_kn": self._energy_thermodynamic_states.T} + + def _report_positions(self): + """ + Store and report the positions of all replicas at the current iteration. + + This method compiles and reports the position data for each replica, which + is critical for trajectory analysis. + """ + from loguru import logger as log + + log.debug("Reporting positions...") + # numpy array with shape (n_replicas, n_atoms, 3) + xyz = np.zeros((self.n_replicas, self._sampler_states[0].x0.shape[0], 3)) + for replica_id in range(self.n_replicas): + xyz[replica_id] = self._sampler_states[replica_id].x0 + return {"positions": xyz} + + def _report(self, property: str) -> None: + """ + Report a specific property of the simulation. + + Depending on the specified property, this method delegates to the appropriate + internal reporting method. - # TODO: write energies + Parameters + ---------- + property : str + The property to report. Can be 'positions', 'states', 'energies', + 'trajectory', 'mixing_statistics', or 'all'. + """ + from loguru import logger as log - # TODO: write trajectory + log.debug(f"Reporting {property}...") + match property: + case "positions": + return self._report_positions() + case "states": + pass + case "u_kn": + return self._report_energy_matrix() + case "trajectory": + return + case "mixing_statistics": + return - # TODO: write mixing statistics + def _report_iteration(self): + """ + Store and report various properties of the current iteration. + + This method is called at each iteration to report essential simulation data, + such as positions, states, energies, and other properties defined in the reporter. + """ + from loguru import logger as log + + log.debug("Reporting data for current iteration...") + log.debug(self._reporter.properties_to_report) + prop = {} + for property in self._reporter.properties_to_report: + p = self._report(property) + if p: + prop.update(p) + self._reporter.report(prop) def _update_analysis(self): - """Update analysis of free energies""" + """ + Update the analysis of free energies based on the current simulation data. + + This method is responsible for updating the free energy estimates, either using + online or offline estimation methods, as configured in the sampler. + """ from loguru import logger as log + log.debug("Updating free energy analysis...") + # Perform offline free energy estimate if requested if self._offline_estimator: log.debug("Performing offline free energy estimate...") N_k = [self._iteration] * self.n_states + u_kn = self._reporter.get_property("u_kn") self._offline_estimator.initialize( - u_kn=self._energy_thermodynamic_states_for_each_iteration_in_run, + u_kn=u_kn, N_k=N_k, ) + log.debug(self._offline_estimator.f_k) elif self._online_estimator: log.debug("Performing online free energy estimate...") - self._online_estimator.update( - u_kn=self._energy_thermodynamic_states_for_each_iteration_in_run[ - :, :, self._iteration - ] - ) + self._online_estimator.update() else: raise RuntimeError("No free energy estimator provided.") @property - def f_k(self): + def f_k(self) -> np.ndarray: + """ + Get the current free energy estimates. + + Returns the free energy estimates calculated by the sampler's free energy estimator. + The specific estimator used (online or offline) depends on the sampler configuration. + + Returns + ------- + np.ndarray + Array of free energy estimates for each thermodynamic state. + + Raises + ------ + RuntimeError + If no free energy estimator is found. + """ + if self._offline_estimator: return self._offline_estimator.f_k elif self._online_estimator: diff --git a/chiron/reporters.py b/chiron/reporters.py index 6230a44..27457a6 100644 --- a/chiron/reporters.py +++ b/chiron/reporters.py @@ -5,31 +5,85 @@ import numpy as np from openmm.app import Topology +from typing import List -class SimulationReporter: - def __init__(self, filename: str, topology: Topology, buffer_size: int = 1): +class BaseReporter: + _directory = None + + @classmethod + def set_directory(cls, directory: str): """ - Initialize the SimulationReporter. + Set the base directory for saving reporter files. Parameters ---------- - filename : str - Name of the HDF5 file to write the simulation data. - topology: openmm.Topology - buffer_size : int, optional - Number of data points to buffer before writing to disk (default is 1). + directory : str + The path to the directory where files will be saved. + """ + cls._directory = directory + + @classmethod + def get_directory(cls): + """ + Get the current directory set for saving reporter files. + Returns + ------- + Path + The path to the directory where files will be saved. Defaults to the + current working directory if no directory has been set. """ - self.filename = filename + from pathlib import Path + + if cls._directory is None: + log.debug( + f"No directory set, using current working directory: {Path.cwd()}" + ) + return Path.cwd() + return Path(cls._directory) + + +class _SimulationReporter: + def __init__(self, file_name: str, buffer_size: int = 10): + """ + Initialize the _SimulationReporter class. + + Parameters + ---------- + file_name : str + Name of the HDF5 file for writing simulation data. + buffer_size : int, optional + The size of the buffer before flushing data to disk (default is 10). + """ + workdir = BaseReporter.get_directory() + self.file_path_base = workdir / f"{file_name}" + self.log_file_path = self.file_path_base.with_suffix(".h5") + self.workdir = workdir + self.report_iteration = 0 + import os + + os.makedirs(workdir, exist_ok=True) + + log.info(f"Writing simulation log data to {self.log_file_path}") + self.buffer_size = buffer_size - self.topology = topology self.buffer = {} - self.h5file = h5py.File(filename, "a") - log.info(f"Writing simulation data to {filename}") + + @property + def properties_to_report(self): + return self._default_properties + + @properties_to_report.setter + def properties_to_report(self, properties: List[str]): + self._default_properties = properties def get_available_keys(self): - return self.h5file.keys() + keys = [] + with h5py.File(self.log_file_path, "r") as h5file: + for key in h5file: + keys.append(key) + return keys def report(self, data_dict): """ @@ -40,17 +94,26 @@ def report(self, data_dict): data_dict : dict Dictionary containing data to report. Keys are data labels (e.g., 'energy'), and values are the data points (usually numpy arrays). - """ for key, value in data_dict.items(): if key not in self.buffer: + # new key shouldn't trigger a flush self.buffer[key] = [] self.buffer[key].append(value) - if len(self.buffer[key]) >= self.buffer_size: - self._write_to_disk(key) + self._flush_buffer_if_necessary() - def _write_to_disk(self, key:str): + def _flush_buffer_if_necessary(self): + """ + Flush the buffer to disk if it reaches the specified buffer size. + """ + # NOTE: we assume that every property is updated with the same frequency! + if all(len(self.buffer[key]) > self.buffer_size for key in self.buffer): + # flush and reset the buffer + log.debug(self.buffer) + self.flush_buffer() + + def _write_to_disk(self, key: str): """ Write buffered data of a given key to the HDF5 file. @@ -60,51 +123,222 @@ def _write_to_disk(self, key:str): The key of the data to write to disk. """ - data = np.array(self.buffer[key]) - if key in self.h5file: - dset = self.h5file[key] - dset.resize((dset.shape[0] + data.shape[0],) + data.shape[1:]) - dset[-data.shape[0] :] = data - else: - log.debug(f"Creating {key} in {self.filename}") - self.h5file.create_dataset( - key, data=data, maxshape=(None,) + data.shape[1:], chunks=True + log.debug(f"Writing {key} to file") + if key == "positions" and hasattr(self, "_write_to_trajectory"): + xyz = np.stack(self.buffer[key]) + self._write_to_trajectory( + positions=xyz, ) - self.buffer[key] = [] + with h5py.File(self.log_file_path, "a") as h5file: + if key in h5file: + data = np.array(self.buffer[key]) + dset = h5file[key] + dset.resize((dset.shape[0] + data.shape[0],) + data.shape[1:]) + dset[-data.shape[0] :] = data + else: + data = np.array(self.buffer[key]) + log.debug(f"Creating {key} in {self.log_file_path}") + h5file.create_dataset( + key, data=data, maxshape=(None,) + data.shape[1:], chunks=True + ) + + def reset_reporter_file(self): + # delete the reporter files + import os - def close(self): + # if file exists, delete it + if os.path.exists(self.log_file_path): + log.debug(f"Deleting {self.log_file_path}") + os.remove(self.log_file_path) + + def flush_buffer(self) -> None: """ - Write any remaining data in the buffer to disk and close the HDF5 file. + Write any remaining data in the buffer to disk. """ for key in self.buffer: if self.buffer[key]: self._write_to_disk(key) - self.h5file.close() + self._reset_buffer() + + def _reset_buffer(self) -> None: + """ + Reset the data buffer after writing to disk. + """ + self.buffer = {key: [] for key in self.buffer} - def get_property(self, name: str): + def get_property(self, name: str) -> np.ndarray: """ - Get the property from the HDF5 file. + Retrieve a specific property from the HDF5 file. Parameters ---------- name : str - Name of the property to get. + The name of the property to retrieve. Returns ------- np.ndarray - The property. + The retrieved property data, if available. + """ + if name == "positions" and hasattr(self, "read_from_trajectory"): + return self.read_from_trajectory() + + with h5py.File(self.log_file_path, "r") as h5file: + if name in h5file: + data = np.array(h5file[name]) + elif name in self.buffer and name not in h5file: + data = np.array(self.buffer[name]) + elif name not in h5file: + log.warning(f"{name} not in HDF5 file") + return None + + if name == "u_kn": + return np.transpose( + data, (2, 1, 0) + ) # shape: n_states, n_replicas, n_iterations + + else: + return data + +from typing import Optional +import mdtraj as md + + +class MultistateReporter(_SimulationReporter): + _name = "multistate_reporter" + _default_properties = [ + "positions", + "box_vectors", + "u_kn", + "state_index", + "step", + ] + + def __init__( + self, + file_name: Optional[str] = None, + buffer_size: int = 1, + ) -> None: """ - if name not in self.h5file: - log.debug(f"{name} not in HDF5 file") - return None - else: - return np.array(self.h5file[name]) + Initialize the MultistateReporter class. + + Parameters + ---------- + file_name : Optional[str], optional + Name of the file for storing multistate simulation data. If None, a + default name based on the reporter name is used. + buffer_size : int, optional + The size of the buffer before flushing data to disk (default is 1). + """ + + if file_name is None: + file_name = MultistateReporter.get_name() + + super().__init__(file_name=file_name, buffer_size=buffer_size) + self._replica_reporter = {} + + @classmethod + def get_name(cls): + return cls._name + + def _write_to_trajectory(self, positions: np.ndarray) -> None: + nr_of_frames, n_replicas, n_of_atoms, _ = positions.shape + + for replica_id in range(n_replicas): + # if file does not exist, create it + key = f"replica_{replica_id}" + if self._replica_reporter.get(key) is None: + self._replica_reporter[key] = LangevinDynamicsReporter(key) + + reporter = self._replica_reporter.get(key) + + for frame_id in range(nr_of_frames): + data = {"positions": positions[frame_id, replica_id]} + if self.buffer.get("box_vectors") is not None: + data["box_vectors"] = self.buffer.get("box_vectors")[frame_id] + reporter.report(data) + + def flush_buffer(self): + for reporter in self._replica_reporter.values(): + reporter.flush_buffer() + reporter._write_xtc_file_handle.flush() + + return super().flush_buffer() + + +from typing import Optional + - def get_mdtraj_trajectory(self): +class MCReporter(_SimulationReporter): + _name = "mc_reporter" + + def __init__(self, file_name: Optional[str] = None, buffer_size: int = 1) -> None: + """ + Initialize the MCReporter class for Monte Carlo simulations. + + Parameters + ---------- + file_name : Optional[str], optional + The file name for storing simulation data. + buffer_size : int, optional + The size of the buffer before flushing data to disk. + """ + if file_name is None: + file_name = MCReporter.get_name() + + super().__init__(file_name=file_name, buffer_size=buffer_size) + + @classmethod + def get_name(cls): + return cls._name + + +class LangevinDynamicsReporter(_SimulationReporter): + _name = "langevin_reporter" + _default_properties = ["positions", "box_vectors", "potential_energy", "step"] + + def __init__( + self, + file_name: Optional[str] = None, + buffer_size: int = 1, + topology: Optional[Topology] = None, + ): + """ + Initialize the LangevinDynamicsReporter for Langevin dynamics simulations. + + Parameters + ---------- + file_name : Optional[str], optional + The file name for storing simulation data. + buffer_size : int, optional + The size of the buffer before flushing data to disk. + topology : Optional[Topology], optional + The system topology for generating trajectories. + """ + if file_name is None: + file_name = LangevinDynamicsReporter.get_name() + + super().__init__(file_name=file_name, buffer_size=buffer_size) + self.topology = topology + self._write_xtc_file_handle = None + self.xtc_file_path = f"{self.file_path_base}.xtc" + + @classmethod + def get_name(cls): + return cls._name + + def get_mdtraj_trajectory(self) -> md.Trajectory: + """ + Generate an MDTraj trajectory object from the stored positions. + + Returns + ------- + md.Trajectory + The MDTraj trajectory object created from the stored position data. + """ import mdtraj as md return md.Trajectory( @@ -114,22 +348,85 @@ def get_mdtraj_trajectory(self): unitcell_angles=self.get_property("box_angles"), ) + def _write_to_trajectory(self, positions: np.ndarray) -> None: + """ + Write position data to a trajectory file for molecular dynamics. + Parameters + ---------- + positions : np.ndarray + The positions of particles to be written to the trajectory. + """ + if self._write_xtc_file_handle is None: + log.debug(f"Creating trajectory in {self.xtc_file_path}") + self._write_xtc_file_handle = md.formats.XTCTrajectoryFile( + self.xtc_file_path, mode="w" + ) + + LangevinDynamicsReporter._write_to_xtc( + file_handler=self._write_xtc_file_handle, + positions=positions, + iteration=self.buffer.get("step"), + box_vecotrs=self.buffer.get("box_vectors"), + ) + + def read_from_trajectory(self) -> np.ndarray: + """ + Read position data from a trajectory file. + + Returns + ------- + np.ndarray + The positions read from the trajectory file. + """ + # flush the write buffer + self._write_xtc_file_handle.flush() + with md.formats.XTCTrajectoryFile( + self.xtc_file_path, mode="r" + ) as _read_xtc_file_handle: + return LangevinDynamicsReporter._read_from_xtc(_read_xtc_file_handle) + + @classmethod + def _read_from_xtc(cls, file_handler) -> np.ndarray: + """ + Read data from an XTC file. -class MultistateReporter: - - def __init__(self, path_to_dir:str) -> None: - self.path_to_dir = path_to_dir - - def _write_trajectories(): - pass - - def _write_energies(): - pass - - def _write_states(): - pass - - - - \ No newline at end of file + Parameters + ---------- + file_handler : md.formats.XTCTrajectoryFile + The file handler for reading XTC files. + + Returns + ------- + np.ndarray + The data read from the XTC file. + """ + return file_handler.read() + + @classmethod + def _write_to_xtc( + cls, + file_handler: md.formats.XTCTrajectoryFile, + positions: np.ndarray, + iteration: np.ndarray, + box_vecotrs: Optional[np.ndarray] = None, + ): + """ + Write position data to an XTC file. + + Parameters + ---------- + file_handler : md.formats.XTCTrajectoryFile + The file handler for writing to XTC files. + positions : np.ndarray + The positions to be written. + iteration : np.ndarray + The iteration numbers corresponding to the positions. + box_vectors : Optional[np.ndarray], optional + Box vectors for each position frame. + """ + file_handler.write( + positions, + time=iteration, + box=box_vecotrs, + ) diff --git a/chiron/states.py b/chiron/states.py index 99ee8f9..99459ae 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -2,6 +2,7 @@ from typing import List, Optional, Union from jax import numpy as jnp from .potential import NeuralNetworkPotential +from jax import random class SamplerState: @@ -22,6 +23,7 @@ class SamplerState: def __init__( self, x0: unit.Quantity, + current_PRNG_key: random.PRNGKey, velocities: Optional[unit.Quantity] = None, box_vectors: Optional[unit.Quantity] = None, ) -> None: @@ -61,9 +63,16 @@ def __init__( raise ValueError( f"box_vectors must be a 3x3 array, got {box_vectors.shape} instead." ) + if velocities is not None and x0.shape != velocities.shape: + raise ValueError( + f"x0 and velocities must have the same shape, got {x0.shape} and {velocities.shape} instead." + ) + if current_PRNG_key is None: + raise ValueError(f"random_seed must be set.") self._x0 = x0 self._velocities = velocities + self._current_PRNG_key = current_PRNG_key self._box_vectors = box_vectors self._distance_unit = unit.nanometer @@ -98,6 +107,12 @@ def x0(self, x0: Union[jnp.array, unit.Quantity]) -> None: def distance_unit(self) -> unit.Unit: return self._distance_unit + @property + def new_PRNG_key(self) -> random.PRNGKey: + key, subkey = random.split(self._current_PRNG_key) + self._current_PRNG_key = key + return subkey + def _convert_to_jnp(self, array: unit.Quantity) -> jnp.array: """ Convert the sampler state to jnp arrays. diff --git a/chiron/tests/data/langevin_reporter.h5 b/chiron/tests/data/langevin_reporter.h5 new file mode 100644 index 0000000000000000000000000000000000000000..e6e23c53d68b33389861d3b00fc962070d077a24 GIT binary patch literal 12872 zcmeI1c~lg~x5tNl8Dw+E4FnWiBJLnv-4{h6Dx#pF+`!w%btq ztZq!e;K2&DQnoLjX*Y$Np{8UiH~e${f3*U>gS|#JOgNCv+Mb>(=zT`3iio*WW;T?k z{cGCXj~6Hu;Stlq8tU@>>E2NJ3x}SO?K4g-C|C~B*rdqc5dDEd^WON!uz zfB);p=n?+jdg`I&w?t9vKOa(l?q7ENbDjKL+sh01=kmWjZFt}5DSG|keaHO|XI5mR z5z6|8o8nIoTMd=<%wzM1uD*<3*JVj&^uN0?|6seJ;VPq-pOgPqtZuY#sv)xf#*GPS z`wT{JqC=m@=;iydv2HIJy;);I+78V$1=)YihTC|jI7Y8*T>6%rWAqj}^sSm}1hPL( zbm)T_y`>I)7NfV)p|551^82N+Ztg8KME0LP4vonbBBQ6bU!(d8Mo%wyqk8L>@A^Y8 zN2B^cMo*{Ts6L(1Q}-Ix*D!ioZB*|@U)1t?%J1#QIz%yg9ovASl+oKYW|+2{QiHNT zZFJ~IGWxbU^zPj>0@;2$9r_4H-(H74kI{F~p|4}~_B!-XOE$4Q4ikkLEq(5Ex{PCE28jJ~rDy;~2BKweK59r`Fn@2W#z%ILf3(3`rw z>rYo5`jL#D-cF77N@eu)ayP2K!|3VqHmY~&sX_93(&;y<4`cLnxJLCwjGk5-)vH6T zcMe%E#~pHfqbqKC$H+gGe^%iCXa&v}rS!8fIpKb0a%BHYdZJL;aZUfszTJyr6$+>{ zo>;Wh!MSMr8v)E?4c%jFI=h#?i3X?FGmG}dp5 zxDZu5;KOb4D(8+0D7>Aes`2?$rMe{mSJ%h$P5)$1pZn49_2>IkJB;7 zuB(gN58Te3e6MT-a50-qWwkILRbj0ZYT8q|`!J>C0RJ zVgmbd`MUx+izCqx`AwqqVen3AYg#|P8~qTPR;}_#d(@Pe&3@teZN+lIBen&%toa(L z=Pm^t>-B{clsi#M`%-}FsHvRax%HgqzGyh?S|s_6Jta+9lfoxcI~^bVJfPpQ-emHJ zEjZIID};F;jNz`$+#@Z!uK+WbBT~Y%c`t9Yt>C!y;*+g&r(VeJt55&ydEc1gIVVIhW8$LEIyw zAt*x+9rK?{DhFF3kbbbaUES#!T05KU&z;Jpt5yg;JkK~@0`hnZu20v0nsqT#u*PRdL_Lc6e-Bx>1&@VxJ7L?kPH5cxyt)1 zga-?6aE{dq%#F~4-ANvJJoT7l^{D`P54Lk7W(AUnk9^7hn=pNk^H03jBZegCi&dYoWQA$58z;)LGPxt6NQmhc+sq;+C9p;Xc@(n;v z9*wzjhBN=Qg1osLz}+uiDU?UK>f8@T&lQf*d^!jf+X8i3mrH zBQWQ=4<_wb!VT(A&CojT1FR&+o(|z+4P%AJR%xX0_WPLnu>shG=i;&j*-|zK5S(J= zvCpTREYlZZ)N~PliXMXtTi5{55BNMcYqIXaNz!(zA9r_XtWa|@i{xkcV{nN9sDcV{ z^yYPvTRcFAFIN^{Ogv3EGZ7a0$KuV*N$7L78SJEXj`(yY#q&=RUy}gNLyQ&pVFe^; z>ll38(GYS}rRY3jp%i%pAbgHem9x2;#J3dTvporzxM((-wzdOPnjjcUfLJ7)BpWOJ zxVQtc!mHY&WXUfRu+ug};6FW%VZV%$5}yF9xW7ep+Vv(G<}8B4j1)YxX(?{=c7RZ7 zXH7^D*%VblhOP?XTK^g=^f$UdPW79P`&^A6@8b#__PD2XySo<*8el194*8b2br!*{ z%SKE)zY=5WT;LqFOD&CIwfO=|2+LT5n z`IeD0Z^m#BH^d87DkUHPV=8{q-voXOIF1$}dsKmIz2M;e>(cQmeZG~u2<>NW#anN- zVU3S31k;2|#=TvnWZDVh*J>!|!mSqK>}~n)W7E;`j0v>JIfgmK#?tx?UU17ofhz|Z z@y>z>)=t~8*UAjsc`6XfshzbQ_mSkx3bK3raBk(n)xuftR(y_kCe8~q1uyB?KP1*u1pY4%<90aDss{E zYg2&Jq5()Fq}SiTHc2ysO0h4{N-YX(_2y7zntN0e$s*_v}F4h+Mb%i7=|}t_!fp+GW-F< ztr*^oCdlORtr;G{aC$pvRT9JL<;m!=d!SGHDccKXudHl`{4`sLu z!w)c=&0p0sybIIcnh8hO^IIE4tyz_SyI; znBi>vl*Mp1eyU}-JdV5w?sP+!+u8U@WH=i?RWO{5pR8#>Alql-r$C0Y@l!g(EoH~) zxrX6Z40odehdjPD!=o5ZZwIZal;QMpYPl&5SY-QjxwU*G!`b*LmEm+c+WvPK&c;tJ zG`^9?r*duoFox5rmKQOch8tS0pbfHpHh%JDI2%7DGn|c|wlJKHpRyRv#!pY^hAjKX z#!qf^LzcPyf6eS|DT?82{8Y+tHhwat8@gqQhoaF`iJ-T7?B_%&0ScK&+dAR3(Ar{e1cKGh&qyv_a z)t|?59(FddlR-VCtF6N-;T8^guUrQyyUSN384IRQP`230L+~eqj_-ybYY)dy{ zLv@bW&sHK%zeI8?wk8UTzxU#6jIUwTNOM@a@EC5~)(p3MdV%QbfioPM@aF^7dC=@a zTvw?=_f}5nQ6#8o{-_FaMj*LzL#3!w}aBT$bM%tD1U2q1|BY zT5HMg@D*Y9z(eT5AI9}&Veo+3ITCV$_>bI2f=;<`{bH_l zmWZCs?309Vd?)hPU+9-izGn`vZ7$ENZ}mzk98wpD`rriG%Q^KkAivT4U`5^*F(Rpp%|tT&0|-G^G091Kyy z{?u#O=g~SG_gH{1AraH;n(>0S2(~ZwV4U+M?2v5;oRwx?hw`_Ow&Ey~>-$(`LXw16 z)0glst~M>X8>NJaQ*PqgcH7ZP{d?)~ohjHcusOe1ZD*)cI@-;>ghlT?5Ynlgh~kap z)AC+q=w>@<$CM-?_QU0T^Z6}HA~q;tXw^Nu9G{Juo7KO+m958|+s*mj!$t6JpNbyw zmoYr1TrjlO_>%iYCgt}@5<(s&@&%hZl&mgRLXNTy zA3rF?a7O{olzonO&06yNric*in1CA>R^un_5``*i=haP-d|#Qyt-Jbx)XybJuqjF5 z{j3~IKHIB==Jl^IrnnOMY90&@?Z6p>?D&6$ictQ1B|1Gjk2fcW32ru;d7Z79OK=;{ z#h)K96}3(hi0LMN)NSXIpVE|29#oH~dRODPY#x>k-HYo=+w+6xi|||ZVhoyn2D{t% z2^**#zrd+v^yqz_Mi#TBzJ^J{jlO9-_qcOO`xGUNGV{hS}I5woc&2_VImEr_e}62$7Wqnxmbisb<>j6KSX}P_f79! zV!coa1>+SZD?X~lxo$js*S8wacX8!i7l;sjX&}D88!kNbi1&Ne`V6Mge+1*X|STi#QtYYHRYj~{R4dURX08@QiMBoJ+WY75k9;3n0KOq G`@aDuKjq&5 literal 0 HcmV?d00001 diff --git a/chiron/tests/test_convergence_tests.py b/chiron/tests/test_convergence_tests.py index 86c881c..16cfad3 100644 --- a/chiron/tests/test_convergence_tests.py +++ b/chiron/tests/test_convergence_tests.py @@ -48,11 +48,11 @@ def test_convergence_of_MC_estimator(prep_temp_dir): ) sampler_state = SamplerState(ho.positions) - from chiron.reporters import SimulationReporter + from chiron.reporters import _SimulationReporter id = uuid.uuid4() - simulation_reporter = SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") + simulation_reporter = _SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") # Initalize the move set (here only LangevinDynamicsMove) from chiron.mcmc import MetropolisDisplacementMove, MoveSchedule, MCMCSampler @@ -61,7 +61,7 @@ def test_convergence_of_MC_estimator(prep_temp_dir): nr_of_moves=100_000, displacement_sigma=0.5 * unit.angstrom, atom_subset=[0], - simulation_reporter=simulation_reporter, + reporter=simulation_reporter, ) move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) @@ -153,12 +153,12 @@ def test_langevin_dynamics_with_LJ_fluid(prep_temp_dir): potential=lj_potential, temperature=300 * unit.kelvin ) - from chiron.reporters import SimulationReporter + from chiron.reporters import _SimulationReporter id = uuid.uuid4() - reporter = SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") + reporter = _SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") - integrator = LangevinIntegrator(reporter=reporter, save_frequency=100) + integrator = LangevinIntegrator(reporter=reporter, report_frequency=100) integrator.run( sampler_state, thermodynamic_state, diff --git a/chiron/tests/test_integrators.py b/chiron/tests/test_integrators.py index 203b409..341c987 100644 --- a/chiron/tests/test_integrators.py +++ b/chiron/tests/test_integrators.py @@ -24,20 +24,28 @@ def test_langevin_dynamics(prep_temp_dir, provide_testsystems_and_potentials): # initialize states and integrator from chiron.integrators import LangevinIntegrator from chiron.states import SamplerState, ThermodynamicState + from chiron.utils import PRNG + + PRNG.set_seed(1234) thermodynamic_state = ThermodynamicState( potential=potential, temperature=300 * unit.kelvin ) - sampler_state = SamplerState(testsystem.positions) - from chiron.reporters import SimulationReporter + sampler_state = SamplerState(testsystem.positions, PRNG.get_random_key()) + + from chiron.reporters import LangevinDynamicsReporter + from chiron.reporters import BaseReporter + + # set up reporter directory + BaseReporter.set_directory(prep_temp_dir.join(f"test_{i}")) - reporter = SimulationReporter(f"{prep_temp_dir}/test{i}.h5", None, 1) + reporter = LangevinDynamicsReporter() - integrator = LangevinIntegrator(reporter=reporter) + integrator = LangevinIntegrator(reporter=reporter, report_frequency=1) integrator.run( sampler_state, thermodynamic_state, - n_steps=5, + n_steps=20, ) i = i + 1 diff --git a/chiron/tests/test_mcmc.py b/chiron/tests/test_mcmc.py index 63c3f92..21bf8b6 100644 --- a/chiron/tests/test_mcmc.py +++ b/chiron/tests/test_mcmc.py @@ -36,40 +36,47 @@ def test_sample_from_harmonic_osciallator(prep_temp_dir): thermodynamic_state = ThermodynamicState( potential=harmonic_potential, temperature=300 * unit.kelvin ) - sampler_state = SamplerState(x0=ho.positions) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + + sampler_state = SamplerState( + x0=ho.positions, current_PRNG_key=PRNG.get_random_key() + ) from chiron.integrators import LangevinIntegrator - from chiron.reporters import SimulationReporter + from chiron.reporters import LangevinDynamicsReporter, BaseReporter id = uuid.uuid4() - h5_file = f"test_{id}.h5" - reporter = SimulationReporter(f"{prep_temp_dir}/{h5_file}", 1) + wd = prep_temp_dir.join(f"_test_{id}") + BaseReporter.set_directory(wd) + reporter = LangevinDynamicsReporter() integrator = LangevinIntegrator( - stepsize=0.2 * unit.femtosecond, reporter=reporter, save_frequency=1 + stepsize=2 * unit.femtosecond, reporter=reporter, report_frequency=1 ) - r = integrator.run( + integrator.run( sampler_state, thermodynamic_state, n_steps=5, ) - + integrator.reporter.flush_buffer() import jax.numpy as jnp import h5py - h5 = h5py.File(f"{prep_temp_dir}/{h5_file}", "r") + h5 = h5py.File(f"{wd}/{LangevinDynamicsReporter.get_name()}.h5", "r") keys = h5.keys() - assert "energy" in keys, "Energy not in keys" + assert "potential_energy" in keys, "Energy not in keys" assert "step" in keys, "Step not in keys" - assert "traj" in keys, "Traj not in keys" + assert "traj" not in keys, "Traj is not in keys" - energy = h5["energy"][:] + energy = h5["potential_energy"][:] print(energy) reference_energy = jnp.array( - [0.00019308, 0.00077772, 0.00174247, 0.00307798, 0.00479007] + [0.03551735, 0.1395877, 0.30911613, 0.5495938, 0.85149795] ) assert jnp.allclose(energy, reference_energy) @@ -99,23 +106,23 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_LangevinDynamics # Initalize the sampler and thermodynamic state from chiron.states import ThermodynamicState, SamplerState + from chiron.utils import PRNG + PRNG.set_seed(1234) thermodynamic_state = ThermodynamicState( harmonic_potential, temperature=300 * unit.kelvin, volume=30 * (unit.angstrom**3), ) - sampler_state = SamplerState(ho.positions) + sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) # Initalize the move set (here only LangevinDynamicsMove) and reporter - from chiron.reporters import SimulationReporter + from chiron.reporters import LangevinDynamicsReporter, BaseReporter - simulation_reporter = SimulationReporter( - f"{prep_temp_dir}/test_{uuid.uuid4()}.h5", None, 1 - ) - langevin_move = LangevinDynamicsMove( - nr_of_steps=10, seed=1234, simulation_reporter=simulation_reporter - ) + BaseReporter.set_directory(prep_temp_dir) + + simulation_reporter = LangevinDynamicsReporter(1) + langevin_move = LangevinDynamicsMove(nr_of_steps=10, reporter=simulation_reporter) move_set = MoveSchedule([("LangevinMove", langevin_move)]) @@ -157,20 +164,23 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla temperature=300 * unit.kelvin, volume=30 * (unit.angstrom**3), ) - sampler_state = SamplerState(ho.positions) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) # Initalize the move set and reporter - from chiron.reporters import SimulationReporter + from chiron.reporters import MCReporter, BaseReporter - simulation_reporter = SimulationReporter( - f"{prep_temp_dir}/test_{uuid.uuid4()}.h5", 1 - ) + wd = prep_temp_dir.join(f"_test_{uuid.uuid4()}") + BaseReporter.set_directory(wd) + simulation_reporter = MCReporter(1) mc_displacement_move = MetropolisDisplacementMove( nr_of_moves=10, displacement_sigma=0.1 * unit.angstrom, atom_subset=[0], - simulation_reporter=simulation_reporter, + reporter=simulation_reporter, ) move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) @@ -212,20 +222,25 @@ def test_sample_from_harmonic_osciallator_array_with_MCMC_classes_and_Metropolis temperature=300 * unit.kelvin, volume=30 * (unit.angstrom**3), ) - sampler_state = SamplerState(ho.positions) + + from chiron.utils import PRNG + + PRNG.set_seed(1234) + sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) # Initalize the move set and reporter - from chiron.reporters import SimulationReporter + from chiron.reporters import MCReporter, BaseReporter - simulation_reporter = SimulationReporter( - f"{prep_temp_dir}/test_{uuid.uuid4()}.h5", 1 - ) + wd = prep_temp_dir.join(f"_test_{uuid.uuid4()}") + BaseReporter.set_directory(wd) + + simulation_reporter = MCReporter(1) mc_displacement_move = MetropolisDisplacementMove( nr_of_moves=10, displacement_sigma=0.1 * unit.angstrom, atom_subset=None, - simulation_reporter=simulation_reporter, + reporter=simulation_reporter, ) move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) diff --git a/chiron/tests/test_minization.py b/chiron/tests/test_minization.py index 3d89854..cf0cf4e 100644 --- a/chiron/tests/test_minization.py +++ b/chiron/tests/test_minization.py @@ -16,8 +16,13 @@ def test_minimization(): cutoff = unit.Quantity(1.0, unit.nanometer) lj_potential = LJPotential(lj_fluid.topology, cutoff=cutoff) + from chiron.utils import PRNG + + PRNG.set_seed(1234) sampler_state = SamplerState( - lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors() + lj_fluid.positions, + current_PRNG_key=PRNG.get_random_key(), + box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors(), ) # use parilist nbr_list = PairList(OrthogonalPeriodicSpace(), cutoff=cutoff) @@ -92,10 +97,14 @@ def test_minimize_two_particles(): lj_potential = LJPotential(None, sigma=sigma, epsilon=epsilon, cutoff=cutoff) coordinates = jnp.array([[0.0, 0.0, 0.0], [0.9, 0.0, 0.0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) # define the sampler state sampler_state = SamplerState( x0=coordinates * unit.nanometer, + current_PRNG_key=PRNG.get_random_key(), box_vectors=jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) * unit.nanometer, ) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index c316046..ef31a84 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -14,6 +14,7 @@ def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: from openmm import unit from chiron.mcmc import LangevinDynamicsMove from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace + from chiron.reporters import MultistateReporter, BaseReporter sigma = 0.34 * unit.nanometer cutoff = 3.0 * sigma @@ -23,9 +24,12 @@ def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) - move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=500) + move = LangevinDynamicsMove(stepsize=1.0 * unit.femtoseconds, nr_of_steps=100) + BaseReporter.set_directory("multistate_test") + reporter = MultistateReporter() + reporter.reset_reporter_file() - multistate_sampler = MultiStateSampler(mcmc_moves=move) + multistate_sampler = MultiStateSampler(mcmc_moves=move, reporter=reporter) return nbr_list, multistate_sampler @@ -59,7 +63,11 @@ def ho_multistate_sampler_multiple_minima() -> MultiStateSampler: ) for x0 in x0s ] - sampler_state = [SamplerState(ho.positions) for _ in x0s] + from chiron.utils import PRNG + + PRNG.set_seed(1234) + + sampler_state = [SamplerState(ho.positions, PRNG.get_random_key()) for _ in x0s] nbr_list, multistate_sampler = setup_sampler() multistate_sampler.create( thermodynamic_states=thermodynamic_states, @@ -102,8 +110,14 @@ def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: from loguru import logger as log log.info(f"Initialize harmonic oscillator with {n_states} states and ks {Ks}") + from chiron.utils import PRNG - sampler_state = [SamplerState(ho.positions) for _ in sigmas] + PRNG.set_seed(1234) + + sampler_state = [ + SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) + for _ in sigmas + ] import numpy as np f_i = np.array( @@ -198,7 +212,7 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): print(f"Analytical free energy difference: {ho_sampler.delta_f_ij_analytical[0]}") - n_iteratinos = 25 + n_iteratinos = 250 ho_sampler.run(n_iteratinos) # check that we have the correct number of iterations, replicas and states @@ -207,8 +221,11 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): assert ho_sampler.n_replicas == 4 assert ho_sampler.n_states == 4 + u_kn = ho_sampler.reporter.get_property("u_kn") + assert u_kn.shape == (n_iteratinos, 4, 4) # check that the free energies are correct print(ho_sampler.analytical_f_i) + # [ 0. , -0.28593054, -0.54696467, -0.78709279] print(ho_sampler.delta_f_ij_analytical) print(ho_sampler.f_k) diff --git a/chiron/tests/test_pairs.py b/chiron/tests/test_pairs.py index 4802bb6..fb2bf2c 100644 --- a/chiron/tests/test_pairs.py +++ b/chiron/tests/test_pairs.py @@ -95,8 +95,13 @@ def test_neighborlist_pair(): coordinates = jnp.array([[0, 0, 0], [1, 0, 0]]) box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + state = SamplerState( x0=unit.Quantity(coordinates, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) @@ -203,8 +208,14 @@ def test_inputs(): nbr_list.build_from_state(123) coordinates = jnp.array([[1, 2, 3], [0, 0, 0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), box_vectors=None + x0=unit.Quantity(coordinates, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), + box_vectors=None, ) # check that boxvectors are defined in the state @@ -271,8 +282,13 @@ def test_neighborlist_pair_multiple_particles(): coordinates = jnp.stack(coord_mesh.reshape(3, -1), axis=1, dtype=jnp.float32) box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + state = SamplerState( x0=unit.Quantity(coordinates, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) @@ -341,8 +357,13 @@ def test_pairlist_pair(): coordinates = jnp.array([[0, 0, 0], [1, 0, 0]]) box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + state = SamplerState( x0=unit.Quantity(coordinates, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) @@ -390,8 +411,13 @@ def test_pair_list_multiple_particles(): coordinates = jnp.stack(coord_mesh.reshape(3, -1), axis=1, dtype=jnp.float32) box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + state = SamplerState( x0=unit.Quantity(coordinates, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) diff --git a/chiron/tests/test_potential.py b/chiron/tests/test_potential.py index 8eb97fd..230fa1c 100644 --- a/chiron/tests/test_potential.py +++ b/chiron/tests/test_potential.py @@ -33,7 +33,9 @@ def test_neural_network_pairlist(): cutoffs = [0.2, 0.1] expected_pairs = [(1, 1), (0, 0)] for cutoff, expected in zip(cutoffs, expected_pairs): - distances, displacement_vectors, pairlist = nn_potential.compute_pairlist(positions, cutoff) + distances, displacement_vectors, pairlist = nn_potential.compute_pairlist( + positions, cutoff + ) assert pairlist[0].size == expected[0] and pairlist[1].size == expected[1] # Test with ethanol molecule @@ -44,7 +46,9 @@ def test_neural_network_pairlist(): # Test compute_pairlist method cutoff = 0.2 - distances, displacement_vectors, pairlist = nn_potential.compute_pairlist(positions, cutoff) + distances, displacement_vectors, pairlist = nn_potential.compute_pairlist( + positions, cutoff + ) print(pairlist) assert ( pairlist[0].size == 12 and pairlist[1].size == 12 @@ -92,8 +96,9 @@ def test_harmonic_oscillator_potential(): forces = harmonic_potential.compute_force(positions_without_unit) assert forces.shape == positions_without_unit.shape, "Forces shape mismatch." + def test_harmonic_oscillator_input_checking(): - #topology check + # topology check with pytest.raises(TypeError): HarmonicOscillatorPotential(1) with pytest.raises(TypeError): @@ -104,14 +109,15 @@ def test_harmonic_oscillator_input_checking(): HarmonicOscillatorPotential(None, U0=1.0) with pytest.raises(ValueError): - HarmonicOscillatorPotential(None, k=1.0*unit.nanometer) + HarmonicOscillatorPotential(None, k=1.0 * unit.nanometer) with pytest.raises(ValueError): - HarmonicOscillatorPotential(None, x0=1.0*unit.kilocalories_per_mole) + HarmonicOscillatorPotential(None, x0=1.0 * unit.kilocalories_per_mole) with pytest.raises(ValueError): - HarmonicOscillatorPotential(None, U0=1.0*unit.nanometer) + HarmonicOscillatorPotential(None, U0=1.0 * unit.nanometer) + def test_lj_input_checking(): - #topology check + # topology check with pytest.raises(TypeError): LJPotential(1) with pytest.raises(TypeError): @@ -122,28 +128,30 @@ def test_lj_input_checking(): LJPotential(None, cutoff=1.0) with pytest.raises(ValueError): - LJPotential(None, sigma=1.0*unit.kilocalories_per_mole) + LJPotential(None, sigma=1.0 * unit.kilocalories_per_mole) with pytest.raises(ValueError): - LJPotential(None, epsilon=1.0*unit.nanometer) + LJPotential(None, epsilon=1.0 * unit.nanometer) with pytest.raises(ValueError): - LJPotential(None, cutoff=1.0*unit.kilocalories_per_mole) + LJPotential(None, cutoff=1.0 * unit.kilocalories_per_mole) from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace + positions = jnp.array([[0, 0, 0], [1, 0, 0]]) box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - lj = LJPotential(None, sigma=1.0*unit.nanometer) - nbr_list = NeighborListNsqrd(OrthogonalPeriodicSpace(), cutoff=2.0*unit.nanometer) + lj = LJPotential(None, sigma=1.0 * unit.nanometer) + nbr_list = NeighborListNsqrd(OrthogonalPeriodicSpace(), cutoff=2.0 * unit.nanometer) - #capture the error associated with not building the neighborlist + # capture the error associated with not building the neighborlist with pytest.raises(ValueError): lj.compute_energy(positions, nbr_list) nbr_list.build(positions, box_vectors) - #capture the error associated cutoffs not matching + # capture the error associated cutoffs not matching with pytest.raises(ValueError): lj.compute_energy(positions, nbr_list) + def test_lennard_jones(): # This will evaluate two LJ particles to ensure the energy and force are correct from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace @@ -156,27 +164,46 @@ def test_lennard_jones(): box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) space = OrthogonalPeriodicSpace() - lj_pot = LJPotential(None, unit.Quantity(sigma, unit.nanometer), unit.Quantity(epsilon, unit.kilojoules_per_mole), - unit.Quantity(cutoff, unit.nanometer)) + lj_pot = LJPotential( + None, + unit.Quantity(sigma, unit.nanometer), + unit.Quantity(epsilon, unit.kilojoules_per_mole), + unit.Quantity(cutoff, unit.nanometer), + ) + from chiron.utils import PRNG + + PRNG.set_seed(1234) for i in range(1, 11): positions = jnp.array([[0, 0, 0], [i * 0.25 * 2 ** (1 / 6), 0, 0]]) - state = SamplerState(x0=unit.Quantity(positions, unit.nanometer), box_vectors = unit.Quantity(box_vectors, - unit.nanometer)) - nbr_list = NeighborListNsqrd(space, cutoff = unit.Quantity(cutoff, unit.nanometer), skin=unit.Quantity(skin, unit.nanometer), n_max_neighbors=5) + state = SamplerState( + x0=unit.Quantity(positions, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), + box_vectors=unit.Quantity(box_vectors, unit.nanometer), + ) + nbr_list = NeighborListNsqrd( + space, + cutoff=unit.Quantity(cutoff, unit.nanometer), + skin=unit.Quantity(skin, unit.nanometer), + n_max_neighbors=5, + ) nbr_list.build_from_state(state) # first use the pairlist energy_chiron = lj_pot.compute_energy(positions) energy_chiron_nbr = lj_pot.compute_energy(positions, nbr_list) - displacement_vector = positions[0]-positions[1] + displacement_vector = positions[0] - positions[1] dist = jnp.linalg.norm(displacement_vector) - energy_analytical = 4.0*epsilon*((sigma/dist)**12-(sigma/dist)**6) + energy_analytical = 4.0 * epsilon * ((sigma / dist) ** 12 - (sigma / dist) ** 6) - assert jnp.isclose(energy_chiron, energy_analytical), "Energy from chiron using a pair list does not match the analytical energy calculation" - assert jnp.isclose(energy_chiron_nbr, energy_analytical), "Energy from chiron using a neighbor list does not match the analytical energy calculation" + assert jnp.isclose( + energy_chiron, energy_analytical + ), "Energy from chiron using a pair list does not match the analytical energy calculation" + assert jnp.isclose( + energy_chiron_nbr, energy_analytical + ), "Energy from chiron using a neighbor list does not match the analytical energy calculation" force_chiron = lj_pot.compute_force(positions) force_chiron_nbr = lj_pot.compute_force(positions, nbr_list) @@ -185,14 +212,19 @@ def test_lennard_jones(): force_chiron_analytical = lj_pot.compute_force_analytical(positions) force = ( - 24 - * (epsilon / (dist * dist)) - * (2 * (sigma / dist) ** 12 - (sigma / dist) ** 6) - ) * displacement_vector + 24 + * (epsilon / (dist * dist)) + * (2 * (sigma / dist) ** 12 - (sigma / dist) ** 6) + ) * displacement_vector forces_analytical = jnp.array([force, -force]) - assert jnp.allclose(force_chiron, forces_analytical, atol=1e-5), "Force from chiron using pair list does not match analytical force" - assert jnp.allclose(force_chiron_nbr, forces_analytical, atol=1e-5), "Force from chiron using neighbor list does not match analytical force" - assert jnp.allclose(force_chiron_analytical, forces_analytical, atol=1e-5), "Force from chiron analytical using pair list does not match analytical force" - + assert jnp.allclose( + force_chiron, forces_analytical, atol=1e-5 + ), "Force from chiron using pair list does not match analytical force" + assert jnp.allclose( + force_chiron_nbr, forces_analytical, atol=1e-5 + ), "Force from chiron using neighbor list does not match analytical force" + assert jnp.allclose( + force_chiron_analytical, forces_analytical, atol=1e-5 + ), "Force from chiron analytical using pair list does not match analytical force" diff --git a/chiron/tests/test_states.py b/chiron/tests/test_states.py index d94869a..4499f18 100644 --- a/chiron/tests/test_states.py +++ b/chiron/tests/test_states.py @@ -24,8 +24,11 @@ def test_initialize_state(): assert state.pressure is None assert state.volume == 30 * (unit.angstrom**3) assert state.nr_of_particles == 1 + from chiron.utils import PRNG - sampler_state = SamplerState(ho.positions) + PRNG.set_seed(1234) + + sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) assert jnp.allclose( sampler_state.x0, @@ -41,9 +44,13 @@ def test_sampler_state_conversion(): from chiron.states import SamplerState from openmm import unit import jax.numpy as jnp + from chiron.utils import PRNG + + PRNG.set_seed(1234) sampler_state = SamplerState( - unit.Quantity(jnp.array([[10.0, 10.0, 10.0]]), unit.nanometer) + unit.Quantity(jnp.array([[10.0, 10.0, 10.0]]), unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), ) assert jnp.allclose( @@ -52,7 +59,8 @@ def test_sampler_state_conversion(): ) sampler_state = SamplerState( - unit.Quantity(jnp.array([[10.0, 10.0, 10.0]]), unit.angstrom) + unit.Quantity(jnp.array([[10.0, 10.0, 10.0]]), unit.angstrom), + current_PRNG_key=PRNG.get_random_key(), ) assert jnp.allclose( @@ -66,6 +74,9 @@ def test_sampler_state_inputs(): from openmm import unit import jax.numpy as jnp import pytest + from chiron.utils import PRNG + + PRNG.set_seed(1234) # test input of positions # should have units @@ -73,19 +84,24 @@ def test_sampler_state_inputs(): SamplerState(x0=jnp.array([1, 2, 3])) # throw and error because of incompatible units with pytest.raises(ValueError): - SamplerState(x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.radians)) + SamplerState( + x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.radians), + current_PRNG_key=PRNG.get_random_key(), + ) # test input of velocities # velocities should have units with pytest.raises(TypeError): SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), velocities=jnp.array([1, 2, 3]), ) # velocities should have units of distance/time with pytest.raises(ValueError): SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), velocities=unit.Quantity(jnp.array([1, 2, 3]), unit.nanometers), ) @@ -94,12 +110,14 @@ def test_sampler_state_inputs(): with pytest.raises(TypeError): SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), box_vectors=jnp.array([1, 2, 3]), ) # box_vectors should have units of distance with pytest.raises(ValueError): SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity( jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), unit.radians ), @@ -108,6 +126,7 @@ def test_sampler_state_inputs(): with pytest.raises(ValueError): SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity( jnp.array([[1, 0, 0], [0, 1, 0]]), unit.nanometers ), @@ -122,6 +141,7 @@ def test_sampler_state_inputs(): # check openmm_box conversion: state = SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), box_vectors=openmm_box, ) assert jnp.allclose( @@ -135,7 +155,9 @@ def test_sampler_state_inputs(): # openmm box vectors end up as a list with contents; check to make sure we capture an error if we pass a bad list with pytest.raises(TypeError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), box_vectors=[123] + x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), + box_vectors=[123], ) @@ -146,6 +168,7 @@ def test_reduced_potential(): from chiron.potential import HarmonicOscillatorPotential import jax.numpy as jnp from openmmtools.testsystems import HarmonicOscillator + from chiron.utils import PRNG ho = HarmonicOscillator() potential = HarmonicOscillatorPotential(topology=ho.topology, k=ho.K, U0=ho.U0) @@ -153,7 +176,7 @@ def test_reduced_potential(): state = ThermodynamicState( potential, temperature=300 * unit.kelvin, volume=30 * (unit.angstrom**3) ) - sampler_state = SamplerState(ho.positions) + sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) reduced_e = state.get_reduced_potential(sampler_state) assert reduced_e == 0.0 diff --git a/chiron/tests/test_testsystems.py b/chiron/tests/test_testsystems.py index 5f14269..3506b77 100644 --- a/chiron/tests/test_testsystems.py +++ b/chiron/tests/test_testsystems.py @@ -182,8 +182,12 @@ def test_LJ_fluid(): dispersion_correction=False, shift=False, ) + from chiron.utils import PRNG + + PRNG.set_seed(1234) state = SamplerState( x0=lj_openmm.positions, + current_PRNG_key=PRNG.get_random_key(), box_vectors=lj_openmm.system.getDefaultPeriodicBoxVectors(), ) diff --git a/chiron/tests/test_utils.py b/chiron/tests/test_utils.py index 2cef5b3..59f6d10 100644 --- a/chiron/tests/test_utils.py +++ b/chiron/tests/test_utils.py @@ -22,48 +22,113 @@ def test_get_list_of_mass(): assert np.isclose(c, expected[0]), "Incorrect masses returned" -def test_reporter(): - """Read in a reporter file and check its contend.""" - import h5py - import numpy as np - from chiron.utils import get_data_file_path +import pytest +from .test_multistate import ho_multistate_sampler_multiple_ks + + +@pytest.fixture(scope="session") +def prep_temp_dir(tmpdir_factory): + """Create a temporary directory for the test.""" + tmpdir = tmpdir_factory.mktemp("test_reporter") + return tmpdir - h5_file = "test_md.h5" - h5_test_file = get_data_file_path(h5_file) - print(h5_test_file) - # Read the h5 file manually and check values - h5 = h5py.File(h5_test_file, "r") - keys = h5.keys() +def test_reporter(prep_temp_dir, ho_multistate_sampler_multiple_ks): + from chiron.integrators import LangevinIntegrator + from chiron.potential import HarmonicOscillatorPotential + from openmm import unit + + from openmmtools.testsystems import HarmonicOscillator - assert "energy" in keys, "Energy not in keys" - assert "step" in keys, "Step not in keys" - assert "traj" in keys, "Traj not in keys" + ho = HarmonicOscillator() + potential = HarmonicOscillatorPotential(ho.topology) + from chiron.utils import PRNG - energy = h5["energy"][:5] - reference_energy = np.array( - [1.9328993e-06, 2.0289978e-02, 8.3407544e-02, 1.7832418e-01, 2.8428176e-01] + PRNG.set_seed(1234) + + from chiron.states import SamplerState, ThermodynamicState + + thermodynamic_state = ThermodynamicState( + potential=potential, temperature=300 * unit.kelvin ) - assert np.allclose( - energy, - reference_energy, - ), "Energy not correct" - h5.close() + sampler_state = SamplerState(ho.positions, PRNG.get_random_key()) + + from chiron.reporters import LangevinDynamicsReporter + from chiron.reporters import BaseReporter + + # set up reporter directory + BaseReporter.set_directory(prep_temp_dir) + + # test langevin reporter + reporter = LangevinDynamicsReporter("langevin_test") + reporter.reset_reporter_file() + + integrator = LangevinIntegrator(reporter=reporter, report_frequency=1) + integrator.run( + sampler_state, + thermodynamic_state, + n_steps=20, + ) + import numpy as np + + reporter.flush_buffer() + + # test for available keys + assert "potential_energy" in reporter.get_available_keys() + assert "step" in reporter.get_available_keys() + + # test for property + pot_energy = reporter.get_property("potential_energy") + np.allclose( + pot_energy, + np.array( + [ + 8.8336921e-05, + 3.5010747e-04, + 7.8302569e-04, + 1.4021739e-03, + 2.1981772e-03, + 3.1483083e-03, + 4.2442558e-03, + 5.4960307e-03, + 6.8922052e-03, + 8.4171966e-03, + 1.0099258e-02, + 1.1929392e-02, + 1.3859766e-02, + 1.5893064e-02, + 1.8023632e-02, + 2.0219875e-02, + 2.2491256e-02, + 2.4893485e-02, + 2.7451182e-02, + 3.0140089e-02, + ], + dtype=np.float32, + ), + ) - # Use the reporter class and check values - from chiron.reporters import SimulationReporter + # test that xtc and log file is written + import os - reporter = SimulationReporter(h5_test_file, None, 1) - assert np.allclose(reference_energy, reporter.get_property("energy")[:5]) - reporter.close() - # test the topology - from openmmtools.testsystems import HarmonicOscillatorArray + assert os.path.exists(reporter.xtc_file_path) + assert os.path.exists(reporter.log_file_path) - ho = HarmonicOscillatorArray() - topology = ho.topology - reporter = SimulationReporter(h5_test_file, topology, 1) - traj = reporter.get_mdtraj_trajectory() - import mdtraj as md + # test multistate reporter + ho_sampler = ho_multistate_sampler_multiple_ks + ho_sampler._reporter.reset_reporter_file() + ho_sampler.run(5) - assert isinstance(traj, md.Trajectory), "Trajectory not correct type" + assert len(ho_sampler._reporter._replica_reporter.keys()) == 4 + assert ho_sampler._reporter._replica_reporter.get("replica_0") + assert ho_sampler._reporter._default_properties == [ + "positions", + "box_vectors", + "u_kn", + "state_index", + "step", + ] + u_kn = ho_sampler._reporter.get_property("u_kn") + assert u_kn.shape == (4, 4, 6) + assert os.path.exists(ho_sampler._reporter.log_file_path) diff --git a/chiron/utils.py b/chiron/utils.py index abe2ec9..a9a7da0 100644 --- a/chiron/utils.py +++ b/chiron/utils.py @@ -1,5 +1,40 @@ from openmm.app import Topology from openmm import unit +from jax import random + + +class PRNG: + _key: random.PRNGKey + _seed: int + + def __init__(self) -> None: + """ + A PRNG class that can be used to generate random numbers in JAX. + The intended use case is to initialize new PRN streams in the `SamplerState` class. + + Example: + -------- + from chiron.utils import PRNG + from chiron.states import SamplerState + from openmmtools.testsystems import HarmonicOscillator + + ho = HarmonicOscillator() + PRNG.set_seed(1234) + sampler_state = [SamplerState(ho.positions, PRNG.get_random_key()) for _ in x0s] + + """ + + pass + @classmethod + def set_seed(cls, seed: int) -> None: + cls._seed = seed + cls._key = random.PRNGKey(seed) + + @classmethod + def get_random_key(cls) -> int: + key, subkey = random.split(cls._key) + cls._key = key + return subkey def get_data_file_path(relative_path: str) -> str: From ccdc4cb963fd3108dd101616b46d3d364858a935 Mon Sep 17 00:00:00 2001 From: Marcus Wieder <31651017+wiederm@users.noreply.github.com> Date: Thu, 8 Feb 2024 10:08:11 +0100 Subject: [PATCH 47/55] Using MCMCSampler instead of MCMCMove in multistate sampler (#22) * Refactor MCMC sampler and MultiStateSampler classes * Refactor MCMC and MultiStateSampler classes --- chiron/mcmc.py | 22 ++++++++++------- chiron/multistate.py | 44 ++++++++++++++++----------------- chiron/tests/test_mcmc.py | 18 +++++++++----- chiron/tests/test_multistate.py | 13 +++++++--- 4 files changed, 56 insertions(+), 41 deletions(-) diff --git a/chiron/mcmc.py b/chiron/mcmc.py index 8063012..285cb34 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -272,18 +272,18 @@ class MCMCSampler: def __init__( self, move_set: MoveSchedule, - sampler_state: SamplerState, - thermodynamic_state: ThermodynamicState, ): - from copy import deepcopy from loguru import logger as log - log.info("Initializing Gibbs sampler") + log.info("Initializing MCMC sampler") self.move = move_set - self.sampler_state = deepcopy(sampler_state) - self.thermodynamic_state = deepcopy(thermodynamic_state) - def run(self, n_iterations: int = 1): + def run( + self, + sampler_state: SamplerState, + thermodynamic_state: ThermodynamicState, + n_iterations: int = 1, + ): """ Run the sampler for a specified number of iterations. @@ -293,6 +293,10 @@ def run(self, n_iterations: int = 1): Number of iterations of the sampler to run. """ from loguru import logger as log + from copy import deepcopy + + sampler_state = deepcopy(sampler_state) + thermodynamic_state = deepcopy(thermodynamic_state) log.info("Running MCMC sampler") log.info(f"move_schedule = {self.move.move_schedule}") @@ -300,15 +304,15 @@ def run(self, n_iterations: int = 1): log.info(f"Iteration {iteration + 1}/{n_iterations}") for move_name, move in self.move.move_schedule: log.debug(f"Performing: {move_name}") - move.run(self.sampler_state, self.thermodynamic_state) + move.run(sampler_state, thermodynamic_state) log.info("Finished running MCMC sampler") log.debug("Closing reporter") for _, move in self.move.move_schedule: if move.reporter is not None: move.reporter.flush_buffer() - # TODO: flush reporter log.debug(f"Closed reporter {move.reporter.log_file_path}") + return sampler_state from .neighbors import PairsBase diff --git a/chiron/multistate.py b/chiron/multistate.py index b9100ec..26f3ea3 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -3,7 +3,7 @@ from chiron.neighbors import NeighborListNsqrd from openmm import unit import numpy as np -from chiron.mcmc import MCMCMove +from chiron.mcmc import MCMCMove, MCMCSampler from chiron.reporters import MultistateReporter @@ -25,8 +25,8 @@ class MultiStateSampler: Number of replicas (read-only). iteration : int Current iteration of the simulation (read-only). - mcmc_moves : List[MCMCMove] - MCMC moves used to propagate the simulation. + mcmc_sampler : MCMCSampler + MCMC sampler used to propagate the simulation. sampler_states : List[SamplerState] Sampler states list at the current iteration. is_periodic : bool @@ -47,7 +47,7 @@ class MultiStateSampler: def __init__( self, - mcmc_moves: Union[MCMCMove, List[MCMCMove]], + mcmc_sampler: MCMCSampler, reporter: MultistateReporter, ): """ @@ -55,8 +55,8 @@ def __init__( Parameters ---------- - mcmc_moves : Union[MCMCMove, List[MCMCMove]] - The MCMCMove or list of MCMCMoves used to propagate the thermodynamic states. + mcmc_sampler : MCMCSampler + The MCMCSampler used to propagate the thermodynamic states. reporter : MultistateReporter The reporter used to store the simulation data. """ @@ -75,9 +75,9 @@ def __init__( self._neighborhoods = None self._n_accepted_matrix = None self._n_proposed_matrix = None - self._reporter = reporter # NOTE: reporter needs to be putlic, API change ahead + self._reporter = reporter # NOTE: reporter needs to be putlic, API change ahead self._metadata = None - self._mcmc_moves = copy.deepcopy(mcmc_moves) + self._mcmc_sampler = copy.deepcopy(mcmc_sampler) self._online_estimator = None self._offline_estimator = MBAREstimator() @@ -121,15 +121,15 @@ def iteration(self): return self._iteration @property - def mcmc_moves(self): - """A copy of the MCMCMoves list used to propagate the simulation. + def mcmc_sampler(self): + """A copy of the MCMCSampler used to propagate the simulation. This can be set only before creation. """ import copy - return copy.deepcopy(self._mcmc_moves) + return copy.deepcopy(self._mcmc_sampler) @property def sampler_states(self) -> Optional[List[SamplerState]]: @@ -268,16 +268,16 @@ def _allocate_variables( ) self._traj = [[] for _ in range(self.n_replicas)] - # Ensure there is an MCMCMove for each thermodynamic state. - from chiron.mcmc import MCMCMove + # Ensure there is an MCMCSampler for each thermodynamic state. + from chiron.mcmc import MCMCSampler - if isinstance(self._mcmc_moves, MCMCMove): - self._mcmc_moves = [ - copy.deepcopy(self._mcmc_moves) for _ in range(self.n_states) + if isinstance(self._mcmc_sampler, MCMCSampler): + self._mcmc_sampler = [ + copy.deepcopy(self._mcmc_sampler) for _ in range(self.n_states) ] - elif len(self._mcmc_moves) != self.n_states: + elif len(self._mcmc_sampler) != self.n_states: raise RuntimeError( - f"The number of MCMCMoves ({len(self._mcmc_moves)}) and ThermodynamicStates ({self.n_states}) must be the same." + f"The number of MCMCMoves ({len(self._mcmc_sampler)}) and ThermodynamicStates ({self.n_states}) must be the same." ) # Reset iteration counter. @@ -394,11 +394,11 @@ def _propagate_replica(self, replica_id: int): thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] sampler_state = self._sampler_states[replica_id] thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] - mcmc_move = self._mcmc_moves[thermodynamic_state_id] - # Apply the MCMC move to the replica. - mcmc_move.run(sampler_state, thermodynamic_state) + mcmc_sampler = self._mcmc_sampler[thermodynamic_state_id] + # Propagate using the mcmc sampler + self._sampler_states[replica_id] = mcmc_sampler.run(sampler_state, thermodynamic_state) # Append the new state to the trajectory for analysis. - self._traj[replica_id].append(sampler_state.x0) + self._traj[replica_id].append(self._sampler_states[replica_id].x0) def _perform_swap_proposals(self): """ diff --git a/chiron/tests/test_mcmc.py b/chiron/tests/test_mcmc.py index 21bf8b6..82ef9a0 100644 --- a/chiron/tests/test_mcmc.py +++ b/chiron/tests/test_mcmc.py @@ -127,10 +127,12 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_LangevinDynamics move_set = MoveSchedule([("LangevinMove", langevin_move)]) # Initalize the sampler - sampler = MCMCSampler(move_set, sampler_state, thermodynamic_state) + sampler = MCMCSampler(move_set) # Run the sampler with the thermodynamic state and sampler state and return the sampler state - sampler.run(n_iterations=2) # how many times to repeat + sampler.run( + sampler_state, thermodynamic_state, n_iterations=2 + ) # how many times to repeat def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDisplacementMove( @@ -186,10 +188,12 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) # Initalize the sampler - sampler = MCMCSampler(move_set, sampler_state, thermodynamic_state) + sampler = MCMCSampler(move_set) # Run the sampler with the thermodynamic state and sampler state and return the sampler state - sampler.run(n_iterations=2) # how many times to repeat + sampler.run( + sampler_state, thermodynamic_state, n_iterations=2 + ) # how many times to repeat def test_sample_from_harmonic_osciallator_array_with_MCMC_classes_and_MetropolisDisplacementMove( @@ -246,10 +250,12 @@ def test_sample_from_harmonic_osciallator_array_with_MCMC_classes_and_Metropolis move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) # Initalize the sampler - sampler = MCMCSampler(move_set, sampler_state, thermodynamic_state) + sampler = MCMCSampler(move_set) # Run the sampler with the thermodynamic state and sampler state and return the sampler state - sampler.run(n_iterations=2) # how many times to repeat + sampler.run( + sampler_state, thermodynamic_state, n_iterations=2 + ) # how many times to repeat def test_thermodynamic_state_inputs(): diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index ef31a84..0f153be 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -15,6 +15,7 @@ def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: from chiron.mcmc import LangevinDynamicsMove from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace from chiron.reporters import MultistateReporter, BaseReporter + from chiron.mcmc import MCMCSampler, MoveSchedule sigma = 0.34 * unit.nanometer cutoff = 3.0 * sigma @@ -24,12 +25,16 @@ def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) - move = LangevinDynamicsMove(stepsize=1.0 * unit.femtoseconds, nr_of_steps=100) + lang_move = LangevinDynamicsMove(stepsize=1.0 * unit.femtoseconds, nr_of_steps=100) BaseReporter.set_directory("multistate_test") reporter = MultistateReporter() reporter.reset_reporter_file() + move_schedule = MoveSchedule([("LangevinDynamicsMove", lang_move)]) + mcmc_sampler = MCMCSampler( + move_schedule, + ) - multistate_sampler = MultiStateSampler(mcmc_moves=move, reporter=reporter) + multistate_sampler = MultiStateSampler(mcmc_sampler=mcmc_sampler, reporter=reporter) return nbr_list, multistate_sampler @@ -212,7 +217,7 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): print(f"Analytical free energy difference: {ho_sampler.delta_f_ij_analytical[0]}") - n_iteratinos = 250 + n_iteratinos = 25 ho_sampler.run(n_iteratinos) # check that we have the correct number of iterations, replicas and states @@ -221,7 +226,7 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): assert ho_sampler.n_replicas == 4 assert ho_sampler.n_states == 4 - u_kn = ho_sampler.reporter.get_property("u_kn") + u_kn = ho_sampler._reporter.get_property("u_kn") assert u_kn.shape == (n_iteratinos, 4, 4) # check that the free energies are correct print(ho_sampler.analytical_f_i) From 124e519d729cfa043b839a9cdee98cb610fba839 Mon Sep 17 00:00:00 2001 From: Marcus Wieder <31651017+wiederm@users.noreply.github.com> Date: Tue, 27 Feb 2024 19:09:31 +0000 Subject: [PATCH 48/55] Refactoring of MC moves (#21) * Fix u_kn shape in test_multistate_run * Refactor MCMCMove subclasses and update method names * Started MC refactor, merging from multistage branch that includes new reporters and random number scheme. This primarily sketches out the new MCMove class (updates to the actual moves is forthcoming). * Small refactor to Langevin integrator. Small changes: run command now takes variable to allow for velocity initialization. The Langevin ntegrator now returns a sampler state rather tha updating. States now has a velocity setter. * created initialize_velocities funtion in utilities; langevin code now uses this if users request velocities to be regenerated each time. The reaosn to split was to make it easy to define velocities outside of the langevin integrator. the option flag to generate each time we call the integrator because in a hybrid workflow, there would effectively be no connection between subsequent langevin moves separated by many MC transformations. furthermore, MC moves that changes thermodynamic states would require regeneration. * Finished fixing up langevin integrator/move * Implemented Metropolis displacement move in new scheme. * Added refactored barostat move and ideal gas test case. * Abstractmethod for _reporting wrapper (this will make it easier for each move to have custom/appropriate logging for what is being changed). * Added _reporter to the displacement moves and the ability to only move a subset of atoms: tests need to be updated/implemented still. * Added in stepsize updater to allow move_parameters to be updated on the fly. * Continued reworking to ensure expected behavior of loggers and stepsize adjustments. * Fixed test_integrator and atom_subsetting in displacement move. * Added in flag to initialize velocities the first time langevin is called. This is different than reinitialize which will call reinit every time run is called. The velocity init function was spun off into the utils file. velocities can be set in the sampler state; code will throw and error if initialize_velocities or reinitialize_velocities are false, and velocities are not set in the sampler state. * test_multistate still is failing an assertion, but I think that is fine, since this PR was focused on revamping the MC moves. That can be tackled separately in the multistate PR. * Modified the routines to ensure that passing the nbr_list fits in more with the functional programming model (passing and return a neighborlist). * Added ideal gas test written in the other PR. updated logic in langevin to take into account the iteration not just the current step; reporters now also log the elapsed_step. * Added ideal gas test written in the other PR. updated logic in langevin to take into account the iteration not just the current step; reporters now also log the elapsed_step. Updated the examples * Added in additional tests for barostat * Fixed convergence test syntax: these were missed on my part because they were marked to be skipped because they take too long. * Fixed convergence test syntax: these were missed on my part because they were marked to be skipped because they take too long. for Harmonic oscillator we do not seem to need to run as long as set initially to pass and get convergence. * Updated CI.yaml to enabled CI to run for branches commiting to the multistage branch. Thanks Mike Henry! * Added in skip to the multistate testing, as this still needs to be worked on in the multistate branch this is being commited to" * match/case statements don't exist in python 3.9. I commented this out and just added in if/elif statements. * fixture was missing in test_testsystems * fixture was missing in test_testsystems * weirdly wrong syntax in the test ideal gas test. * Updating ideal gas example. * Update Examples/Idealgas.py with descriptive assert statement Co-authored-by: Marcus Wieder <31651017+wiederm@users.noreply.github.com> * Update Examples/Idealgas.py with descriptive assert statement Co-authored-by: Marcus Wieder <31651017+wiederm@users.noreply.github.com> * Updating ideal gas example. * Updating ideal gas example. * removed velocity initialization flag in langevin * updated various functions in reponse to marcus' comments. * Merge failed to correctly merge, and broke MCMCSampler; fixed now. multistate reporter giving an error. * Working through comments from Jchodera. * Fixed error with multistate systems. Various other updates based on comments. * Changed `coordinates` to `positions` in neighbor/pairlist routines to be consistent with samplerstate variable name change. * Changed PairListNsqrd to allow setting the cutoff to None which will use no cutoff (i.e. all pairs interact). Revamped some internal tooling regarding how units are treated internally. Space was revampped such that it takes box_vectors as a argument rather than storing them internally (fewer copies of this floating around), and will help to ensure treating the class as static will not mean using the wrong box vectors accidentally. * mised an instance where the reporter called space.box_vectors. fixed. * name refactoring in MCMC * further addressing comments. * Added accumulated for number_of_attemps_made instead of elapsed_steps. (I missed this comment in the PR comments) --------- Co-authored-by: chrisiacovella --- .github/workflows/CI.yaml | 1 + Examples/Idealgas.py | 150 +++ Examples/LJ_MCMC.py | 165 ++++ Examples/LJ_langevin.py | 47 +- Examples/LJ_mcmove.py | 57 +- Examples/methane_coords.npy | Bin 0 -> 26528 bytes chiron/integrators.py | 155 ++- chiron/mcmc.py | 1262 +++++++++++++++++------- chiron/multistate.py | 53 +- chiron/neighbors.py | 758 +++++++++----- chiron/potential.py | 86 +- chiron/reporters.py | 6 +- chiron/states.py | 94 +- chiron/tests/conftest.py | 6 +- chiron/tests/test_convergence_tests.py | 161 ++- chiron/tests/test_integrators.py | 8 +- chiron/tests/test_mcmc.py | 180 +++- chiron/tests/test_minization.py | 22 +- chiron/tests/test_multistate.py | 18 +- chiron/tests/test_pairs.py | 180 ++-- chiron/tests/test_potential.py | 2 +- chiron/tests/test_states.py | 26 +- chiron/tests/test_testsystems.py | 14 +- chiron/tests/test_utils.py | 7 +- chiron/utils.py | 61 +- 25 files changed, 2637 insertions(+), 882 deletions(-) create mode 100644 Examples/Idealgas.py create mode 100644 Examples/LJ_MCMC.py create mode 100644 Examples/methane_coords.npy diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 0e307f5..2917564 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -9,6 +9,7 @@ on: pull_request: branches: - "main" + - "multistage" schedule: # Weekly tests run on main by default: # Scheduled workflows run on the latest commit on the default or base branch. diff --git a/Examples/Idealgas.py b/Examples/Idealgas.py new file mode 100644 index 0000000..32ef6f0 --- /dev/null +++ b/Examples/Idealgas.py @@ -0,0 +1,150 @@ +from openmmtools.testsystems import IdealGas +from openmm import unit + +""" +This example explore an ideal gas system, where the particles are non-interacting. +This will use the MonteCarloBarostatMove to sample the volume of the system and +MonteCarloDisplacementMove to sample the particle positions. + +This utilizes the IdealGas example from openmmtools to initialize particle positions and topology. + +""" + +# Use the IdealGas example from openmmtools to initialize particle positions and topology +# For this example, the topology provides the masses for the particles + +n_particles = 216 +temperature = 298 * unit.kelvin +pressure = 1 * unit.atmosphere + +ideal_gas = IdealGas(nparticles=n_particles, temperature=temperature, pressure=pressure) + + +from chiron.potential import IdealGasPotential +from chiron.utils import PRNG, get_list_of_mass +import jax.numpy as jnp + +# particles are non interacting +cutoff = 0.0 * unit.nanometer +ideal_gas_potential = IdealGasPotential(ideal_gas.topology) + +from chiron.states import SamplerState, ThermodynamicState + +# define the thermodynamic state +thermodynamic_state = ThermodynamicState( + potential=ideal_gas_potential, + temperature=temperature, + pressure=pressure, +) + +PRNG.set_seed(1234) + + +# define the sampler state +sampler_state = SamplerState( + positions=ideal_gas.positions, + current_PRNG_key=PRNG.get_random_key(), + box_vectors=ideal_gas.system.getDefaultPeriodicBoxVectors(), +) + +from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace + +# define the pair list for an orthogonal periodic space +# since particles are non-interacting, this will not really do much +# but will be used to appropriately wrap particles in space +nbr_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=cutoff) +nbr_list.build_from_state(sampler_state) + +from chiron.reporters import MCReporter + +# initialize a reporter to save the simulation data +filename = "test_mc_ideal_gas.h5" +import os + +if os.path.isfile(filename): + os.remove(filename) +reporter = MCReporter(filename, 100) + + +from chiron.mcmc import ( + MonteCarloDisplacementMove, + MonteCarloBarostatMove, + MoveSchedule, + MCMCSampler, +) + +# initialize the displacement move +mc_barostat_move = MonteCarloBarostatMove( + volume_max_scale=0.2, + number_of_moves=10, + reporter=reporter, + autotune=True, + autotune_interval=100, +) + +# initialize the barostat move and the move schedule +metropolis_displacement_move = MonteCarloDisplacementMove( + displacement_sigma=0.1 * unit.nanometer, + number_of_moves=100, + autotune=True, + autotune_interval=100, +) + +# define the move schedule +move_set = MoveSchedule( + [ + ("MonteCarloDisplacementMove", metropolis_displacement_move), + ("MonteCarloBarostatMove", mc_barostat_move), + ] +) + +sampler = MCMCSampler(move_set) +sampler.run( + sampler_state, thermodynamic_state, n_iterations=10, nbr_list=nbr_list +) # how many times to repeat + +# get the volume from the reporter +volume = reporter.get_property("volume") +step = reporter.get_property("elapsed_step") + + +import matplotlib.pyplot as plt + +plt.plot(step, volume) +plt.show() + +# get expectations +ideal_volume = ideal_gas.get_volume_expectation(thermodynamic_state) +ideal_volume_std = ideal_gas.get_volume_standard_deviation(thermodynamic_state) + +print("ideal volume and standard deviation: ", ideal_volume, ideal_volume_std) + + +volume_mean = jnp.mean(jnp.array(volume)) * unit.nanometer**3 +volume_std = jnp.std(jnp.array(volume)) * unit.nanometer**3 + + +print("measured volume and standard deviation: ", volume_mean, volume_std) + +# get the masses of particles from the topology +masses = get_list_of_mass(ideal_gas.topology) + +sum_of_masses = jnp.sum(jnp.array(masses.value_in_unit(unit.amu))) * unit.amu + +ideal_density = sum_of_masses / unit.AVOGADRO_CONSTANT_NA / ideal_volume +measured_density = sum_of_masses / unit.AVOGADRO_CONSTANT_NA / volume_mean + +assert jnp.isclose( + ideal_density.value_in_unit(unit.kilogram / unit.meter**3), + measured_density.value_in_unit(unit.kilogram / unit.meter**3), + atol=1e-1, +) +# see if within 5% of ideal volume +assert ( + abs(ideal_volume - volume_mean) / ideal_volume < 0.05 +), f"Warning: {abs(ideal_volume - volume_mean) / ideal_volume} exceeds the 5% threshold" + +# see if within 10% of the ideal standard deviation of the volume +assert ( + abs(ideal_volume_std - volume_std) / ideal_volume_std < 0.1 +), f"Warning: {abs(ideal_volume_std - volume_std) / ideal_volume_std} exceeds the 10% threshold" diff --git a/Examples/LJ_MCMC.py b/Examples/LJ_MCMC.py new file mode 100644 index 0000000..b3cb2e6 --- /dev/null +++ b/Examples/LJ_MCMC.py @@ -0,0 +1,165 @@ +from openmm import unit +from openmm import app + +""" +This example explore a Lennard-Jones system, where a single bead represents a united atom methane molecule, +modeled with the UA-TraPPE force field. + + +""" +n_particles = 1100 +temperature = 140 * unit.kelvin +pressure = 13.00765 * unit.atmosphere +mass = unit.Quantity(16.04, unit.gram / unit.mole) + +# create the topology +lj_topology = app.Topology() +element = app.Element(1000, "CH4", "CH4", mass) +chain = lj_topology.addChain() +for i in range(n_particles): + residue = lj_topology.addResidue("CH4", chain) + lj_topology.addAtom("CH4", element, residue) + +import jax.numpy as jnp + +# these were generated in Mbuild using fill_box which wraps packmol +# a minimum spacing of 0.4 nm was used during construction. + +from chiron.utils import get_full_path + +positions = jnp.load(get_full_path("Examples/methane_coords.npy")) * unit.nanometer + +box_vectors = ( + jnp.array( + [ + [4.275021399280942, 0.0, 0.0], + [0.0, 4.275021399280942, 0.0], + [0.0, 0.0, 4.275021399280942], + ] + ) + * unit.nanometer +) + +from chiron.potential import LJPotential +from chiron.utils import PRNG +import jax.numpy as jnp + +# + +# initialize the LennardJones potential for UA-TraPPE methane +# +sigma = 0.373 * unit.nanometer +epsilon = 0.2941 * unit.kilocalories_per_mole +cutoff = 1.4 * unit.nanometer + +lj_potential = LJPotential(lj_topology, sigma=sigma, epsilon=epsilon, cutoff=cutoff) + +from chiron.states import SamplerState, ThermodynamicState + +# define the thermodynamic state +thermodynamic_state = ThermodynamicState( + potential=lj_potential, + temperature=temperature, + pressure=pressure, +) + +PRNG.set_seed(1234) + + +# define the sampler state +sampler_state = SamplerState( + positions=positions, current_PRNG_key=PRNG.get_random_key(), box_vectors=box_vectors +) + + +from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace + +# define the pair list for an orthogonal periodic space +# since particles are non-interacting, this will not really do much +# but will appropriately wrap particles in space +nbr_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=cutoff) +nbr_list.build_from_state(sampler_state) + +# CRI: minimizer is not working correctly on my mac +# from chiron.minimze import minimize_energy +# +# results = minimize_energy( +# sampler_state.positions, lj_potential.compute_energy, nbr_list, maxiter=100 +# ) +# +# min_x = results.params +# +# sampler_state.positions = min_x + +from chiron.reporters import MCReporter + +# initialize a reporter to save the simulation data +import os + + +filename_displacement = "test_mc_lj_disp.h5" + +if os.path.isfile(filename_displacement): + os.remove(filename_displacement) +reporter_displacement = MCReporter(filename_displacement, 10) + +from chiron.mcmc import MonteCarloDisplacementMove + +mc_displacement_move = MonteCarloDisplacementMove( + displacement_sigma=0.001 * unit.nanometer, + number_of_moves=100, + reporter=reporter_displacement, + report_interval=10, + autotune=True, + autotune_interval=100, +) + +filename_barostat = "test_mc_lj_barostat.h5" +if os.path.isfile(filename_barostat): + os.remove(filename_barostat) +reporter_barostat = MCReporter(filename_barostat, 1) + + +from chiron.mcmc import MonteCarloBarostatMove + +mc_barostat_move = MonteCarloBarostatMove( + volume_max_scale=0.1, + number_of_moves=10, + reporter=reporter_barostat, + report_interval=1, + autotune=True, + autotune_interval=50, +) + +from chiron.reporters import LangevinDynamicsReporter + +filename_langevin = "test_mc_lj_langevin.h5" + +if os.path.isfile(filename_langevin): + os.remove(filename_langevin) +reporter_langevin = LangevinDynamicsReporter(filename_langevin, 10) + +from chiron.mcmc import LangevinDynamicsMove + +langevin_dynamics_move = LangevinDynamicsMove( + timestep=1.0 * unit.femtoseconds, + collision_rate=1.0 / unit.picoseconds, + number_of_steps=1000, + reporter=reporter_langevin, + report_interval=10, +) + +from chiron.mcmc import MoveSchedule + +move_set = MoveSchedule( + [ + ("LangevinDynamicsMove", langevin_dynamics_move), + ("MonteCarloDisplacementMove", mc_displacement_move), + ("MonteCarloBarostatMove", mc_barostat_move), + ] +) + +from chiron.mcmc import MCMCSampler + +sampler = MCMCSampler(move_set) +sampler.run(sampler_state, thermodynamic_state, n_iterations=100, nbr_list=nbr_list) diff --git a/Examples/LJ_langevin.py b/Examples/LJ_langevin.py index d769b1a..8bb38d1 100644 --- a/Examples/LJ_langevin.py +++ b/Examples/LJ_langevin.py @@ -19,31 +19,47 @@ lj_fluid.topology, sigma=sigma, epsilon=epsilon, cutoff=cutoff ) + +from chiron.utils import PRNG + +PRNG.set_seed(1234) + from chiron.states import SamplerState, ThermodynamicState # define the sampler state sampler_state = SamplerState( - x0=lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors() + positions=lj_fluid.positions, + current_PRNG_key=PRNG.get_random_key(), + box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors(), ) # define the thermodynamic state thermodynamic_state = ThermodynamicState( - potential=lj_potential, temperature=300 * unit.kelvin + potential=lj_potential, + temperature=300 * unit.kelvin, ) + from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace -# define the neighbor list for an orthogonal periodic space +# Set up a neighbor list for an orthogonal periodic box with a cutoff of 3.0 * sigma and skin of 0.5 * sigma, +# where sigma = 0.34 nm. +# The class we instantiate, NeighborListNsqrd, uses an O(N^2) calculation to build the neighbor list, +# but uses a buffer (i.e., the skin) to avoid needing to perform the O(N^2) calculation at every step. +# With this routine, the calculation at each step between builds is O(N*n_max_neighbors). +# For the conditions considered here, n_max_neighbors is set to 180 (note this will increase if necessary) +# and thus there is ~5 reduction in computational cost compared to a brute force approach (i.e., PairListNsqrd). + skin = 0.5 * unit.nanometer nbr_list = NeighborListNsqrd( OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) -# build the neighbor list from the sampler state +# perform the initial build of the neighbor list from the sampler state nbr_list.build_from_state(sampler_state) -from chiron.reporters import _SimulationReporter +from chiron.reporters import LangevinDynamicsReporter # initialize a reporter to save the simulation data filename = "test_lj.h5" @@ -51,18 +67,25 @@ if os.path.isfile(filename): os.remove(filename) -reporter = _SimulationReporter("test_lj.h5", lj_fluid.topology, 1) +reporter = LangevinDynamicsReporter( + "test_lj.h5", + 1, + lj_fluid.topology, +) from chiron.integrators import LangevinIntegrator # initialize the Langevin integrator -integrator = LangevinIntegrator(reporter=reporter, report_frequency=100) -print("init_energy: ", lj_potential.compute_energy(sampler_state.x0, nbr_list)) +integrator = LangevinIntegrator(reporter=reporter, report_interval=100) +print("init_energy: ", lj_potential.compute_energy(sampler_state.positions, nbr_list)) -integrator.run( +# run the simulation +# note, typically we will not be calling the integrator directly, +# but instead using the LangevinDynamics Move in the MCMC Sampler. +updated_sampler_state, updated_nbr_list = integrator.run( sampler_state, thermodynamic_state, - n_steps=5000, + number_of_steps=1000, nbr_list=nbr_list, progress_bar=True, ) @@ -71,9 +94,11 @@ # read the data from the reporter with h5py.File("test_lj.h5", "r") as f: - energies = f["energy"][:] + energies = f["potential_energy"][:] steps = f["step"][:] +energies = reporter.get_property("potential_energy") +steps = reporter.get_property("step") # plot the energy import matplotlib.pyplot as plt diff --git a/Examples/LJ_mcmove.py b/Examples/LJ_mcmove.py index bc673f6..09fa6fa 100644 --- a/Examples/LJ_mcmove.py +++ b/Examples/LJ_mcmove.py @@ -20,10 +20,15 @@ ) from chiron.states import SamplerState, ThermodynamicState +from chiron.utils import PRNG + +PRNG.set_seed(1234) # define the sampler state sampler_state = SamplerState( - x0=lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors() + positions=lj_fluid.positions, + current_PRNG_key=PRNG.get_random_key(), + box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors(), ) # define the thermodynamic state @@ -39,29 +44,61 @@ nbr_list = NeighborListNsqrd( OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) -from chiron.neighbors import PairList +from chiron.neighbors import PairListNsqrd # build the neighbor list from the sampler state nbr_list.build_from_state(sampler_state) -from chiron.reporters import _SimulationReporter +from chiron.reporters import MCReporter # initialize a reporter to save the simulation data -filename = "test_lj.h5" +filename = "test_mc_lj.h5" import os if os.path.isfile(filename): os.remove(filename) -reporter = _SimulationReporter("test_mc_lj.h5", lj_fluid.topology, 1) +reporter = MCReporter(filename, 1) -from chiron.mcmc import MetropolisDisplacementMove +from chiron.mcmc import MonteCarloDisplacementMove -mc_move = MetropolisDisplacementMove( - seed=1234, +mc_move = MonteCarloDisplacementMove( displacement_sigma=0.01 * unit.nanometer, - nr_of_moves=1000, + number_of_moves=5000, reporter=reporter, + report_interval=1, + autotune=True, + autotune_interval=100, ) -mc_move.run(sampler_state, thermodynamic_state, nbr_list, True) +mc_move.update(sampler_state, thermodynamic_state, nbr_list) + +stats = mc_move.statistics +print(stats["n_accepted"] / stats["n_proposed"]) + + +acceptance_probability = reporter.get_property("acceptance_probability") +displacement_sigma = reporter.get_property("displacement_sigma") +potential_energy = reporter.get_property("potential_energy") +step = reporter.get_property("step") + +# plot the energy +import matplotlib.pyplot as plt + +plt.subplot(3, 1, 1) + +plt.plot(step, displacement_sigma) +plt.ylabel("displacement_sigma (nm)") + +plt.subplot(3, 1, 2) + +plt.plot(step, acceptance_probability) +plt.ylabel("acceptance_probability") + + +plt.subplot(3, 1, 3) + +plt.plot(step, potential_energy) +plt.xlabel("Step") +plt.ylabel("potential_energy (kj/mol)") +plt.show() diff --git a/Examples/methane_coords.npy b/Examples/methane_coords.npy new file mode 100644 index 0000000000000000000000000000000000000000..b769491b98b3900c02c416ef91ee3d2a4d04a594 GIT binary patch literal 26528 zcmbTec|4Tw_dh-}X3Q97NQ<;+BZ{KZLh2k_$x^Zul{OWLNGg<+Xc3Vmm91#8R4Pf5 zQi;f(vSi=)eP6$q*ZY%xfB*dB;dak`U&}e?I_LR(p7GsosH1C4rJSXlmppRxw5hG6 z;zCKKV;dx87D^towzac8bn<|;?UAFC*S8$9Jbjd0KW%=<<|z4FYSk(^IhloPmdPx1 zTKNC(hpqCQD=~4Gg0^zFRrfOo=dEBmMp$t%E0g?a3t`7>^XqooYN2=_z?YZ9n!2u= zxV<=(&BXGwv!?aus5tKj!+u_8A6!o>&PiH42%^u!qxcp#fyA%;^(nDTyhMLvT70+% zEc|M%R~ho4uAgH3r(%A*e{b?99qhGilV3cc_0KcNxZjP_>Ql*c z_eb~aKe3j9{rxKY{NB_+=KNM3XKp`?zs&V^<6tQjCz7|0g_WY;xo*od;ICbr=tpTedM_mB8$^W) z)9_frw8Xt1YvJ#@1#b(CIxk)gmfD+T+&%l@o)yi^RhT(7{=fHCZm8IN=o1}vlBOFy zChN7q&7Zw|c|NFrZm6Ez$J*$h>lavcLqW~+$m(Yt91}DQc8nYV-Iax{K3^(9_Q;%W zyXloMJLYTkW#K-MxyMSb4dlT!GTzN@tt@={K_G5lYa6TyO;Y34r%tWczOluxjlEcS zDIQ&_cXKgVyL9<=986R0k zIw7zs`V0SV9t{2}}7*;pB<|>=f(BCsDE04@$lcRacsmJW8`ObDQp3Xm70jkYs z`~A&2A;-vTmr%$AtY24om%XbAb{KB?Ugz2g-D}?~XgPPoOTVMnO*i#Jf%nru#~LPf zUJIQN`;Z9{j5^EC6b?pxO9HS8Cuv0wjsOD9+PYXw%~n}fjb-=9OrHO^|T+-D2a}wP4^{C#qiA0sE{3 zJ+kihg5s!iebNY-SEy@^Dp|h)n$Zoh6)ZG)u#M@xj)Eto1IsOZD99X14wW9E;DYDh zXz$(I;M0Lw;%SmC5In+rxpUtT$gIlI+;fPDvIFTmFEv(A+0{9Y+QI%(VsG0dsv79s z(Eng0Ye5JDXUhF(ye-5+qpne-bE|n!Tx7sePq`J2hl&1Lv%Lyrf7V?wz(3%Wz@RoB z8;6F0t;<%6Pe50<-3zJ5JUCbxr!{w+gXf0MXS*=isOx1VYSYg`HfvypCWV9c^Aimh z?cm4OffPlFfs>_W2!#Z6<2JYBn(z8RIfidw?+!3P=Xv|Yxa!a}o&IP}QzJe~8_eqP+ z*v7^Wx-ptXSwmCv5BYjRC-M|8+SPwG(mKnF=BAfIGs!wPc%>}$l%Zl+YrZu5U?#lp z%(|398HLf1F!?%HF8;hh*Uuv3Neg^vvV0>4nLD4`t#~;Bn)_9vCa>R3|Is(?iXfUh zWvnt5;78U1%{R4+SolC{PMm;CD-@mcaw_R=h1I1Q`<5@{BE2~)VyJ`<-*-wWt`{By zFGYW|gxyrseP=Gr8^e#A-1bG9Rk2axz!8)6HVo``4&%(#6+r7H1_!j8Mxee+@a>wX z4Uo1p#7Mh}hL4x>N4C0EfU`rqd36dyjV7KqnQ~dfR_Z zTQvr1hcllR>G9xhtxlz7d-*_Z)Z6(}|y>1(;z>DQ~hmNn0>w-Bi?+1^pB5}Jo zfBf*Y0Z{DCyj&1e1sm2sk(B9TqMxv2kjv#hFkE)+;IZ@pXudIc*LWrsk7_;0&rfCG zw?fYq@A|;JGTKT;zq@0bL4UT-AT(@gimRYyHpkl&ek74JR4Iq-B-Tj?h zmpgwk$&9ZJnDS{6i?A22J3T%5)3p_N`2w{PU54Nc>)e+o2{b%?Cw!XIJ6`OJ@F+2& zGx4ay>;C+Xe{ucBktHY0#fISe;I$Q-$UYB0Z-`sMI}8gJZlCphF%Rm^61CrEM&h%@ z;3=<#-N2FC6g?~K57@j40hlAr8lA;pg|;f${{9`K`l?Hc(ZDnFj=%-Lmql8Kg8=1!L$ zu`p`$!^R(Eo(ze^>yOV3g64gA5*OJ9$A1Nv35gbg^Us5yChh-P-|BTbes7mHeOoY2w2QPYPatTw1jC%P{;fukCR+>xF$i z_ripA82Bvvo8n7J20l3E6IKzyGZoK+HF+;BZ?6Y_-yT8BExb4;neH07dl(qe&2{I< zx|uO@6>fx4(Y>T~+>EDp>Ks$uRFq^xs{w-<21Ba${jiYb{m^th6))&UtEK}B$7Jp- z9?B5FFiokVk~;zz8!kEbn^YwzD!F$3d^ZGorVYZesx80_xTTO&)Cy;+MR>fOM?mgj z*@cU1+Q3jJF7X+66q48#uWk|l|M_}C%alGz5YRJq?y!Yl9r+t7_h&BT%?&YtxElyi@DBK2v(N zzXKiLxTcz9|E6R9%xM(L%~wx9V~kj0rX#K#Q0Xi#X6m0 zwa*Qm;4*#grb)ZCG&Vh>dy9@0S1z0x-ZBm|eKMQp-{PBkPxo%&<~3baP&_xaUAvot z`qEGDmJjv9ozvO9@yl4~`C{G)@9F$Fr}^00=_(WyxhL`DWdxLFlP$5GKMJ=>R;gOy)j;0;lJlnNBdq+h3FP;Bq+ zi?Y3N;K!LG{a0AXifZZey;BQ}a34i#5+4qJITCfSfP(4zBcfMg>4@TaPKym!Gk;e=yzOZlZ@})v(XEC2Y`6wL)9cG$F zW$>eYrJk$%!CDY}k&&}|4juW#6z@i)Q*cdqioX0@8gAS^cfq-h{b1-L9UZ!&38Ykq zeO71Z!wNmyyAB?_SU0m~H9o?#d1)_5MlUWs@`j57DZ49kkr#VA zZZ+Im#*a6H6(76}sD%R*>AdjM&7eDB+*2bch>rT}c6K_9f}7}(QnBnVu$DgT6w$=N zrd1Mc4>xqeC9~hphdf$9J#*dVHH#^D`f1MbQxCf!H30l?>>LCc#WvS&5}#sJ*Gd=% z@Z-Ei5ubRj_kg5nszebV4JUTG3mnqo$9?6;Hg+Zs!*IbX%_efcR@B?T$`&f_v6f9y zSKwmvqEAb5=Jdm!OwmBIU0o0#w|wQAX*3LXE>pX{g@Uqf%&ZrGxmJk5sy(DTq=|(&O4g!F`9>hE zO7gg|Fa^KNa2LJ$W&pH3*vhfGT*NZGe6E@g_d1<1eDIr!DG!BLw_NH0={czzN|b1L z;NZx~1J8KSdYMPQ=95my^}b`hIERBHrf~)e=0gxH{xo?yy=uxHZgDm|aXCZ9HCM~8 zNfH0}TmKk*SNt4aO~>Fj@9RS?+Mv~U{oJef`7sEn)}@;`=v!>D@aa%LWN8T+6iN?6 zY4sT+rOpZHU%py#VR|nx_FkIiWi5z$fn^hxHY1>spdO{rJP0$qF{>m#9pqXaV?}Oq zaMj&UD)oE#(8f!1SMzojK4-eTy-D)WPTHl9*SmV*U1`+Xa>-7ZdGUboeETM-2>HF_ z>%o6@N9(i6fJy4Yo#~oo2S(tUMu0}_*ESe!J{T*!tOm|FIi|V#_CT##RC%uZD2NUQ z-kL|um&nBB0+*gT9}vVi z>KeuKq`n(}|1((f;t*8wxZPa3h=Kuri8)<|sMtl)(MkWLubQK(nbZz3_b$X{l~Hm2 zIP{daF>rl0wXFMtARegu)bGL=19k(Sq7aLLvZe;6`qw$g>5y_w`9MWMn*k|dlK1I# zH_zkAXX95|OEjki-+P*x(TtS1;>*$$OED!i-x7ewL?${GH*im1Ml- z+^=Kn?(<;XF9qHs&NOu1W>;B1J_0i$9aW{P1h7U>SNe_zsT=plvpo8G;qN+FM7cbp zbq+wol4VVYzK_GNCl79a%wglD=TypD+a4H=b*%P#GY*6MSF6#5|JjN9yoZMKNc|&~ z;rA^zgO0a~HNN*hAa&>w-`y(t?I8Bm6mA}+;nl@uIm&jqXsMMzs;)h#n`WW|au&~uRL&eQ* z6s{?`%w6_}4^MIz4K91v3HcdG$~2 zUBSk9u^fBLGkkb_J5RJ%YbAK=-ZB~U9fo;@a~?k#8>;?_!tlAm4dx_lhA z#=ds?zL$l8iP}7U^O~SAbjKt%tJWdtcD-lYO+l`(gp~;JB;_ zrd2}019P!tJuY5v9^ihD83nH$!u$Fe0vMGdJt3sbhkMqV%B+m!py$|Ji%A@zU?QtN z^A$h#^zL;lZ>@&N18dm6H4GHn&bX|&hk@_H+Y@LHSa|>B_rAwY9F&mrh?+^pmC?8^ zAa6g}5B6qTArjyABz!0j+fV};vs0dZAkS~Mata$=IWcw4ykQ4YR}TK$e_e}$Mn~qq zbLj86qB^52rQ~=6bguro#hAn&kyz^un-T<2TIkx_=j)rGO0D-qi&-UzKfBV4&w8tAW38}W*R7T>k{^$RXi0wG&WqokiKr^%{X563yD5nfD}a|7zoM-s zsJQ!rYE>+$j~-nC&N4F3AX~R{2^zyN{XFN`yQ@R6Oz_xA(UvY~$%|nbuVP^6gZ+vF zEFLU7RWNrVhk>!luUG7EU?MEcz4zvLJ-pA(y|JCt2d6P!>HPF65FZwb*)pvGv{NW6 zt;*U#ZMPGjjo~1~Co&s9DhXh3)ah4TnHI2owD;(~r{!>UQ(I>t*%uG$_fKv=h9E*p zM`0!z=L28AsBmJ>C4(l5r8LT+Nia((VG9ituKV#TE@h&1!>f6P{v14+~j2=ILBP*dQiYbyPwB`u4?dN z3actFb>{>GG+whjU(ZF^vCg|o7jf`+&BtGguX>@R=+C_1Vh*l8y;Qg>aRLO18x4oMX!ql(}4Vxcc6BvaXHTov}J4 zj|U6q3SBVs>VVx%zLz6NJz03AeZ8|e4WoR7Gpj2IUh?vd%GS#?eB*oUt}_{z{jAfe z_2l=%<2U^CSOefKHLfsAT>zgs8tlGj$3XrMzB|00SHeyA4Z%l#&~Rwwq^;}W8JfeCxs7JI8>d^3~7SH z``2|EXK`^3yGb#G;2qYmTyWkYF8Y7p(aR?FY<-m2FT=A;6in=xEq1;Vg5;tw=WMbG&r3snVVemclqiWmbDyUxM5t=$S z4BG|Snt7T$xJR^cZC-UhFlKnky%Ot(m<%T=M||cjiccSr*$IYE}gJY>zB;cMZ-|cN@*E(%ZHpjN2w=o4};Pvh5eW9 z*+@6k2=P`K1b5-hkNpm{!13YRY4#3-V7~B9);ymP=<=`(J{d~KFdx@jW&~#oYv=JU z@#DwLR6pbF36k$k7hZ77bpRq<0)NMsvCu^={r=WB6YwRrr!}Io8EjtgX6aXVL(n7D zF1x`Tc*oE>L-p#1j0PLl%Nl;H-?@=<_!Jcjuif`lJkSiwr!g zKF^5^v-e-y4%R)X+YX%+z$@l@F_aTLxI-iL+cTd&*uATEQ84l6|8QB7kNoy{_3lvz zc~5OYwp?63ls^=Clrw(-toYqNM%`c{(|(@rC1C-)(EU>P1i9b*x@hj_NIFI>z8qb} z6u?7~D=Uxk61x^45v-VjPu@kfCfYJEd3}J)22Wm8bAMt~lFhBbz+uR@a9ZiBCNs%e9a4JYNMne|*cHw7>ZkQeSDJ z96WQP=3do7Djr@Q_5&4p(9a_DRvP)g%^pR$43j<>J2`W~Fu{-h`u#)AEHQI8CT86& zYe$0HKi%y8;q*HSRs{Nozcn3%EXs!{qn-RXH(fUCBYDolx38O?pJk!b_0gujBo1Ea zm?6^IG6WOq(t00AzPpg`K~kt7!S|xVRhUC%o84@+1d}hjp0MbS<4hMJs7z6*2Cw1)-0UwaQ;N%BRXCf^kfeZ`;`lP%y(&4 z3AkoyKDsE?14k=nexBIHMwu6tR({dV;Jz^XE1yRLM1^F+a6cW>)c54X74RaOoohSy zoDcm4m47wH@#CWlK`G}W`=;`QJ5Ial6MMTs(u)@A=tai{E6cNw_7FV2@hmP`+z9(* zMvk0cB8Vpiv~(tMsr@UfH?AW2$EKg1`P~s5Jm?hgTz8@x8pXcLl{t>U^)Z7V9xEmw z=|G}|!7Kr6aq!8xD$)o+YwZvC|Dd3plzq$u$-^C}fwR|Z((%P$)%lVSxnQF3d!Z!3 zkL3BszkjUbL6^G~DJ{RcV1A8kzT6@zx*WJjyE4JXm5#Iz2Y=+j?0ls^;}4iP0pVezJ-`=|}2@|Ks2ES~;fY`YGsI;Ik&dgdYculMhH8_6$zSv*Aj2WfN8kWIwsZ&8_>eff_Ku;Gz9&EK4xHFMU!fm_3S2d+&J4rFNVSV$ zA`P&8bykjfYQxm~{=PTnzKTl61_~1lRY3tNFdu`y;JiBTf%EWmu`x9Pb)+FD+ zsDe=M;P_Xf6zrgn#GKZzhPkuCU&QSmf)aDh(ZEFl_-A_hsK#3gUK43*UPEvSBfBjf zF%ExV#nbIO!C?R_qsoq+QRc^n3uYGIFY=?lVweXly9QckkA_8t|I4q^_iMgunJt6^ z6iJEgM#E4ZL~%H}g6Ds@2C3_GwkSp%e?~#0#w$lHXY-++^I<+AixFVHTe$bmgaF?D zBzX9XDji3bHGNpb;>EHhfzkEGqfn-Q#Zhe+$^ZV3kNi~CTdV8FMvVCxPB*HD!?xp> zC(lJevx{q)Muc~4{^OfxOYmT+mi5!W(os$$bVTI&7$~Lvd2F6s2hSF-7fXw32S4ul zEoqL0@OW!*M(hVZ%&T4%@h605>OB&wO~*<^TVPYiiS!)@1aS1*G@%V!nRw6EPV3Mc z!cR3B#|dqtp&ivB`Rb}hSb|Dao`HHe-yF7D(4q=#t(UO#Pql+F<;AB{-VgF8vX41tvoW8~#{2xoI_5?&hcqJ-2% zkAxqj?w1)bvFxjaX)j+r;k+hwpnJUQ>MLw~e&Wr%#)sLE^-MEOd8Ghus(EF&;T=Ef z&N5{jai!o6zU$lGc{7kA9IzrhYyiT{Vw+DN<;N99y?1@iw!o+3hEIBwDuH$6X!DUB zRCLrTf9xPI0p~AmFn-g_gWhEad7PX^!J_<>`>on4csgH0Dour;CT zc(8Q)-X$)tn3(eHP_F(-77ACrPnuZX2fERXf|GW!_LuL_)`?*lxFoyLQH6>_HsaEA z-maXRsXX5U%V$SUwT|j?HomM@Ei4kHc%d)j;=rECDRlKDhX8iFz zYrL9)6^0f&hqI{op<8fTn{gi;RdKni7}5h5=Jx-d+c^R*i)XEBy+pyJyPr>MlKA}B zZ(o!ib8~va!WnGKW#^Cef$Wm#X)m=qz&%yms`t+@Y^_RpI9V?w+77~C&|)QSbIQXL1H35y3d%Q z+fDMYzx(2>r{Q@-k%B1;hbLUJho|Dj-*N3pCj`S*5|^IMDl&b|z=-7+gX_Y_VTJM* zbPMN4@$=b|O0WOrrGGy+GdlX|$S*GTOLiqpm*K(OS5mi|;)bAM?}FcRAM<0n$G*~* zdsHOZ(?`jXCYVzp5)eh^*BEy(bQZ~uN>1XViObwaLyLIODr07O zoeBeezSoc+jS=w z6$jH%X`G)kd4eW0N-N9)`BFXh5RFVlU<;SAqj~;jT zw?W96!X;vg-C%bz{owg|KVW9#$<+*ko5n58w6IYY!oT@e%;v9n!Ht2_c;3A^Bi#)v z_U$^f2HU}l2EBjF8JWY4be<-<=IE%~{+d9ZoT-KB5IycQOyvyUF4PvuSL+%_Af zKH*1G)l%i#`V8D~tk^gCSP>-k;+Z6_0D3QJ;65|x1Wn`IZ=EhY&(F0AeU!?KByAZdxci7{YdmFTXI&|ed_~y($YN}WaA&60SKaJP-hu5f|$1K z@6xvnT=L?1(@m}rT4xF56@DQ)k^?APnn3cAd2S^V3tJ&P{Ft$Fa5c2cm0E|JF|gLX zeqWs$AL2B*13b*Ssrca$8Fe*wCl5Y7IuVdc_~xh6nyeNtYk~Y{mJ7O5iGHWxxoU|P z3m5qQUiN^jlMdxgkfM4YIL|%&p`OHv|Kd$G$tQcKXP-=BVu2*jPUVhK@No;dGXD&z z-@6s&(PS8S?Qm5IFh)V@^}=%3QU*pPlpc&QW#hgo+8Jx2FY-$W@uY6-(EGRw`CQ zR$cUsD55(E)=Q|JZq)_JZqmEH&Je`WG!;A6yk6*fZ4)kRM47VBAa#Z6l08k};~w|o znjyjeE5&V($Z`!Q z1FyRd-c%Ptqno{3lPq~qWMl47#U&nmV7K_Q!J0yNewXq(6}T8-^0ak_Q!`xG`<+qo zj)QnSVy8xT^VB{XsGe20WZVO$H+^4u?-&H{#TOLD94WXvZ^3+9aG=!W)+Xn}iqbo}(=)uC=_D!RS=bi0n^mvX22o-ZuyfW0j@T_^pSz3b-YUjt-c zLNbo;xWR`R5vx5CyC|54GJy(@dSG@G^OU>(7%)C=4O222h5zQawGS@$Y`Qpk&U^8B z=%$f%crET;vURWpd^Fa*zIu$*`*PQOo4dwA?{)pOwnrS?y>e2iF$(^n1|xQ(OgtE* zFaCZp$pa_(x(75IYd8CXRFW5jRx&+^9nNftc|9$B6fTK2WQM=O3ZabK)| zz$l3m50u`W4!3TqDLT?jMvl) znA%r)@yhxMcCUC92uo`e?YrL&Y4+uXKA}wX;Y_F{6Fm3KUH8V-HiUmEYKxtmCls={ z4jdsAdiD9sHZs1~FT0{<6FgVFBsDa`@&|OyiJBek(gW_#N=nS`_d{fa;SS~FbUf{l zq<-oh4<;{7^?ZAD9LCr+?}F2Ep@xw*cL#~nYtOD+wNaObnpf8Ot#zvcfdrO{`ROiD z?^e84NZuoo`!XqUISu1|%pwAb-X-(+j&ubo(S_vbu8$`3cFq~^m!3}IL{Ngvq@KWR zYgEUmFc+=Y*aXs-voJyQWOBoL#*{zkF3&AHvWJCgTIQv7mSgZo(f2mjy%!=rZW3LQ z*8=B8o~_nj&Vv_1z{-Hcm073!UDAzs@pbsAhzbcVE|(4{+)HQU8PV`LfqI1R+&|}x z=}TTLa?Bm8_xi_k8p&<^b%OAax#61wy-hr(=Y$!#2vR3Cf#Ys*bRbZw13bY^lN zmW_TC-!lNms$UIXJURrwxoeXXTUeM~vFyg}IeaL2uii{PoQuk1^YIM9n<_tYgVxZP z$dzg3PT~mKbcVteO+lQql+XTpUnfjg3~}q3px_?0gyNLVY_zT_IINYy#2C98m!fhS z+Lk}!1}G20HHYTbPmx3L?0A;?YBCPP$17fRiI#(wU((~C15D(HvbssUL2jMtwT6uZ zS1Woe6-(`b9+HlVYA|qgg_hfwt~!`W(m;83ADDe$z0l!u5UzzNcoQDm(&zZaEwLqV z+9!y%!@CO($%$tgAL)fX;bLKV^Qn0A?2i=FwN#WcSlY5wk`HeQ6;x0Q`e2&2GrpSW z1+V>Izf81_LC5u&xqIz8VbPA#jc<-K@W6NZ+9pzuu5_(;og+xamxWQ&ucirN?yBEy zVPfrY_xr(xrs>15<8~u^ZxP`Wq^uI>u!kTh|7CUqjTdbM_Bn{(W1!j$-}T8Jy^tO^ zI#=-aZ+IHq=X$BI6@IFz{r+@`hUd5cx_g4mf1ae9gV*$KaQ*PrM}+VIBubX6Z|{UT zvXbPa;GpVnJ~98UE*N}wsp^S$CCum*aZDK$#2>zK0T-`O(LZx!TWT~D57Sizmdg#Mzxu`2xd z{o5R@{rS9cp(q#2zC?bzqsB!U{vT>@@1{e7g{pUljNFmcNx{tegRQ5DzN*g4 zNcqcEE~YszHJhBzKvcz@;#W9VjX|kr zuThhD9mG8TV`baJ!FUDN%BS934EUH@`!%c_N^h(8ts=bai>I$<=aRTyP*phYI!$P5 z-i*3w>uDsOH_i*l6d-t*nu=za4t)sJc38JecN>F8+h%_V5o02Y8e$wzbdn2JDzy6% zeqH*Ywi=%rsT*Wfz9@T)!ZS7}U zdw=!`;yS;Qo95EIIPu^Hf1nZ*drsz^yeGwjZW3a;>*@xk{Q1J~wR^a48X!Q{LC5$h z!7;L&$5*NLz$5+Ln}5WP0(00f<3d+AsQS4t@_XF{9SdHH#C#lr$8$fhU1t+rlcuSl zM(`+HF1XA*M9y7Xw|SIVBoEBZ3ct4H11~cE0~o2eVeZvtWe+COQ$uxPB}T!3#xzds zCVcIm#OSZ&+_L-gaqHOHVOS|FPTAH?$EQ5w{)(IE=x+MXKf#rTTesVc$SQYDvyjNROo3B`A)X3HXp4nT2poS`2N&QRAV%jOc? zcAuYo$+sSWClmcw_p#ALD(d^{TQy*n+Z}L<%Y*;br_0;nYmWG?mdGwJyV`i7>k*0L z13z49%?VHRW$6XVybU*~vx^KIQ-Y)onLc;|7$eIprC8^c6+@OFC7@(XD^*xN{mo?Z@)-XI0?VFA4PG4r#N7$3F}n$mjX07#N! z)J>%Y>~mY+$Ey&YNi3;qvk&3v{x2V~sHr)5VoM*~%>5YQOZ4w@r3rHRH_IUKW22JY zc^dY~JeTGlVWYu)so?N*UZm^_+u9I61Q{y^rdP~qhZ^~&y*G&-!`#j3j>^a{$dr`6 z8Xe2QjjPt`pO9_@0X_e{6Fr@O_CMBt40Je*g~NcXHTg&AY&R{{+2dbI%8=5`EU>G&N~B>@rJOZqG0j3~ANJ zZso;UgR*g&zj&wowc)+uH12K=zP$W#v+HXPnnhm@s-IpDpW{PSjxYz{lgDhCn}gl( zYuCD@U-AS`lz%z=UASW^56$>IZQH$}Ziqjhu;W(P7`)xM=O)8|iZVr8l#1swQOM?b z;DFiy{D=4N`?oG0(U(?DBmDwIZ7MjuM*s(A#;MnkbMs@0(USvVG*o*3c3)U%D|lBv zymBLhjY8KKU0JGM3@R!LXZwiW$HSzRd3Dn$R1Q7(VrfEji>01Bge{3~ba1IV`&)qx zS_D`8WIBIMpK}+y^V{yc)1(4oEeAfx5ZwDO9u%{+%5la=1`f7F%1zU30*2F?^Ud8X z{GsWqmuo^r%UiRn$E;`=JJYLv3B3-&w8vdGG_!Hj_G@&N&qG9407_4Y&fmQG%*|xN zL;rnlQ24D|(YGo1y;4MHK3V_ezYiQ;!aEMh_4AablfO-Zf?pLlbFi;g_sqpgesq*d zK4u-t!SjE*elW;>D3PpGZw)Wf5{?)7*^K>ff$W8HE&f`~`6mH?`M^S-z4M!pvtKoJpN* z^8Niub#X+eWxNsViz#Tb*3RfE!F|k1WrvRv9n$^+-;i4u7&yOedSElbZ-@4h!j9nA z-ugLZeAB8xWvAB?k27>!?5#D=;XysbiG9?bQAKnO%vOtT`Cceq6dqAz(glzFh{7a( zWNIE3YqqNz9QwyU__zz14?iJ#$3xjpS1vU{GUOc4Ko$OC5Md+_yM)gCBn!DXZOQP4s8kfdy6h6kIh~_pR@TWo?{~a%6ut|SJ@xZMq9egO~fND~}rR+E<_8K3t6 zW5li_cGRKKqJ; zPlShV?0+~2vwg~Doh9{T(w=Qs&-qbMH~;k`A5tf7<5_u;?*|JD{9Z&woEQV;!jG5s zZxF)ZUr#?5>t#X6iw*TxUk<{KuIYzwEUt#DXKKoB-=?GQ_OX&N zSs}}wJC}%lyrT6R@8`+*m21h_yR8c* zgtfgbo9Gyx$$Rezod;(gy<`ZKKG;}TA2*rjL_SWK`GwsJmM8C&+FH}GZ^Zc8_>nUB zx>@VtrtdUl8Xqj46-xBDDuWVDK@TkX?&!&wF)XH7JqfyEr9>aGuTt- zbnX5f%X8-wouT*&t>ri9n2Tqn43wKd==8Q7`-)j`v=009aS?@z1zi) z=H8O+^XD{y#-7S5nYjY^;6t&TLB|j*$z|NPJJr>%Tki>w|r4jgbYp4E#N( zW6e($;_q|cu9);2LCv1UCvS{G+Nq6Rb;Pdrd>8gmSV_SnerEf*BTSreH}X-W0s{qz zT;aIqzx7&v{z-S92?Lkf1*8~LnYiV%O3)_WL3sRFM{dB5i$%M}_Dc|or_M3clf*@`}yM)uY!UV1s#tUcz!19@?W11 zl^<8$)i;^@tqyj`8@ELho;GGB?c+LLeyos=OjmK|$3&kS6zz+hAaGCmipAG4Xkq4R z#?7H%)W}|w$klXI4Lmon_jfnwrYycXe+%iuNK&>)C-3vwklPrmF%By7^l3f)tyBIr z!1Hx~)^8eK$TfQ4B1Ol2C(0$q)(lVmpIH#Nv*#!aX}u%kSwC7KenObW<5LCiuQ*la zNU-ts!E3)|vp87BBl=`-U;)%dhW*@L#6;Wb==tB8h>mRFjYIxm9t4~k+_vY=I9xk) zmT{ZJbMNTZS*7lTuaOCwc)hC`WDLh1Kj0;LDSZivyXq9IYgL(D+w@+Y8sUA=zToF= zgdZoed-*$5Bx)@|!O|g^SLEhsc(@;EW!aqsrNfA%PwIr!Wolk^?vIDPwcgo5Afv!WF@3ZlVe!_s3p>{q1U<*yy2c{MZ%)2o3otQI0p`iQ6O%98|+a6?&BaIoqx(ZQ2_I9Y_Du+>Vy3IZ5@&r3=}^r{8mlB0aQNli7p{JvpUbR9fe1# zU`4!Ofh*BRq-g~1_(J+&eoja|0ACv3{A@llx~~^TEmvx&5#HnP`~&7}sb5F*Z3)!S z1zXJd@R3xNzUDSMQYIS&PchIfUqmJQ1Q#_l+J|RtC3yCO2Q!}U>jLw9>Z9BDDfr1u zTxE$p7aPT_&RHmMaPOhv{i+le&b;rugGus26VI*m&BTvI($;6Z#3~qA?%Cs$K|>wu zk_HdcUQl}RPJs52{U{>=6D{#f)ERx|DX`*SUoOzZ;oAOD*l_)~`#NSD zh`(Dml&{N!MQ797-vy*Y?w6P_U81wENqAoK!fpr*c#kjINBrBK#@k-^lOKKdg}YxX z?SK>0c50W~5Is^tM!F;6r@|f)VByFFXv-*u4~P)GXR6lv8<%UQ>}J(#zF)TccrlOK zE~cJO$5w`+afW#nyft~fv3=&h@%dT5O<1+54d#6EzrCQJ)bTrSy?gtX^S}OGYCYuj zm{4-#Nq@nE-sOc4IOse$)bTlxhURIzwckdwP{eN42aUpEFp}|dewoFOTd}hA{o7WU z^L*U&f^Ro0=(}K>K=`Xg%bYF*3w8tZZAbq%O%CdnJE_sM1#t5A{)5HfAURy9)60u) z-gng}^B{w4k@K<_8$s^&yx|S56R^e^xvIM5$dR*;T=$*d{6d;KDLi$)I}dne^UJ6R|* z)4*U%lLwUpX38d#aTjTSEUrH`1iqry(n)6tK98>yQTq*AQLOW}X}J_7Gd zzM5^hUk;KhCmW;*|8T20rb3QBL1+bO5F$G7|JG5z3yPX{ z(w-6i=ESpDp`T>^)XZw~RaAL#_nYz=Bgl_!^JbLG+O&hfWFMw575~GrGrNHmvdZOx zGpQ5yiiSk|CjG+O7OTG?^D>UV>6pKUg6|3ZQBwA=PbwpdwO-V)1 zxa8d*IM`7Q2501BH>Q_CuK-DaZ*uYQ#_T0)^$8z4c7NyHQYOx0^4QV+8$tE7if4n- zFK|t|u}FdVO}@V=$M!W18|9B#z33|gnN2r7p3)$FWe4}(e@prar+t^WrIgzRO69er z)3^@~2M#%!a!8+3Tl9yOS%mjmJRthPfKBRzU={s8gH!%@SK(cb8>xf;>qiwCfp|++ zldDz_1m1f5CftdI2AA8H>XUwL&(KBZEWHSS_sUsSC!K>b@+Vx2C;s_CyrhQZxttNO zbp9f5_KuD!!aJV^^o+uP?eho6z}(tDbJuA$R`06Uqdp_`NBOX!#4I6PN7^8oR4c*r zOk#v2>7$!JyKZy25EHMBlfqY?=s539X2$98`8HnBar!rIb+L04M>%4r{mD*0qY5F{ zOZxulZyY=^(`c^%AE`%kR#psq6Mm+B-{~;>F|hW~HytE>(HaA@cRH&Oy{@9;?z`li zdVaDmsJDFRf4X*O3O;wz(!81S2aM%|g9To6LuwxXiifM{Xj$QtzxdQY+~M!J`Si!( z!#$Td__ch2rXbmGuQlCb0tBDv+GG_XpH0J-kMW469hPBe`lx0Ui%VhPf+cMoS zJEU8c@?!*EdCzE_Gn|LM|_v#tr~E+GByhL5kC zrF27_vF>S|pn6z)McTUWj7Jh0Vjl7qBLss+rlh@YqU}vQy_d!J)Jah6d zd|buFH{DC;Ob|PKd|e>gKD-5X`&>ODO!j~Fx*69S4|T(=y#wllZB+bT>oD^pk04$* zkvOh#mLK0VXP><8(FuCS`*|!$KfA(Y!8J_yZy~kJKaySl`~Htb#9ngqqVC9ZvHZZ@ zr`aE?L07k+hV^v-Y6lCKoM{3`zTLozm++EC6>Qr>w@6<%C#IxDpXgaficU_} z=Uy8hU+IhCqEST2^hy2F2mX$%jNAY4U?o&O-Qe2^D^|HB!5h*)8^8Nz@40`vN!aw|{c79>u(Iu&i@*G=%oir}Kd4L~ncZKqr{V0a@-Le}LRudiZ zM7w?a>@FCI$aPOxTMbHQqgD139rm|Hb2UUro>-tB6Mb@+^v9mkHS=YWzEXG9R+@nz zddw^-pMH?&cng(`g#7q$ex#|Q=w&9hz6k4kLU?FrtFhkDUkv;XMcF}Qy#qgbUZF}3xi4t9+!aiMB zhc~2dExq-=;Xe8PiwE&v7#5u{u;3mycstPxr#{`Ar#uE8scY|P3-!ZjwVwZ5()VJK z;PhZ=F9nrE9lvm%zOciI=&IdfV$P8Mqv(FM)!;+;-LYq?&WqVN_OheEij32B zTUQs~JwALEf5OAryIHxj^S%@3brt)`ON*YGrO#tOiXN5>U(GxhY@tF08 z^z}_&QP8)bZ|Zy|Rpjk{dlQjyXNKdE$~p1 zeJ{6|j_c2lWV;g{+<&e5ACLA4@cX=zbO3e$B~MbPyoDFT>0z1Cf;{-8)gWftRxS?r zsan3&9fmg);qLl4KK!pwB94J0lGBAJ=RMxw*lJJu$eo-Oj^}O}h1cJY`fTl|p}fkC z%JkV>RGYu+U0ux>Ov;lJ;=5t%6Hj?s^gp}Xt&`e7=OOVtpxDEn@YjAs_H+OAzrHm? zKkBZrmpmw4@_9vy7x}&ix9fG4;Y}c~Z54Ut8R_HXW~isyjX>C{_CxOE`+#gG0WE@m zeE4XQe4d=IMYi^OHvg}vGmnS5`{Fob8D=b%ib|zLv`MH$b#6sU3ZbNstXYy|DQl6G zJVZ#@Qt^bchDxhEBt=r;AxrjM3^N#hXL^3}r`KyRbKQHs=YG!T^VUo!U-;*{QomD( zfdFGiyaFc6b8{njPEyE~Ox5rnjb6y*Q`z@0MS#S5F3mPv%1^GKtSW=&YMV~-wr9Hh zB+h^R=bR}fG#-|3TvQ`WMyzks^H48vJE3{nq^%LWSNe{4n;^%>VD}YEi6;1^*_+@H zIszLOG1LA*e^S|)tarOYFTgC@iYB`m{?Mz}RGN7243~uZA`F}}*|7D;m!QxS{A9>7 zTZvnz8Tt(>j?&RDfGO$X>v;c-V-|6>r1KKC?gqU+tQ)pD^n!g@R|Hw=j|qNV@cPb5 ztt-4_)~DpZ3{*skl9*8Y!c{{M@8^EH+a38HRSR!k7!V=CD95gIL?7f16_-}UpKyJ? zdd^$?{e`n1rP!UVVf2!&O&?eV2qFPwGgM>Nh9k zuPmk{KKB`d!>Z>sJm<|KbzhdIKCm5z$DDZQ13sc3%%fkwx9`vXT|Mw{G~JquY>7@@ zoqw8_)G13OtmqhoT^ygrS>zNAt9ky`7o{dst2>(^g?p>`ramUTUdHOTN)#pJ{g=O8 zJ!k6S`E36F)Jh(rYHbj~@m(AV*BVm9bGPIBp*!IU{N#!1eF-VdZ%jWIA8(jbdsdE{ zXpdapV})q>x#s`@xcr)XSyHGYTa+5HzGnbf(185;m>bU^>A+t79bBE z>Vq_%yR=8urmpGyWLKMGyY(wB@|)AOi3NL8jVg6N^~1jBpNl=^xruR2%6q0{FWAZ3 zcs^675ZY2M8IJDdpU$GZlj9>71A6PfF2nUZbW$OsbLRa{zYis$Q*jPcKjBqR$NN1S zry%b`i7hYkvP$0MhO82pf&5pw_z6eKP4JGj*jf=Xv(STYri2a@4c;mx*Y-F9Q(MiSlsk0;y87ovG^u z$=2D1Hy!q4Jw{xBvwS~1v~D}?Ij0p=E#}U9o{Z=5wYi*WEZ(#KfA2{cquvl=bwZ=i zULh@aKJq$xsC_f;v!i}7F?&q0&Z0|{%oQP_`}TZ#yQCjJHr_NLH+x~1S*od8ZWlzE zZ7H6s(GBw=Y*xlz=OH;)(pK7_?pER8E-!(3fS#;_`%juq!Pih@krTOnP_+Nz{No(o ziKhP3+w+CVhTHycx*rZfz5pt@#=T%-Ip;R#a|X+X>FHE;Zrm4bezme>x{E!f|vKfbn#f6F)s z^F$smT0WyYEfZDjn)Me0q!JFVtLYhnR<-A})jmHV@aTE5;09q*c8(=A4|AJq^0R3p zrRe8-`L;d>{YJn2N_fsn3a_j9ucZ=WGkVY+%m+-bXOo4ildIJz+|#@Jp;evfqen1E&n#_{pT{!nNO8xk)p>r?&(*AFtr+ zDmmAI_xbOR7frnnagkDIucacD6JQjQ`MyG&kNh~iG^Y^P!|X`kiKT`(Cuo{d>_`1E zhj#(@Jds*R@9m9jIf1_L-fDBNIAJ0fS@Vq#^^!4#wCx%jc*)=AGA7H-z5wOLhlEY| zxkgIlm2VzHPJ!K!kpK;FlScu#2xVgM;n@V_C}Q9q^B@k~ zcNqh%&G010B7U}O51g#@4?8_TBPz-F)-(G={>$0B;Qg#Kt z^9QjG6<-kDW`6${aMd@xW+?PQ?DfV?aV9(@Vpu*$g@wKoj!;+6OBO%KYcZ~+lQ|I+ zJ1TR>q4m?C>*@=`ux<1f8fqi`NTn_yq-X>h6WL zVveI?1J~}Nevn$#v}*~@ImW1)^*H|U5@|)hSZX<}Kjo3aA}q*jLIWVbAYoXYFI@gy zlqg?~Szx%l?@ygsk1>UxILB+_vNGq?)#>%6{3^Z8<{$d=I!UyDe?LSeLPAIDyPhIH z?d7)zdpgFzg>Un&+xEgF|I+bWs{80fEabwDr*+Mc@AKNE0J$6|elhIs-fn}2CRZCn zNdeNXs4l$_U-y0R;BY$jwQdH7ui9Lo5vK)p9>!RAN?r6y?JUQjTHRGMEt5tJTE9)* zyxaf+r}t$q!TfVWc=Ha9{$-vk7&F@zxd7-c1=MMJ4hX&Hocp@Z`?&VHjRI$Er8G<@ zPS+II`o{78(alc24ofxELSNOyJMLl3S+AL^D0Iq1khs<8U$V7fLYdm^LwSl{;TiRN zn3}XO*&6iGOBUy*T)v0y?TH$NuxHH8hK0y?3}%&ne1krL`3UZF#@ww5sajq<0_g>? z)l#qzp4cQwwOmA>+UY~<=EyNY#wNT_hK+d+OUKpibo2otHt6;mVg2%lVd}COZf1!U8d=xj z+$>yB@MKCioCuTt)VgyVayKN1_ibasG^cX9&Is?TtgVTefE#%248-Tqqu7ty^PTBr z3t75oOX@f@6kPHiSUUw>7{hFD@BY(=yyE|B-kx_ex>bTphqmiq6aV>d(Wkz9h2mMK zbo4We-7Wrs{=Q^OBi*gX`(Z_A%i%-lANMts))+wE!}Z46mi<2`!R3hSn?{--X|vxi zH|wt<(8xWQr=f}S1hjd@Q0V0I?vkqfIqhJS-LO&|bHQc7icyv|RI>M|i;%Z28{+aZ zYi@M$66N4~vH@6sHkh*>HhvQ&Zh=14mG_vCj5a;JAKj2~;q^8PTLx@bd3E*^o-ft> zHDG$G27dR=Cy+M1bHlWV60j3xMOf29Hi5@~*?AO!W&7SzvL| z@|Y(RGCu4-+Vq(P_Lq2`IHN!53dY#|Gmr-vcj5jUhZ+9D^m?zp?sh}(0p{bgEjr`Q z41+KCrx<7#Cbx`tw(Pja0^fxKVxb8%5;^bu-uAPDP;0f|7RR6bNQM8EO;-bCa_rL{ zBk+@Ltmb=*hp6l>^A+i%6YliOqW>cQZ~wu%#<9Z`viI%FShcxaBnGikoLu0lPR6Aa zrCI0$AJB2c{Yv?`(N@4+h!pVU8I=vs=pQ|T$8DH8{6xgp_*V;Z$grxs39g#aS3Q-2 z_#1BX{;3ZG<>OCM_po3`pkB<;osGcU%&rNUTLlXh3cKUWCSmNXK}gpc%t4v0x>JDu zq54Z)=Z0+95Tle;;hfO``dOpf)jI{rigBqAmyo-czQrSRXXw z#O4YS4UTR79Ro~b&B@5Z0a&%&+UgC*f4(&GtU2;5Y(3|dhw#x!l~br)oCzDkbuA^_ zc$#7M!&BlM-ISt{tXJraoHLo{(Vt~dhYk*OUV{3vGXLXkVd&r6eoXgfL@1T;C~?tR zWQXDYS+OTd$hi}7p*-tnpZ)Z4pSDmQ??V|_2i+N3fon3sV?-U=&sBaGN zuNr|A>zF39O?Z!Ql3G^Uj@&N0y9-+#$3dNzewzO&ofJvY*Ztef0twOntr@E*xfG&LfD|w{HalMP>59hi1c<(8wG)b z#gD?(s(>lJ(|py94!FNQwrFRy5J}2CAJ|=woPsxOU+M@C@pxvMvO^Mc-2d#xFMsH!_OHDoeA&ZBciw!I za2)+M>)uCvEuW8{=l`75$$*wucdoLadV$Jk!RsPqf`4RXYk&dZG5*J`{|N1O|7Wgh4Z*>I=^NV z)HwC1e+&kc)ius`;@oDwcw!u$e?z)PH&2TQdBtRpRN&XspOZl|@L2wI67pWuhe@KY za{TPYGbu(wB-nhVhdAPEo|D2DUtM!4(8j~Q|no_vvXD7U#Fn#5eEkLA~PXr|(PlWo` zwEiuA-?p5Z#Q7~0QaSN$sCFaz9JeeDX88ybzg8-VhK_)wo?*XMm~!9T#81U}vT!HsYnK0Ik zvN4byQ=pDUkHTeDvoAM)U_SF(vmO`v<$1egPudG{leIcQZ#aEwX@ZYTC1nb(O1Jho z3-OY3I_|GtpXDNRuOJB8ScK^8+f|&cGX=bUQvyq#^?}X@(c`V>`H7aHtK@C;JKVVK zx{=dA40c(Ca_)5*D3etuGa%7FI@gD99M(7w1$Ru2z(9pnkdaP3+-AIJil#F_!`rYl zI2P-^t7F2VO)aoF8HKPsUPAw-eCIF1gc4)_)ujyNgGCB2iyasNYl}^kNbUw;Z7H%C&)Hm0@KFAMyYbK&0NQCxCyHIeSPWN%9+m$>^7@mPMyvW1$SE~&ef+v+mSB@^h=P-xM<;PDN5|w+*Rrm=*?nlVK zm^%$|Dk{=t(aCHl=QRtYD5PQ4mR3vjqlJ{Xt#(;WC;C_phl=@$4&RnXj>z$oHAx?H z!1E-G*n@YIjQ{Q{$X)FheL&TI%S}{AZJz7d&2WSC-#>^vG%*398%>5%J_(VpZ!)s% zW8Q=PdYOqwdwI!04la^13|Bl{X*{@J7pOK#D@X7X^@{DGW@7^|r>}wkAN2DZzFZt7 z6EMR;`r4$BKWdBT=^bs(W-8`o7Buzl+}rgZSG<>(#K+%C{kk5{XLeDoKk^TQ26ak} zwl>2G{xhe#H_vb$T2|$Z1U64X=mSw>omUK)994DX@GO3FDV>o2AuGIzh8)K|9tEKy z+D+hp4?@yLX{36qcqp|%h}^0VzjC>hkK`VD_(dUiZ5Kx_R+GmFAJ=Gsb`-=X-^u5V~o3H+Wx8BiPPw#4kuLu1M zd*<@hg??X@#E1+aujksY|r!8$GV|l?~!=36C;24 zvT4-^HAhc%!uop}yJj!rAtthXmmZW0kQ&4{J=}x2=FIWV+xOXUq7V&=xmb5*y-K^i z6MaxFee2|FMMy>pPmxkJ&L4aIuDmmBf<#vfT`!?J_*?$i5vL7o;Kag>eWu{IPt2G` z5*kmc3-gbHp*IHC=#$Vq5%t8xYY;A;_2v7zpNpK?p1ydaBj)eAO0Ca(kAtOpp@c@a z2)Q3>akxLD6?%di97T;=pvYa;@{y+qdF$^w#ERu7BjWN~IC-w3)HH`{#f=bO>E<9^ zOCeF)cGss_jlqgsz6c85CvGbFXxZuUkf^(j68XDm#Et&`G<_}RGteF@{tpZEGB>8L zkVE~S!%sdx4icx+pKqw20PF&__R4b+*Q+7JBA&?QV3mx1``iPT`^SUSnG?{Ky-l?5 zG0s1mOO+V&8o-yCelmGdm_&zT!928 ztZ%_lVn=87^OC@GSCv$e6EMB6r|TvMJ9lkz4A$ixHH){L0(Vc3?+F_iP>uG_e9Zrk zZ6E(Py$t76H^Su-5B0*+cKzU|xu~B@L{tgx9r}M>gNu9%Iq++R9Off27RQr>TxX+h z`42o3pnb)3Tx`Aw(Xx;a&BXpx?9?KkMC9_dvED9RyoaA8+tukFifjbk!a250!EJEb z!X_YYauA*sP1IG~WW(3Oruo}ZH&1d7$yLRE#z>%`zV7M>WLcM6+E$8?q{u_tzR3TA z-#qUp$fN#n&vl{N#GgK8%Nc6dP0S}9F*Tm|_k$U~ir4C>Wu2HK51F+g`kW;bGJpI% zFUheTE-5HD+#ei=n6>@d2Qe4#n1lf|E*2Ol*9)^nFt0k#;8i8|<1Q9=XvwPR$87!( z&(Q-+V_&jt(%5jv@9=!piVk?|;xY9V>$|V0$@83Lt+3raY{6Xg3;)(_C(-{MOpCEc z-mSFC8Rv3sJ`xcl=3#@}p*+dBxiJ}oP%mPm_D)~`O6Y%Wi90t0CFiG>%Gm#c-}e@r z7ph#C_>qKjg#$&)T)AnO(>`)iUjYB!F8`K08)%9Un>9=8&Cwrr>!#B#4$o~h+U*+_ z`~+)G{^SnU&4pPXxQ}E0?f?9g>q{4yk)kP>cQGMs_17xckY^h_itm5LPiicue^lA+ zRD5R}FWE6RZ`oSRBL~n<<(ka!Nk``$UG?!KgUaGnQEDw(cP5R-NFQu%{X7#za^P!$hxkCKPFu@>m@ z`>dPoiQI`j;Xxab!?n(Io9xFAZSZl<;~S?V`k=+FZa%pEb4$FB{$@H9IF~+*O z=;-H;xtsqJ^9J88@!1bM%YK!xZtsFhN9p{uOTJ7g^g~%-Xr0X30WVFflaSQgvd^>L_yJ*jFa5$LL zU;BiKQc$Y2`y_H2IwP`#GW$TXNcLr@S~dK(`6Tla={J_9g~D{=fAq-nI#&j)bawUR z_?Sy~N9bg@v_iG7fKzKv7y4iCwk*Z`MyPO~gbn%&Y_!~hwf4{nWp;CEO)BQLjI|W& zQGa-?AiejP9r6K%I`@|{o1iQ)BPg-B6Q9Sn^9vJjKCaiUJBa)t&ynLvt1gO=BQn;q zedrJ9l7m$O!fkNt=b4?e`SH25h$SA=gZ{iFA>|8^ck-K;Q#uBV-JjG&BiHJ8o(*&C z{5alH!8)kHhFj9e@3fUcWBN4#Vs2+|>3$M<_YyhXTdQV#!X5pZTjVfjJ6(4+Eq`%f z<@PBUiXfUj&jiQ|3C9W^-Zm(&ons<{`#4A5;6ffxI~Ztbwz2(3z+646r_V~@Pk-xs zOj4zJ9*unJc^JgmSB!x1l1KY0z>5k_TptBVo$+DuB+Gu_$?~Ld`dB}y5vDWN*(e9C z;W$wd5;8^E=^G)$4ou> zw|et8Q9FJjTgi^OgxtFVu|$O&F7&gQ+;~`qz6Z(6E3+@iAg|hAMPXZ-rwtkI!5JO?j4wuZm2etJHbnOGvvg&;)IBgos8zLOG0E%p2-)#T^-n%`C)HS|tWx>8)OU8urIFv%YG6$GlrvR6$aj z-54;#JpZNS&-k4xR885PI$z-k{#OcFP2?2KK`KfXn8qZS!#s7Ol+VP{4qjF5%8E@^L None: @@ -35,12 +36,14 @@ def __init__( Parameters ---------- - stepsize : unit.Quantity, optional + timestep : unit.Quantity, optional Time step of integration with units of time. Default is 1.0 * unit.femtoseconds. collision_rate : unit.Quantity, optional Collision rate for the Langevin dynamics, with units 1/time. Default is 1.0 / unit.picoseconds. - report_frequency : int, optional - Frequency of saving the simulation data. Default is 100. + refresh_velocities : bool, optional + Flag indicating whether to reinitialize the velocities each time the run function is called. Default is False. + report_interval : int, optional + Interval between saving the simulation data. Default is 100. reporter : SimulationReporter, optional Reporter object for saving the simulation data. Default is None. save_traj_in_memory: bool @@ -50,11 +53,11 @@ def __init__( from loguru import logger as log self.kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA - log.info(f"stepsize = {stepsize}") + log.info(f"timestep = {timestep}") log.info(f"collision_rate = {collision_rate}") - log.info(f"report_frequency = {report_frequency}") + log.info(f"report_interval = {report_interval}") - self.stepsize = stepsize + self.timestep = timestep self.collision_rate = collision_rate if reporter: log.info( @@ -62,30 +65,21 @@ def __init__( ) log.info(f"and logging to {reporter.log_file_path}") self.reporter = reporter - self.report_frequency = report_frequency + self.report_interval = report_interval self.velocities = None self.save_traj_in_memory = save_traj_in_memory self.traj = [] - - def set_velocities(self, vel: unit.Quantity) -> None: - """ - Set the initial velocities for the Langevin Integrator. - - Parameters - ---------- - vel : unit.Quantity - Velocities to be set for the integrator. - """ - self.velocities = vel + self.refresh_velocities = refresh_velocities + self._move_iteration = 0 def run( self, sampler_state: SamplerState, thermodynamic_state: ThermodynamicState, - n_steps: int = 5_000, + number_of_steps: int = 5_000, nbr_list: Optional[PairsBase] = None, progress_bar=False, - ): + ) -> Tuple[SamplerState, PairsBase]: """ Run the integrator to perform Langevin dynamics molecular dynamics simulation. @@ -95,13 +89,19 @@ def run( The initial state of the simulation, including positions. thermodynamic_state : ThermodynamicState The thermodynamic state of the system, including temperature and potential. - n_steps : int, optional + number_of_steps : int, optional Number of simulation steps to perform. nbr_list : PairBase, optional Neighbor list for the system. progress_bar : bool, optional Flag indicating whether to display a progress bar during integration. + Returns + ------- + sampler_state : SamplerState + The final state of the simulation, including positions, velocities, and current PRNG key. + nbr_list : PairBase + The neighbor list for the final state of the simulation. If the NeighborList object is None, the function returns None. """ from .utils import get_list_of_mass from tqdm import tqdm @@ -114,10 +114,10 @@ def run( self.box_vectors = sampler_state.box_vectors self.progress_bar = progress_bar temperature = thermodynamic_state.temperature - x0 = sampler_state.x0 + x0 = sampler_state.positions log.debug("Running Langevin dynamics") - log.debug(f"n_steps = {n_steps}") + log.debug(f"number_of_steps = {number_of_steps}") log.debug(f"temperature = {temperature}") # Initialize the random number generator @@ -129,18 +129,38 @@ def run( :, None ] sigma_v = jnp.sqrt(kbT_unitless / mass_unitless) - stepsize_unitless = self.stepsize.value_in_unit_system(unit.md_unit_system) + timestep_unitless = self.timestep.value_in_unit_system(unit.md_unit_system) collision_rate_unitless = self.collision_rate.value_in_unit_system( unit.md_unit_system ) - a = jnp.exp((-collision_rate_unitless * stepsize_unitless)) - b = jnp.sqrt(1 - jnp.exp(-2 * collision_rate_unitless * stepsize_unitless)) + a = jnp.exp((-collision_rate_unitless * timestep_unitless)) + b = jnp.sqrt(1 - jnp.exp(-2 * collision_rate_unitless * timestep_unitless)) # Initialize velocities - if self.velocities is None: - v0 = sigma_v * random.normal(key, x0.shape) - else: - v0 = self.velocities.value_in_unit_system(unit.md_unit_system) + if self.refresh_velocities: + # v0 = sigma_v * random.normal(key, positions.shape) + from .utils import initialize_velocities + + sampler_state.velocities = initialize_velocities( + temperature, potential.topology, key + ) + + elif sampler_state._velocities is None: + # v0 = sigma_v * random.normal(key, positions.shape) + from .utils import initialize_velocities + + sampler_state.velocities = initialize_velocities( + temperature, potential.topology, key + ) + elif sampler_state._velocities.shape[0] != sampler_state.positions.shape[0]: + from .utils import initialize_velocities + + sampler_state.velocities = initialize_velocities( + temperature, potential.topology, key + ) + + # extract the velocities from the sampler state + v0 = sampler_state.velocities x = x0 v = v0 @@ -151,56 +171,76 @@ def run( F = potential.compute_force(x, nbr_list) # propagation loop - for step in tqdm(range(n_steps)) if self.progress_bar else range(n_steps): + for step in ( + tqdm(range(number_of_steps)) + if self.progress_bar + else range(number_of_steps) + ): key, subkey = random.split(key) # v - v += (stepsize_unitless * 0.5) * F / mass_unitless + v += (timestep_unitless * 0.5) * F / mass_unitless # r - x += (stepsize_unitless * 0.5) * v + x += (timestep_unitless * 0.5) * v - if nbr_list is not None: - x = self._wrap_and_rebuild_neighborlist(x, nbr_list) - # o random_noise_v = random.normal(subkey, x.shape) v = (a * v) + (b * sigma_v * random_noise_v) - x += (stepsize_unitless * 0.5) * v + x += (timestep_unitless * 0.5) * v + if nbr_list is not None: - x = self._wrap_and_rebuild_neighborlist(x, nbr_list) + x, nbr_list = self._wrap_and_rebuild_neighborlist(x, nbr_list) F = potential.compute_force(x, nbr_list) # v - v += (stepsize_unitless * 0.5) * F / mass_unitless + v += (timestep_unitless * 0.5) * F / mass_unitless - if step % self.report_frequency == 0: + elapsed_step = step + self._move_iteration * number_of_steps + if (elapsed_step) % self.report_interval == 0: if hasattr(self, "reporter") and self.reporter is not None: - self._report(x, potential, nbr_list, step) + self._report( + x, potential, nbr_list, step, self._move_iteration, elapsed_step + ) if self.save_traj_in_memory: self.traj.append(x) log.debug("Finished running Langevin dynamics") - # save the final state of the simulation in the sampler_state object - sampler_state.x0 = x - sampler_state.v0 = v + + # return the final state of the simulation as a sampler_state object + import copy + + updated_sampler_state = copy.deepcopy(sampler_state) + + updated_sampler_state.positions = x + updated_sampler_state.velocities = v + updated_sampler_state.current_PRNG_key = key + + return updated_sampler_state, nbr_list def _wrap_and_rebuild_neighborlist(self, x: jnp.array, nbr_list: PairsBase): """ - Wrap the coordinates and rebuild the neighborlist if necessary. + Wrap the positions and rebuild the neighborlist if necessary. Parameters ---------- x: jnp.array - The coordinates of the particles. + The positions of the particles. nbr_list: PairsBsse The neighborlist object. + + Returns + ------- + x: jnp.array + The wrapped positions. + nbr_list: PairsBase + The neighborlist object; this may or may not have been rebuilt. """ - x = nbr_list.space.wrap(x) + x = nbr_list.space.wrap(x, self.box_vectors) # check if we need to rebuild the neighborlist after moving the particles if nbr_list.check(x): nbr_list.build(x, self.box_vectors) - return x + return x, nbr_list def _report( self, @@ -208,6 +248,8 @@ def _report( potential: NeuralNetworkPotential, nbr_list: PairsBase, step: int, + iteration: int, + elapsed_step: int, ): """ Reports the trajectory, energy, step, and box vectors (if available) to the reporter. @@ -221,7 +263,12 @@ def _report( nbr_list: PairsBase The neighbor list step: int - The current time step. + The current step in the move; this resets each iteration. + iteration: int + The number iterations the move has been called. + elapsed_step: int, + The total number of steps that have been taken in the simulation move. + Returns: None @@ -230,8 +277,10 @@ def _report( "positions": x, "potential_energy": potential.compute_energy(x, nbr_list), "step": step, + "iteration": iteration, + "elapsed_step": elapsed_step, } if nbr_list is not None: - d["box_vectors"] = nbr_list.space.box_vectors + d["box_vectors"] = nbr_list.box_vectors self.reporter.report(d) diff --git a/chiron/mcmc.py b/chiron/mcmc.py index 285cb34..64d71d2 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -2,49 +2,101 @@ from openmm import unit from typing import Tuple, List, Optional import jax.numpy as jnp -from chiron.reporters import LangevinDynamicsReporter, _SimulationReporter +from chiron.reporters import LangevinDynamicsReporter, _SimulationReporter, MCReporter +from .neighbors import PairsBase + +from abc import ABC, abstractmethod class MCMCMove: def __init__( self, - nr_of_moves: int, + number_of_moves: int, reporter: Optional[_SimulationReporter] = None, - report_frequency: Optional[int] = 100, + report_interval: Optional[int] = 100, ): """ Initialize a move within the molecular system. Parameters ---------- - nr_of_moves : int + number_of_moves : int Number of moves to be applied. reporter : _SimulationReporter, optional Reporter object for saving the simulation data. Default is None. - report_frequency : int, optional + report_interval : int, optional + Interval for saving the simulation data in the reporter. + Default is 100. + """ - self.nr_of_moves = nr_of_moves + self.number_of_moves = number_of_moves self.reporter = reporter - self.report_frequency = report_frequency + self.report_interval = report_interval + + # we need to keep track of which iteration we are on + self._move_iteration = 0 + + # we also need to keep track of attempts made (i.e., total elapsed steps), in case the number_of_moves is changed + self._number_of_attempts_made = 0 + from loguru import logger as log if self.reporter is not None: log.info( f"Using reporter {self.reporter} saving to {self.reporter.workdir}" ) - assert self.report_frequency is not None + assert self.report_interval is not None + + @abstractmethod + def update( + self, + sampler_state: SamplerState, + thermodynamic_state: ThermodynamicState, + nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, Optional[PairsBase]]: + """ + Update the state of the system. + + Parameters + ---------- + sampler_state : SamplerState + The sampler state to run the integrator on. + thermodynamic_state : ThermodynamicState + The thermodynamic state to run the integrator on. + nbr_list : PairsBase, optional + The neighbor list to use for the simulation. + + Returns + ------- + sampler_state : SamplerState + The updated sampler state. + thermodynamic_state : ThermodynamicState + The updated thermodynamic state. + nbr_list: PairsBase + The updated neighbor/pair list. If no nbr_list is passed, this will be None. + + """ + pass + + @property + def number_of_attemps_made(self): + """ + Return the total number of steps that have been attempted in the move. + """ + return self._number_of_attempts_made class LangevinDynamicsMove(MCMCMove): def __init__( self, - stepsize=1.0 * unit.femtoseconds, - collision_rate=1.0 / unit.picoseconds, + timestep: unit.Quantity = 1.0 * unit.femtoseconds, + collision_rate: unit.Quantity = 1.0 / unit.picoseconds, + refresh_velocities: bool = False, reporter: Optional[LangevinDynamicsReporter] = None, - report_frequency: int = 100, - nr_of_steps=1_000, + report_interval: int = 100, + number_of_steps: int = 1_000, save_traj_in_memory: bool = False, ): """ @@ -52,17 +104,20 @@ def __init__( Parameters ---------- - stepsize : unit.Quantity + timestep : unit.Quantity Time step size for the integration. collision_rate : unit.Quantity Collision rate for the Langevin dynamics. + refresh_velocities : bool, optional + Whether to reinitialize the velocities each time the run function is called. + Default is False. reporter : LangevinDynamicsReporter, optional Reporter object for saving the simulation data. Default is None. - report_frequency : int - Frequency of saving the simulation data. + report_interval : int + Interval for saving the simulation data. Default is 100. - nr_of_steps : int, optional + number_of_steps : int, optional Number of steps to run the integrator for. Default is 1_000. save_traj_in_memory: bool @@ -70,30 +125,32 @@ def __init__( Default is False. NOTE: Only for debugging purposes. """ super().__init__( - nr_of_moves=nr_of_steps, + number_of_moves=number_of_steps, reporter=reporter, - report_frequency=report_frequency, + report_interval=report_interval, ) - self.stepsize = stepsize + self.timestep = timestep self.collision_rate = collision_rate self.save_traj_in_memory = save_traj_in_memory self.traj = [] from chiron.integrators import LangevinIntegrator self.integrator = LangevinIntegrator( - stepsize=self.stepsize, + timestep=self.timestep, collision_rate=self.collision_rate, - report_frequency=report_frequency, + refresh_velocities=refresh_velocities, + report_interval=report_interval, reporter=reporter, save_traj_in_memory=save_traj_in_memory, ) - def run( + def update( self, sampler_state: SamplerState, thermodynamic_state: ThermodynamicState, - ): + nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, Optional[PairsBase]]: """ Run the integrator to perform molecular dynamics simulation. @@ -103,6 +160,17 @@ def run( The sampler state to run the integrator on. thermodynamic_state : ThermodynamicState The thermodynamic state to run the integrator on. + nbr_list : PairsBase, optional + The neighbor list to use for the simulation. + + Returns + ------- + sampler_state : SamplerState + The updated sampler state. + thermodynamic_state : ThermodynamicState + The thermodynamic state; note this is not modified by the Langevin dynamics algorithm. + nbr_list: PairsBase + The updated neighbor/pair list. If a nbr_list is not set, this will be None. """ assert isinstance( @@ -112,87 +180,837 @@ def run( thermodynamic_state, ThermodynamicState ), f"Thermodynamic state must be ThermodynamicState, not {type(thermodynamic_state)}" - self.integrator.run( + updated_sampler_state, updated_nbr_list = self.integrator.run( thermodynamic_state=thermodynamic_state, sampler_state=sampler_state, - n_steps=self.nr_of_moves, + number_of_steps=self.number_of_moves, + nbr_list=nbr_list, ) + # update the elapsed steps + self._number_of_attempts_made += self.number_of_moves if self.save_traj_in_memory: self.traj.append(self.integrator.traj) self.integrator.traj = [] + self._move_iteration += 1 + + # The thermodynamic_state will not change for the langevin move + return updated_sampler_state, thermodynamic_state, updated_nbr_list + class MCMove(MCMCMove): def __init__( - self, nr_of_moves: int, reporter: Optional[_SimulationReporter] + self, + number_of_moves: int, + reporter: Optional[_SimulationReporter], + report_interval: int = 1, + autotune: bool = False, + autotune_interval: int = 100, + acceptance_method: str = "Metropolis-Hastings", ) -> None: - super().__init__(nr_of_moves, reporter=reporter) + """ + Initialize the move. - def apply_move(self): + Parameters + ---------- + number_of_moves + Number of moves to be attempted in each call to update. + reporter + Reporter object for saving the simulation step data. + report_interval + Interval for saving the simulation data. + autotune + Whether to automatically tune the parameters of the MC move to achieve a target acceptance ratio. + For example, for a simple displacement move this would update the displacement_sigma. + autotune_interval + Frequency of autotuning the MC move parameters to achieve a target acceptance ratio. + acceptance_method + Methodology to use for accepting or rejecting the proposed state. + Default is "Metropolis-Hastings". """ - Apply a Monte Carlo move to the system. + super().__init__( + number_of_moves=number_of_moves, + reporter=reporter, + report_interval=report_interval, + ) + self.acceptance_method = acceptance_method # I think we should pass a class/function instead of a string, like space. - This method should be overridden by subclasses to define specific types of moves. + self.reset_statistics() + self.autotune = autotune + self.autotune_interval = autotune_interval - Raises - ------ - NotImplementedError - If the method is not implemented in subclasses. + def update( + self, + sampler_state: SamplerState, + thermodynamic_state: ThermodynamicState, + nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, Optional[PairsBase]]: """ + Perform the defined move and update the state. - raise NotImplementedError("apply_move() must be implemented in subclasses") + Parameters + ---------- + sampler_state : SamplerState + The initial state of the simulation, including positions. + thermodynamic_state : ThermodynamicState + The thermodynamic state of the system, including temperature and potential. + nbr_list : PairBase, optional + Neighbor list for the system. - def compute_acceptance_probability( + Returns + ------- + sampler_state : SamplerState + The updated sampler state. + thermodynamic_state : ThermodynamicState + The updated thermodynamic state. + nbr_list: PairsBase + The updated neighbor/pair list. If a nbr_list is not set, this will be None. + """ + + self._current_reduced_potential = None + for i in range(self.number_of_moves): + sampler_state, thermodynamic_state, nbr_list = self._step( + sampler_state, + thermodynamic_state, + nbr_list, + ) + self._number_of_attempts_made += 1 + + # We should use self._number_of_attempts_made as the "step" otherwise, if we just used i, instances where + # self.report_interval > self.number_of_moves would only report on the + # first step, which might actually be more frequent than we specify + + if hasattr(self, "reporter"): + if self.reporter is not None: + if self._number_of_attempts_made % self.report_interval == 0: + self._report( + i, + self._move_iteration, + self._number_of_attempts_made, + self.n_accepted / self.n_proposed, + sampler_state, + thermodynamic_state, + nbr_list, + ) + if self.autotune: + # if we only used i, we might never actually update the parameters if we have a move that is called infrequently + if ( + self._number_of_attempts_made % self.autotune_interval == 0 + and self._number_of_attempts_made > 0 + ): + self._autotune() + # keep track of how many times this function has been called + self._move_iteration += 1 + + return sampler_state, thermodynamic_state, nbr_list + + @abstractmethod + def _report( self, - old_state: SamplerState, - new_state: SamplerState, + step: int, + iteration: int, + number_of_attempts_made: int, + acceptance_probability: float, + sampler_state: SamplerState, + thermodynamic_state: ThermodynamicState, + nbr_list: Optional[PairsBase] = None, ): """ - Compute the acceptance probability for a move from an old state to a new state. + Report the current state of the MC move. + + Since different moves will be modifying different quantities, + this needs to be defined for each move. Parameters ---------- - old_state : object - The state of the system before the move. - new_state : object - The state of the system after the move. + step : int + The current step of the simulation move. + iteration : int + The current iteration of the move sequence (i.e., how many times has this been called thus far). + number_of_attempts_made : int + The total number of steps that have been taken in the simulation move. step+ nr_moves*iteration + acceptance_probability : float + The acceptance probability of the move. + sampler_state : SamplerState + The sampler state of the system. + thermodynamic_state : ThermodynamicState + The thermodynamic state of the system. + nbr_list : Optional[PairBase]=None + The neighbor list or pair list for evaluating interactions in the system, default None + """ + pass + + @abstractmethod + def _autotune(self): + """ + This will autotune the move parameters to reach a target acceptance probability. + This will be specific to the type of move, e.g., a displacement_sigma for a displacement move + or a maximum volume change factor for a Monte Carlo barostat move. + + Since different moves will be modifying different quantities, this needs to be defined for each move. + + Note this will modify the class parameters in place. + """ + pass + + def _step( + self, + current_sampler_state: SamplerState, + current_thermodynamic_state: ThermodynamicState, + current_nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, Optional[PairsBase]]: + """ + Performs an individual MC step. + + This will call the _propose function which will be specific to the type of move. + + Parameters + ---------- + current_sampler_state : SamplerState, required + Current sampler state. + current_thermodynamic_state : ThermodynamicState, required + Current thermodynamic state. + current_nbr_list : Optional[PairsBase] + Neighbor list associated with the current state. Returns ------- - float - Acceptance probability as a float. + sampler_state : SamplerState + The updated sampler state; if a move is rejected this will be unchanged. + Note, if the proposed move is rejected, the current PRNG key will be updated to ensure + that we are using a different random number for the next iteration. + thermodynamic_state : ThermodynamicState + The updated thermodynamic state; if a move is rejected this will be unchanged. + Note, many MC moves will not modify the thermodynamic state regardless of acceptance of the move. + nbr_list: PairsBase, optional + The updated neighbor/pair list. If a nbr_list is not set, this will be None. + If the move is rejected, this will correspond to the neighbor + + """ + + # if this is the first time we are calling this function during this iteration + # we will need to calculate the reduced potential for the current state + # this is toggled by the calculate_current_reduced_potential flag + # otherwise, we can use the one that was saved from the last step, for efficiency + if self._current_reduced_potential is None: + current_reduced_potential = ( + current_thermodynamic_state.get_reduced_potential( + current_sampler_state, current_nbr_list + ) + ) + # save the current_reduced_potential so we don't have to recalculate + # it on the next iteration if the move is rejected + self._current_reduced_potential = current_reduced_potential + else: + current_reduced_potential = self._current_reduced_potential + + # propose a new state and calculate the log proposal ratio + # this will be specific to the type of move + # in addition to the sampler_state, this will require/return the thermodynamic state + # for systems that e.g., make changes to particle identity. + # For efficiency, we will also return a copy of the nbr_list associated with the proposed state + # because if the move is rejected, we can move back the original state without having to rebuild the nbr_list + # if it were modified due to the proposed state. + ( + proposed_sampler_state, + proposed_thermodynamic_state, + proposed_reduced_potential, + log_proposal_ratio, + proposed_nbr_list, + ) = self._propose( + current_sampler_state, + current_thermodynamic_state, + current_reduced_potential, + current_nbr_list, + ) + + if jnp.isnan(proposed_reduced_potential): + decision = False + else: + # accept or reject the proposed state + decision = self._accept_or_reject( + log_proposal_ratio, + proposed_sampler_state.new_PRNG_key, + acceptance_method=self.acceptance_method, + ) + # a function that will update the statistics for the move + + self._update_statistics(decision) + + if decision: + # save the reduced potential of the accepted state so + # we don't have to recalculate it the next iteration + self._current_reduced_potential = proposed_reduced_potential + + # replace the current state with the proposed state + # not sure this needs to be a separate function but for simplicity in outlining the code it is fine + # or should this return the new sampler_state and thermodynamic_state? + + return ( + proposed_sampler_state, + proposed_thermodynamic_state, + proposed_nbr_list, + ) + else: + # if we reject the move, we need to update the current_PRNG key to ensure that + # we are using a different random number for the next iteration + # this is needed because the _step function returns a SamplerState instead of updating it in place + current_sampler_state._current_PRNG_key = ( + proposed_sampler_state._current_PRNG_key + ) + + return current_sampler_state, current_thermodynamic_state, current_nbr_list + + def _update_statistics(self, decision): + """ + Update the statistics for the move. + """ + if decision: + self.n_accepted += 1 + self.n_proposed += 1 + + @property + def statistics(self): + """The acceptance statistics as a dictionary.""" + return dict(n_accepted=self.n_accepted, n_proposed=self.n_proposed) + + @statistics.setter + def statistics(self, value): + self.n_accepted = value["n_accepted"] + self.n_proposed = value["n_proposed"] + + def reset_statistics(self): + """Reset the acceptance statistics.""" + self.n_accepted = 0 + self.n_proposed = 0 + + @abstractmethod + def _propose( + self, + current_sampler_state: SamplerState, + current_thermodynamic_state: ThermodynamicState, + current_reduced_potential: float, + current_nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, float, float, Optional[PairsBase]]: """ - self._check_state_compatiblity(old_state, new_state) - old_system = self.system(old_state) - new_system = self.system(new_state) + Propose a new state and calculate the log proposal ratio. + + This will accept the relevant quantities for the current state, returning the proposed state quantities + and the log proposal ratio. + + This will need to be defined for each new move. + + Parameters + ---------- + current_sampler_state : SamplerState, required + Current sampler state. + current_thermodynamic_state : ThermodynamicState, required + Current thermodynamic state. + current_reduced_potential : float, required + Current reduced potential. + current_nbr_list : PairsBase, required + Neighbor list associated with the current state. + + Returns + ------- + proposed_sampler_state : SamplerState + Proposed sampler state. + proposed_thermodynamic_state : ThermodynamicState + Proposed thermodynamic state. + proposed_reduced_potential : float + Proposed reduced potential. + log_proposal_ratio : float + Log proposal ratio. + proposed_nbr_list : PairsBase + Proposed neighbor list. If not defined, this will be None. - energy_before_state_change = old_system.compute_energy(old_state.position) - energy_after_state_change = new_system.compute_energy(new_state.position) - # Implement the logic to compute the acceptance probability + """ pass - def accept_or_reject(self, probability): + def _accept_or_reject( + self, + log_proposal_ratio, + key, + acceptance_method, + ): + """ + Accept or reject the proposed state with a given methodology. + """ + # define the acceptance probability + if acceptance_method == "Metropolis-Hastings": + import jax.random as jrandom + + compare_to = jrandom.uniform(key) + if -log_proposal_ratio <= 0.0 or compare_to < jnp.exp(log_proposal_ratio): + return True + else: + return False + + +class MonteCarloDisplacementMove(MCMove): + """ + A Monte Carlo move that randomly displaces particles in the system. + + For each move, all particles will be randomly displaced at once, where the random displacement is drawn from + a normal distribution. The standard deviation of the distribution is defined by the `displacement_sigma` parameter. + + Displacements can be restricted to a subset of particles by defining the `atom_subset` parameter, which is a list of + particle indices that will be allowed to move. If `atom_subset` is not defined, all particles will be displaced. + + Note, the displacement moves are applied on a per-particle basis; this does not support collective moves. + + The value of the `displacement_sigma` can be autotuned to achieve a target acceptance ratio between 0.4 and 0.6, + by setting the autotune parameter to True. The frequency of autotuning is defined by setting `autotune_interval`. + + + """ + + def __init__( + self, + displacement_sigma=1.0 * unit.nanometer, + number_of_moves: int = 100, + atom_subset: Optional[List[int]] = None, + report_interval: int = 1, + reporter: Optional[MCReporter] = None, + autotune: bool = False, + autotune_interval: int = 100, + acceptance_method="Metropolis-Hastings", + ): """ - Decide whether to accept or reject the move based on the acceptance probability. + Initialize the Displacement Move class. Parameters ---------- - probability : float - Acceptance probability. + displacement_sigma : float or unit.Quantity, optional + The standard deviation of the displacement for each move. Default is 1.0 nm. + number_of_moves : int, optional + The number of move attempts to perform. Default is 100. + For a given move, all particles will be randomly displaced at once (unless atom_subset is), + rather than moving each particle one at a time. + atom_subset : list of int, optional + A list of particle indices that represent a subset of all particles. + If defined, only those particles in the list will have their positions random displaced. + Default is None. + reporter : SimulationReporter, optional + The reporter to write the data to. Default is None. + autotune : bool, optional + Whether to autotune the displacement_sigma of the move to achieve an acceptance ratio between 0.4 and 0.6. + Default is False. + autotune_interval : int, optional + Frequency of autotuning displacement_sigma of the move. Default is 100. + acceptance_method : str, optional + Methodology to use for accepting or rejecting the proposed state. + Default is "Metropolis-Hastings". Returns ------- - bool - Boolean indicating if the move is accepted. + None """ - import jax.numpy as jnp + super().__init__( + number_of_moves=number_of_moves, + reporter=reporter, + report_interval=report_interval, + autotune=autotune, + autotune_interval=autotune_interval, + acceptance_method=acceptance_method, + ) + self.displacement_sigma = displacement_sigma + + self.atom_subset = atom_subset + self.atom_subset_mask = None + + def _report( + self, + step: int, + iteration: int, + number_of_attempts_made: int, + acceptance_probability: float, + sampler_state: SamplerState, + thermodynamic_state: ThermodynamicState, + nbr_list: Optional[PairsBase] = None, + ): + """ + Report the current state of the MC displacement move. + + Parameters + ---------- + step : int + The current step of the simulation move. + iteration : int + The current iteration of the move sequence (i.e., how many times has this been called thus far). + number_of_attempts_made : int + The total number of steps that have been taken in the simulation move. step+ nr_moves*iteration + acceptance_probability : float + The acceptance probability of the move. + sampler_state : SamplerState + The sampler state of the system. + thermodynamic_state : ThermodynamicState + The thermodynamic state of the system. + nbr_list : Optional[PairBase]=None + The neighbor list or pair list for evaluating interactions in the system, default None - return jnp.random.rand() < probability + """ + potential = thermodynamic_state.potential.compute_energy( + sampler_state.positions, nbr_list + ) + self.reporter.report( + { + "step": step, + "iteration": iteration, + "number_of_attempts_made": number_of_attempts_made, + "potential_energy": potential, + "displacement_sigma": self.displacement_sigma.value_in_unit_system( + unit.md_unit_system + ), + "acceptance_probability": acceptance_probability, + } + ) + + def _autotune(self): + """ + Update the displacement_sigma to reach a target acceptance probability between 0.4 and 0.6. + """ + acceptance_ratio = self.n_accepted / self.n_proposed + if acceptance_ratio > 0.6: + self.displacement_sigma *= 1.1 + elif acceptance_ratio < 0.4: + self.displacement_sigma /= 1.1 + + def _propose( + self, + current_sampler_state: SamplerState, + current_thermodynamic_state: ThermodynamicState, + current_reduced_potential: float, + current_nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, float, float, Optional[PairsBase]]: + """ + Implements the logic specific to displacement moves. + + Parameters + ---------- + current_sampler_state : SamplerState, required + Current sampler state. + current_thermodynamic_state : ThermodynamicState, required + Current thermodynamic state. + current_reduced_potential : float, required + Current reduced potential. + current_nbr_list : Optional[PairsBase] + Neighbor list associated with the current state. + + Returns + ------- + proposed_sampler_state : SamplerState + Proposed sampler state. + proposed_thermodynamic_state : ThermodynamicState + Proposed thermodynamic state. + proposed_reduced_potential : float + Proposed reduced potential. + log_proposal_ratio : float + Log proposal ratio. + proposed_nbr_list : PairsBase + Proposed neighbor list. If not defined, this will be None. + """ + + # create a mask for the atom subset: if a value of the mask is 0 + # the particle won't move; if 1 the particle will be moved + if self.atom_subset is not None and self.atom_subset_mask is None: + import jax.numpy as jnp + + self.atom_subset_mask = jnp.zeros(current_sampler_state.n_particles) + for atom in self.atom_subset: + self.atom_subset_mask = self.atom_subset_mask.at[atom].set(1) + + key = current_sampler_state.new_PRNG_key + + nr_of_atoms = current_sampler_state.n_particles + + unitless_displacement_sigma = self.displacement_sigma.value_in_unit_system( + unit.md_unit_system + ) + import jax.random as jrandom + + scaled_displacement_vector = ( + jrandom.normal(key, shape=(nr_of_atoms, 3)) * unitless_displacement_sigma + ) + import copy + + proposed_sampler_state = copy.deepcopy(current_sampler_state) + + if self.atom_subset is not None: + proposed_sampler_state.positions = ( + proposed_sampler_state.positions + + scaled_displacement_vector * self.atom_subset_mask + ) + else: + proposed_sampler_state.positions = ( + proposed_sampler_state.positions + scaled_displacement_vector + ) + + # after proposing a move we need to wrap particles and see if we need to rebuild the neighborlist + if current_nbr_list is not None: + proposed_sampler_state.positions = current_nbr_list.space.wrap( + proposed_sampler_state.positions, + proposed_sampler_state.box_vectors, + ) + + # if we need to rebuild the neighbor the neighborlist + # we will make a copy and then build + if current_nbr_list.check(proposed_sampler_state.positions): + import copy + + proposed_nbr_list = copy.deepcopy(current_nbr_list) + + proposed_nbr_list.build( + proposed_sampler_state.positions, proposed_sampler_state.box_vectors + ) + # if we don't need to update the neighborlist, just make a new variable that refers to the original + else: + proposed_nbr_list = current_nbr_list + else: + proposed_nbr_list = None + + proposed_reduced_potential = current_thermodynamic_state.get_reduced_potential( + proposed_sampler_state, proposed_nbr_list + ) + + log_proposal_ratio = -proposed_reduced_potential + current_reduced_potential + + # since do not change the thermodynamic state we can return + # 'current_thermodynamic_state' rather than making a copy + return ( + proposed_sampler_state, + current_thermodynamic_state, + proposed_reduced_potential, + log_proposal_ratio, + proposed_nbr_list, + ) + + +class MonteCarloBarostatMove(MCMove): + """ + A Monte Carlo move that randomly changes the volume of the system. + + The volume change is drawn from a normal distribution with a mean of 0 and a standard deviation defined + by the product of the `volume_max_scale` parameter and the current volume. Particle positions are scaled + proportionately with the change in volume. This routine operates on a per-particle basis and does not support + collective moves (i.e., it is an "atomic" barostat move where particle center-of-mass positions are scaled; + it is not aware of "molecules" which would be scaled by the molecule center-of-mass). + + The `volume_max_scale` parameter can be autotuned to achieve a target acceptance ratio between 0.25 and 0.75, + by setting the autotune parameter to True. The frequency of autotuning is defined by setting `autotune_interval`. + Note, the maximum value of `volume_max_scale` is capped at 0.3 in the auto-tuning process. + + + """ + + def __init__( + self, + volume_max_scale=0.01, + number_of_moves: int = 100, + report_interval: int = 1, + reporter: Optional[LangevinDynamicsReporter] = None, + autotune: bool = False, + autotune_interval: int = 100, + acceptance_method="Metropolis-Hastings", + ): + """ + Initialize the Monte Carlo Barostat Move class. + + Parameters + ---------- + volume_max_scale : float, optional + The scaling factor multiplied by volume to set the maximum volume change allowed. + number_of_moves : int, optional + The number of volume update moves attempts to perform. Default is 100. + reporter : SimulationReporter, optional + The reporter to write the data to. Default is None. + autotune : bool, optional + Whether to autotune the volume_max_scale value of the move to achieve a target probability + between 0.25 and 0.75. Default is False. volume_max_scale is capped at 0.3 + autotune_interval : int, optional + Frequency of autotuning the volume_max_scale of the move. Default is 100. + acceptance_method : str, optional + Methodology to use for accepting or rejecting the proposed state. + Default is "Metropolis-Hastings". + + Returns + ------- + None + """ + super().__init__( + number_of_moves=number_of_moves, + reporter=reporter, + report_interval=report_interval, + autotune=autotune, + autotune_interval=autotune_interval, + acceptance_method=acceptance_method, + ) + self.volume_max_scale = volume_max_scale + + def _report( + self, + step: int, + iteration: int, + number_of_attempts_made: int, + acceptance_probability: float, + sampler_state: SamplerState, + thermodynamic_state: ThermodynamicState, + nbr_list: Optional[PairsBase] = None, + ): + """ + + Parameters + ---------- + step : int + The current step of the simulation move. + iteration : int + The current iteration of the move sequence (i.e., how many times has this been called thus far). + number_of_attempts_made : int + The total number of steps that have been taken in the simulation move. step+ nr_moves*iteration + acceptance_probability : float + The acceptance probability of the move. + sampler_state : SamplerState + The sampler state of the system. + thermodynamic_state : ThermodynamicState + The thermodynamic state of the system. + nbr_list : Optional[PairBase]=None + The neighbor list or pair list for evaluating interactions in the system, default None + """ + + potential = thermodynamic_state.potential.compute_energy( + sampler_state.positions, nbr_list + ) + volume = ( + sampler_state.box_vectors[0][0] + * sampler_state.box_vectors[1][1] + * sampler_state.box_vectors[2][2] + ) + self.reporter.report( + { + "step": step, + "iteration": iteration, + "number_of_attempts_made": number_of_attempts_made, + "potential_energy": potential, + "volume": volume, + "box_vectors": sampler_state.box_vectors, + "max_volume_scale": self.volume_max_scale, + "acceptance_probability": acceptance_probability, + } + ) + + def _autotune(self): + """ + Update the volume_max_scale parameter to ensure our acceptance probability is within the range of 0.25 to 0.75. + The maximum volume_max_scale will be capped at 0.3. + """ + acceptance_ratio = self.n_accepted / self.n_proposed + if acceptance_ratio < 0.25: + self.volume_max_scale /= 1.1 + elif acceptance_ratio > 0.75: + self.volume_max_scale = min(self.volume_max_scale * 1.1, 0.3) + + def _propose( + self, + current_sampler_state: SamplerState, + current_thermodynamic_state: ThermodynamicState, + current_reduced_potential: float, + current_nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, float, float, Optional[PairsBase]]: + """ + Implement the logic specific to displacement changes. + + Parameters + ---------- + current_sampler_state : SamplerState, required + Current sampler state. + current_thermodynamic_state : ThermodynamicState, required + Current thermodynamic state. + current_reduced_potential : float, required + Current reduced potential. + current_nbr_list : PairsBase, optional + Neighbor list associated with the current state. + + Returns + ------- + proposed_sampler_state : SamplerState + Proposed sampler state. + proposed_thermodynamic_state : ThermodynamicState + Proposed thermodynamic state. + proposed_reduced_potential : float + Proposed reduced potential. + log_proposal_ratio : float + Log proposal ratio. + proposed_nbr_list : PairsBase + Proposed neighbor list. If not defined, this will be None. + + """ + from loguru import logger as log + + key = current_sampler_state.new_PRNG_key + + import jax.random as jrandom + + nr_of_atoms = current_sampler_state.n_particles + + initial_volume = ( + current_sampler_state.box_vectors[0][0] + * current_sampler_state.box_vectors[1][1] + * current_sampler_state.box_vectors[2][2] + ) + + # Calculate the maximum amount the volume can change by + delta_volume_max = self.volume_max_scale * initial_volume + + # Calculate the volume change by generating a random number between -1 and 1 + # and multiplying by the maximum allowed volume change, delta_volume_max + delta_volume = jrandom.uniform(key, minval=-1, maxval=1) * delta_volume_max + # calculate the new volume + proposed_volume = initial_volume + delta_volume + + # calculate the length scale factor for particle positions and box vectors + length_scaling_factor = jnp.power(proposed_volume / initial_volume, 1.0 / 3.0) + + import copy + + proposed_sampler_state = copy.deepcopy(current_sampler_state) + proposed_sampler_state.positions = ( + current_sampler_state.positions * length_scaling_factor + ) + + proposed_sampler_state.box_vectors = ( + current_sampler_state.box_vectors * length_scaling_factor + ) + + if current_nbr_list is not None: + proposed_nbr_list = copy.deepcopy(current_nbr_list) + # after scaling the box vectors and positions we should always rebuild the neighborlist + proposed_nbr_list.build( + proposed_sampler_state.positions, proposed_sampler_state.box_vectors + ) + + proposed_reduced_potential = current_thermodynamic_state.get_reduced_potential( + proposed_sampler_state, proposed_nbr_list + ) + # NPT acceptance criteria was originally defined in McDonald 1972, https://doi.org/10.1080/00268977200100031 + # (see equation 9). The acceptance probability is given by: + # ⎡−β (ΔU + PΔV ) + N ln(V new /V old )⎤ + log_proposal_ratio = -( + proposed_reduced_potential - current_reduced_potential + ) + nr_of_atoms * jnp.log(proposed_volume / initial_volume) + + # we do not change the thermodynamic state so we can return 'current_thermodynamic_state' + return ( + proposed_sampler_state, + current_thermodynamic_state, + proposed_reduced_potential, + log_proposal_ratio, + proposed_nbr_list, + ) class RotamerMove(MCMove): - def apply_move(self): + def _propose(self): """ Implement the logic specific to rotamer changes. """ @@ -200,7 +1018,7 @@ def apply_move(self): class ProtonationStateMove(MCMove): - def apply_move(self): + def _propose(self): """ Implement the logic specific to protonation state changes. """ @@ -208,7 +1026,7 @@ def apply_move(self): class TautomericStateMove(MCMove): - def apply_move(self): + def _propose(self): """ Implement the logic specific to tautomeric state changes. """ @@ -283,20 +1101,39 @@ def run( sampler_state: SamplerState, thermodynamic_state: ThermodynamicState, n_iterations: int = 1, + nbr_list: Optional[PairsBase] = None, ): """ Run the sampler for a specified number of iterations. Parameters ---------- + sampler_state : SamplerState + The initial state of the sampler. + thermodynamic_state : ThermodynamicState + The thermodynamic state of the system. n_iterations : int, optional Number of iterations of the sampler to run. + Default is 1. + nbr_list : PairsBase, optional + The neighbor list to use for the simulation. + + Returns + ------- + sampler_state : SamplerState + The updated sampler state. + thermodynamic_state : ThermodynamicState + The updated thermodynamic state. + nbr_list: PairsBase + The updated neighbor/pair list. If a nbr_list is not set, this will be None. + """ from loguru import logger as log from copy import deepcopy sampler_state = deepcopy(sampler_state) thermodynamic_state = deepcopy(thermodynamic_state) + nbr_list = deepcopy(nbr_list) log.info("Running MCMC sampler") log.info(f"move_schedule = {self.move.move_schedule}") @@ -304,7 +1141,10 @@ def run( log.info(f"Iteration {iteration + 1}/{n_iterations}") for move_name, move in self.move.move_schedule: log.debug(f"Performing: {move_name}") - move.run(sampler_state, thermodynamic_state) + + sampler_state, thermodynamic_state, nbr_list = move.update( + sampler_state, thermodynamic_state, nbr_list + ) log.info("Finished running MCMC sampler") log.debug("Closing reporter") @@ -312,298 +1152,4 @@ def run( if move.reporter is not None: move.reporter.flush_buffer() log.debug(f"Closed reporter {move.reporter.log_file_path}") - return sampler_state - - -from .neighbors import PairsBase - - -class MetropolizedMove(MCMove): - """A base class for metropolized moves. - - Only the proposal needs to be specified by subclasses through the method - _propose_positions(). - - Parameters - ---------- - atom_subset : slice or list of int, optional - If specified, the move is applied only to those atoms specified by these - indices. If None, the move is applied to all atoms (default is None). - - Attributes - ---------- - n_accepted : int - The number of proposals accepted. - n_proposed : int - The total number of attempted moves. - atom_subset - - Examples - -------- - TBC - """ - - def __init__( - self, - atom_subset: Optional[List[int]] = None, - nr_of_moves: int = 100, - reporter: Optional[_SimulationReporter] = None, - report_frequency: int = 1, - ): - self.n_accepted = 0 - self.n_proposed = 0 - self.atom_subset = atom_subset - super().__init__(nr_of_moves=nr_of_moves, reporter=reporter) - from loguru import logger as log - - self.report_frequency = report_frequency - log.debug(f"Atom subset is {atom_subset}.") - - @property - def statistics(self): - """The acceptance statistics as a dictionary.""" - return dict(n_accepted=self.n_accepted, n_proposed=self.n_proposed) - - @statistics.setter - def statistics(self, value): - self.n_accepted = value["n_accepted"] - self.n_proposed = value["n_proposed"] - - def apply( - self, - thermodynamic_state: ThermodynamicState, - sampler_state: SamplerState, - nbr_list=Optional[PairsBase], - ): - """Apply a metropolized move to the sampler state. - - Total number of acceptances and proposed move are updated. - - Parameters - ---------- - thermodynamic_state : ThermodynamicState - The thermodynamic state to use to apply the move. - sampler_state : SamplerState - The initial sampler state to apply the move to. This is modified. - nbr_list: Neighbor List or Pair List routine, - The routine to use to calculate the interacting atoms. - Default is None and will use an unoptimized pairlist without PBC - """ - import jax.numpy as jnp - from loguru import logger as log - - # Compute initial energy - initial_energy = thermodynamic_state.get_reduced_potential( - sampler_state, nbr_list - ) # NOTE: in kT - log.debug(f"Initial energy is {initial_energy} kT.") - - # Store initial positions of the atoms that are moved. - x0 = sampler_state.x0 - atom_subset = self.atom_subset - if atom_subset is None: - initial_positions = jnp.copy(x0) - else: - initial_positions = jnp.copy(sampler_state.x0[jnp.array(atom_subset)]) - log.debug(f"Initial positions are {initial_positions} nm.") - # Propose perturbed positions. Modifying the reference changes the sampler state. - proposed_positions = self._propose_positions(initial_positions) - - log.debug(f"Proposed positions are {proposed_positions} nm.") - # Compute the energy of the proposed positions. - if atom_subset is None: - sampler_state.x0 = proposed_positions - else: - sampler_state.x0 = sampler_state.x0.at[jnp.array(atom_subset)].set( - proposed_positions - ) - if nbr_list is not None: - if nbr_list.check(sampler_state.x0): - nbr_list.build(sampler_state.x0, sampler_state.box_vectors) - - proposed_energy = thermodynamic_state.get_reduced_potential( - sampler_state, nbr_list - ) # NOTE: in kT - # Accept or reject with Metropolis criteria. - delta_energy = proposed_energy - initial_energy - log.debug(f"Delta energy is {delta_energy} kT.") - import jax.random as jrandom - - self.key, subkey = jrandom.split(self.key) - - compare_to = jrandom.uniform(subkey) - if not jnp.isnan(proposed_energy) and ( - delta_energy <= 0.0 or compare_to < jnp.exp(-delta_energy) - ): - self.n_accepted += 1 - log.debug(f"Check suceeded: {compare_to=} < {jnp.exp(-delta_energy)}") - log.debug( - f"Move accepted. Energy change: {delta_energy:.3f} kT. Number of accepted moves: {self.n_accepted}." - ) - if self.n_proposed % self.report_frequency == 0: - self.reporter.report( - { - "energy": proposed_energy, # in kT - "step": self.n_proposed, - "traj": sampler_state.x0, - } - ) - else: - # Restore original positions. - if atom_subset is None: - sampler_state.x0 = initial_positions - else: - sampler_state.x0 = sampler_state.x0.at[jnp.array([atom_subset])].set( - initial_positions - ) - log.debug( - f"Move rejected. Energy change: {delta_energy:.3f} kT. Number of rejected moves: {self.n_proposed - self.n_accepted}." - ) - self.n_proposed += 1 - - def _propose_positions(self, positions: jnp.array): - """Return new proposed positions. - - These method must be implemented in subclasses. - - Parameters - ---------- - positions : nx3 jnp.ndarray - The original positions of the subset of atoms that these move - applied to. - - Returns - ------- - proposed_positions : nx3 jnp.ndarray - The new proposed positions. - - """ - raise NotImplementedError( - "This MetropolizedMove does not know how to propose new positions." - ) - - -class MetropolisDisplacementMove(MetropolizedMove): - """A metropolized move that randomly displace a subset of atoms. - - Parameters - ---------- - displacement_sigma : openmm.unit.Quantity - The standard deviation of the normal distribution used to propose the - random displacement (units of length, default is 1.0*nanometer). - atom_subset : slice or list of int, optional - If specified, the move is applied only to those atoms specified by these - indices. If None, the move is applied to all atoms (default is None). - - Attributes - ---------- - n_accepted : int - The number of proposals accepted. - n_proposed : int - The total number of attempted moves. - displacement_sigma - atom_subset - - See Also - -------- - MetropolizedMove - - """ - - def __init__( - self, - displacement_sigma=1.0 * unit.nanometer, - nr_of_moves: int = 100, - atom_subset: Optional[List[int]] = None, - reporter: Optional[LangevinDynamicsReporter] = None, - ): - """ - Initialize the MCMC class. - - Parameters - ---------- - seed : int, optional - The seed for the random number generator. Default is 1234. - displacement_sigma : float or unit.Quantity, optional - The standard deviation of the displacement for each move. Default is 1.0 nm. - nr_of_moves : int, optional - The number of moves to perform. Default is 100. - atom_subset : list of int, optional - A subset of atom indices to consider for the moves. Default is None. - reporter : SimulationReporter, optional - The reporter to write the data to. Default is None. - Returns - ------- - None - """ - super().__init__(nr_of_moves=nr_of_moves, reporter=reporter) - self.displacement_sigma = displacement_sigma - self.atom_subset = atom_subset - self.key = None - - def displace_positions( - self, positions: jnp.array, displacement_sigma=1.0 * unit.nanometer - ): - """Return the positions after applying a random displacement to them. - - Parameters - ---------- - positions : nx3 jnp.array unit.Quantity - The positions to displace. - displacement_sigma : openmm.unit.Quantity - The standard deviation of the normal distribution used to propose - the random displacement (units of length, default is 1.0*nanometer). - - Returns - ------- - rotated_positions : nx3 numpy.ndarray openmm.unit.Quantity - The displaced positions. - - """ - import jax.random as jrandom - - self.key, subkey = jrandom.split(self.key) - nr_of_atoms = positions.shape[0] - unitless_displacement_sigma = displacement_sigma.value_in_unit_system( - unit.md_unit_system - ) - displacement_vector = ( - jrandom.normal(subkey, shape=(nr_of_atoms, 3)) * 0.1 - ) # NOTE: convert from Angstrom to nm - scaled_displacement_vector = displacement_vector * unitless_displacement_sigma - updated_position = positions + scaled_displacement_vector - - return updated_position - - def _propose_positions(self, initial_positions: jnp.array) -> jnp.array: - """Implement MetropolizedMove._propose_positions for apply().""" - return self.displace_positions(initial_positions, self.displacement_sigma) - - def run( - self, - sampler_state: SamplerState, - thermodynamic_state: ThermodynamicState, - nbr_list=None, - progress_bar=True, - ): - from tqdm import tqdm - from loguru import logger as log - from jax import random - - self.key = sampler_state.new_PRNG_key - - for trials in ( - tqdm(range(self.nr_of_moves)) if progress_bar else range(self.nr_of_moves) - ): - self.apply(thermodynamic_state, sampler_state, nbr_list) - if trials % 100 == 0: - log.debug(f"Acceptance rate: {self.n_accepted / self.n_proposed}") - if self.reporter is not None: - self.reporter.report( - { - "Acceptance rate": self.n_accepted / self.n_proposed, - "step": self.n_proposed, - } - ) - - log.info(f"Acceptance rate: {self.n_accepted / self.n_proposed}") + return sampler_state, thermodynamic_state, nbr_list diff --git a/chiron/multistate.py b/chiron/multistate.py index 26f3ea3..e67d7bb 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -75,6 +75,7 @@ def __init__( self._neighborhoods = None self._n_accepted_matrix = None self._n_proposed_matrix = None + self._reporter = reporter # NOTE: reporter needs to be putlic, API change ahead self._metadata = None self._mcmc_sampler = copy.deepcopy(mcmc_sampler) @@ -322,14 +323,14 @@ def _minimize_replica( # Perform minimization minimized_state = minimize_energy( - sampler_state.x0, + sampler_state.positions, thermodynamic_state.potential.compute_energy, self.nbr_list, maxiter=max_iterations, ) # Update the sampler state - self._sampler_states[replica_id].x0 = minimized_state.params + self._sampler_states[replica_id].positions = minimized_state.params # Compute and log final energy final_energy = thermodynamic_state.get_reduced_potential(sampler_state) @@ -394,11 +395,17 @@ def _propagate_replica(self, replica_id: int): thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] sampler_state = self._sampler_states[replica_id] thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] + mcmc_sampler = self._mcmc_sampler[thermodynamic_state_id] # Propagate using the mcmc sampler - self._sampler_states[replica_id] = mcmc_sampler.run(sampler_state, thermodynamic_state) + # NOTE this needs to be updated to support neighborlists + ( + self._sampler_states[replica_id], + self._thermodynamic_states[thermodynamic_state_id], + nbr_list, + ) = mcmc_sampler.run(sampler_state, thermodynamic_state) # Append the new state to the trajectory for analysis. - self._traj[replica_id].append(self._sampler_states[replica_id].x0) + self._traj[replica_id].append(self._sampler_states[replica_id].positions) def _perform_swap_proposals(self): """ @@ -577,9 +584,9 @@ def _report_positions(self): log.debug("Reporting positions...") # numpy array with shape (n_replicas, n_atoms, 3) - xyz = np.zeros((self.n_replicas, self._sampler_states[0].x0.shape[0], 3)) + xyz = np.zeros((self.n_replicas, self._sampler_states[0].positions.shape[0], 3)) for replica_id in range(self.n_replicas): - xyz[replica_id] = self._sampler_states[replica_id].x0 + xyz[replica_id] = self._sampler_states[replica_id].positions return {"positions": xyz} def _report(self, property: str) -> None: @@ -598,17 +605,29 @@ def _report(self, property: str) -> None: from loguru import logger as log log.debug(f"Reporting {property}...") - match property: - case "positions": - return self._report_positions() - case "states": - pass - case "u_kn": - return self._report_energy_matrix() - case "trajectory": - return - case "mixing_statistics": - return + if property == "positions": + return self._report_positions() + elif property == "states": + pass + elif property == "u_kn": + return self._report_energy_matrix() + elif property == "trajectory": + return + elif "mixing_statistics": + return + + # match isn't in python 3.9; we can discuss if we want to drop python 3.0 support or just keep the if/else structure + # match property: + # case "positions": + # return self._report_positions() + # case "states": + # pass + # case "u_kn": + # return self._report_energy_matrix() + # case "trajectory": + # return + # case "mixing_statistics": + # return def _report_iteration(self): """ diff --git a/chiron/neighbors.py b/chiron/neighbors.py index 9941e44..be80cf0 100644 --- a/chiron/neighbors.py +++ b/chiron/neighbors.py @@ -3,62 +3,36 @@ import jax import jax.numpy as jnp from functools import partial -from typing import Tuple, Union +from typing import Tuple, Union, Optional from .states import SamplerState from openmm import unit -# split out the displacement calculation from the neighborlist for flexibility +# split out the displacement calculation from the neighbor list and pair list for flexibility from abc import ABC, abstractmethod class Space(ABC): - def __init__( - self, box_vectors: Union[jnp.array, unit.Quantity, None] = None - ) -> None: - """ - Abstract base class for defining the simulation space. + """ + Abstract Base Class for different simulation spaces. - Parameters - ---------- - box_vectors: jnp.array, optional - Box vectors for the system. - """ - if box_vectors is not None: - if isinstance(box_vectors, unit.Quantity): - if not box_vectors.unit.is_compatible(unit.nanometer): - raise ValueError( - f"Box vectors require distance unit, not {box_vectors.unit}" - ) - self.box_vectors = box_vectors.value_in_unit_system(unit.md_unit_system) - elif isinstance(box_vectors, jnp.ndarray): - if box_vectors.shape != (3, 3): - raise ValueError( - f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" - ) - - self.box_vectors = box_vectors - else: - raise TypeError( - f"box_vectors must be a jnp.array or unit.Quantity, not {type(box_vectors)}" - ) + This class will define two functions: + - displacement, i.e., how to calculate the displacement vector and distance between two points + - wrap, i.e., how to wrap a particle in the box (i.e., apply boundary conditions). - @property - def box_vectors(self) -> jnp.array: - return self._box_vectors + Note, this class does not store the box_vectors; they will need to be passed to each function. - @box_vectors.setter - def box_vectors(self, box_vectors: jnp.array) -> None: - self._box_vectors = box_vectors + + """ @abstractmethod def displacement( - self, xyz_1: jnp.array, xyz_2: jnp.array + self, xyz_1: jnp.array, xyz_2: jnp.array, box_vectors: jnp.array ) -> Tuple[jnp.array, jnp.array]: pass @abstractmethod - def wrap(self, xyz: jnp.array) -> jnp.array: + def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: pass @@ -68,20 +42,9 @@ class OrthogonalPeriodicSpace(Space): """ - @property - def box_vectors(self) -> jnp.array: - return self._box_vectors - - @box_vectors.setter - def box_vectors(self, box_vectors: jnp.array) -> None: - self._box_vectors = box_vectors - self._box_lengths = jnp.array( - [box_vectors[0][0], box_vectors[1][1], box_vectors[2][2]] - ) - @partial(jax.jit, static_argnums=(0,)) def displacement( - self, xyz_1: jnp.array, xyz_2: jnp.array + self, xyz_1: jnp.array, xyz_2: jnp.array, box_vectors: jnp.array ) -> Tuple[jnp.array, jnp.array]: """ Calculate the periodic distance between two points. @@ -89,9 +52,10 @@ def displacement( Parameters ---------- xyz_1: jnp.array - Coordinates of the first point + Positions of the first point xyz_2: jnp.array - Coordinates of the second point + Positions of the second point + box_vectors: jnp.array Returns ------- @@ -101,36 +65,43 @@ def displacement( Distance between the two points """ - # calculate uncorrect r_ij + # calculate uncorrected r_ij r_ij = xyz_1 - xyz_2 - # calculated corrected displacement vector - r_ij = ( - jnp.mod(r_ij + self._box_lengths * 0.5, self._box_lengths) - - self._box_lengths * 0.5 + box_lengths = jnp.array( + [box_vectors[0][0], box_vectors[1][1], box_vectors[2][2]] ) + # calculated corrected displacement vector + # using modulus seems faster in JAX + r_ij = jnp.mod(r_ij + box_lengths * 0.5, box_lengths) - box_lengths * 0.5 # calculate the scalar distance dist = jnp.linalg.norm(r_ij, axis=-1) return r_ij, dist @partial(jax.jit, static_argnums=(0,)) - def wrap(self, xyz: jnp.array) -> jnp.array: + def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: """ - Wrap the coordinates of the system. + Wrap the positions of the system. Parameters ---------- xyz: jnp.array - Coordinates of the system + Positions of the system + box_vectors: jnp.array + Box vectors for the system Returns ------- jnp.array - Wrapped coordinates of the system + Wrapped positions of the system """ - xyz = xyz - jnp.floor(xyz / self._box_lengths) * self._box_lengths + box_lengths = jnp.array( + [box_vectors[0][0], box_vectors[1][1], box_vectors[2][2]] + ) + + xyz = xyz - jnp.floor(xyz / box_lengths) * box_lengths return xyz @@ -141,6 +112,7 @@ def displacement( self, xyz_1: jnp.array, xyz_2: jnp.array, + box_vectors: jnp.array, ) -> Tuple[jnp.array, jnp.array]: """ Calculate the periodic distance between two points. @@ -148,9 +120,11 @@ def displacement( Parameters ---------- xyz_1: jnp.array - Coordinates of the first point + Positions of the first point xyz_2: jnp.array - Coordinates of the second point + Positions of the second point + box_vectors: jnp.array + Box vectors for the system. Returns ------- @@ -169,20 +143,22 @@ def displacement( return r_ij, dist @partial(jax.jit, static_argnums=(0,)) - def wrap(self, xyz: jnp.array) -> jnp.array: + def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: """ - Wrap the coordinates of the system. - For the Non-periodic system, this does not alter the coordinates + Wrap the positions of the system. + For the Non-periodic system, this does not alter the positions Parameters ---------- xyz: jnp.array - Coordinates of the system + Positions of the system + box_vectors: jnp.array + Box vectors for the system Returns ------- jnp.array - Wrapped coordinates of the system + Wrapped positions of the system """ return xyz @@ -190,7 +166,7 @@ def wrap(self, xyz: jnp.array) -> jnp.array: class PairsBase(ABC): """ - Abstract Base Class for different algorithms that determine which particles are interacting. + Abstract Base Class for different algorithms that determine which particle pairs are interacting. Parameters ---------- @@ -207,46 +183,58 @@ class PairsBase(ABC): >>> import jax.numpy as jnp >>> >>> space = OrthogonalPeriodicSpace() # define the simulation space, in this case an orthogonal periodic space - >>> sampler_state = SamplerState(x0=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]]), + >>> sampler_state = SamplerState(positions=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]]), >>> box_vectors=jnp.array([[10, 0.0, 0.0], [0.0, 10, 0.0], [0.0, 0.0, 10]])) >>> >>> pair_list = PairsBase(space, cutoff=2.5*unit.nanometer) # initialize the pair list >>> pair_list.build_from_state(sampler_state) # build the pair list from the sampler state >>> - >>> coordinates = sampler_state.x0 # get the coordinates from the sampler state, without units attached + >>> positions = sampler_state.positions # get the positions from the sampler state, without units attached >>> >>> # the calculate function will produce information used to calculate the energy - >>> n_neighbors, padding_mask, dist, r_ij = pair_list.calculate(coordinates) + >>> n_neighbors, padding_mask, dist, r_ij = pair_list.calculate(positions) >>> """ def __init__( self, space: Space, - cutoff: unit.Quantity = unit.Quantity(1.2, unit.nanometer), + cutoff: Optional[unit.Quantity] = unit.Quantity(1.2, unit.nanometer), ): + """ + Initialize the PairsBase class + + Parameters + ---------- + space: Space + Class that defines how to calculate the displacement between two points and apply the boundary conditions + This should not be changed after initialization. + cutoff: unit.Quantity, default = 1.2 unit.nanometer + Cutoff distance for the neighborlist + + """ if not isinstance(space, Space): raise TypeError(f"space must be of type Space, found {type(space)}") if not cutoff.unit.is_compatible(unit.angstrom): raise ValueError( f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}" ) - self.cutoff = cutoff.value_in_unit_system(unit.md_unit_system) + self.cutoff = cutoff self.space = space @abstractmethod def build( self, - coordinates: Union[jnp.array, unit.Quantity], + positions: Union[jnp.array, unit.Quantity], box_vectors: Union[jnp.array, unit.Quantity], ): """ - Build list from an array of coordinates and array of box vectors. + Build list from an array of positions and array of box vectors. Parameters ---------- - coordinates: jnp.array or unit.Quantity - Shape[n_particles,3] array of particle coordinates, either with or without units attached. + positions: jnp.array or unit.Quantity + Shape[n_particles,3] array of particle positions, either with or without units attached. If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. box_vectors: jnp.array or unit.Quantity Shape[3,3] array of box vectors for the system, either with or without units attached. @@ -261,24 +249,36 @@ def build( def _validate_build_inputs( self, - coordinates: Union[jnp.array, unit.Quantity], + positions: Union[jnp.array, unit.Quantity], box_vectors: Union[jnp.array, unit.Quantity], ): """ Validate the inputs to the build function. + + This will raise ValueErrors if the inputs are not of the correct type or shape or compatible units + + Parameters + ---------- + positions: jnp.array or unit.Quantity + Shape[n_particles,3] array of particle positions, either with or without units attached. + If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. + box_vectors: jnp.array or unit.Quantity + Shape[3,3] array of box vectors for the system, either with or without units attached. + If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. + """ - if isinstance(coordinates, unit.Quantity): - if not coordinates.unit.is_compatible(unit.nanometer): + if isinstance(positions, unit.Quantity): + if not positions.unit.is_compatible(unit.nanometer): raise ValueError( - f"Coordinates require distance units, not {coordinates.unit}" + f"Positions require distance units, not {positions.unit}" ) - self.ref_coordinates = coordinates.value_in_unit_system(unit.md_unit_system) - if isinstance(coordinates, jnp.ndarray): - if coordinates.shape[1] != 3: + self.ref_positions = positions.value_in_unit_system(unit.md_unit_system) + if isinstance(positions, jnp.ndarray): + if positions.shape[1] != 3: raise ValueError( - f"coordinates should be a Nx3 array, shape provided: {coordinates.shape}" + f"positions should be a Nx3 array, shape provided: {positions.shape}" ) - self.ref_coordinates = coordinates + self.ref_positions = positions if isinstance(box_vectors, unit.Quantity): if not box_vectors.unit.is_compatible(unit.nanometer): raise ValueError( @@ -300,7 +300,7 @@ def build_from_state(self, sampler_state: SamplerState): Parameters ---------- sampler_state: SamplerState - SamplerState object containing the coordinates and box vectors + SamplerState object containing the positions and box vectors Returns ------- @@ -309,22 +309,22 @@ def build_from_state(self, sampler_state: SamplerState): if not isinstance(sampler_state, SamplerState): raise TypeError(f"Expected SamplerState, got {type(sampler_state)} instead") - coordinates = sampler_state.x0 + positions = sampler_state.positions if sampler_state.box_vectors is None: raise ValueError(f"SamplerState does not contain box vectors") box_vectors = sampler_state.box_vectors - self.build(coordinates, box_vectors) + self.build(positions, box_vectors) @abstractmethod - def calculate(self, coordinates: jnp.array): + def calculate(self, positions: jnp.array): """ - Calculate the neighbor list for the current state + Calculate the list of interacting particles for the current state Parameters ---------- - coordinates: jnp.array - Shape[N,3] array of particle coordinates + positions: jnp.array + Shape[N,3] array of particle positions Returns ------- @@ -343,15 +343,16 @@ def calculate(self, coordinates: jnp.array): pass @abstractmethod - def check(self, coordinates: jnp.array) -> bool: + def check(self, positions: jnp.array) -> bool: """ - Check if the internal variables need to be reset. E.g., rebuilding a neighborlist - Should do nothing for a simple pairlist. + Check if the internal variables need to be reset. E.g., rebuilding a neighborlist if particles moved to far, + or rebuilding if number of particles changes. + Parameters ---------- - coordinates: jnp.array - Array of particle coordinates + positions: jnp.array + Array of particle positions Returns ------- bool @@ -362,22 +363,69 @@ def check(self, coordinates: jnp.array) -> bool: class NeighborListNsqrd(PairsBase): """ - N^2 neighborlist implementation that returns the particle pair ids, displacement vectors, and distances. + A JAX based neighbor list implementation used to determine which pairs of particles are interacting + (i.e., those particles that fall within the specified cutoff). + + The neighbor list (i.e., list of particles within a distance of cutoff+skin of a given particle) is generated + within the `build` function using an O(N^2) calculation rather than using a spatial partitioning scheme + (e.g., cell-list). The `calculate` function that uses the neighbor list to determine which particle pairs are + interacting and determine the distances and displacement vectors between interacting pairs of particles for + use in the calculation of the interaction energies/forces. The routines are subject to the boundary conditions + specified by the Space class. + + Notes: + This neighbor list not include self-interactions and only includes unique pairs (i.e., no double-counting). + This is sometimes referred to as a "half" neighbor list. E.g. consider the pair of neighboring particles (A, B): + in the "half" neighbor list approach, B is in the neighbor list of A, but A is not in the neighbor list of B + as that pair is already accounted for. + . + The output of the `calculate` function is padded to a fixed size, `n_max_neighbors` (default=100), + to allow for efficient jitted computations in JAX. As such, values need to be masked using the `padding_mask` + array returned by the `calculate` function. The padding mask is an array of 1s and 0s, where 1 indicates an + interacting neighbor and 0 indicates the pair is either non-interacting or simply a padded value. + The `build` function will iteratively increase `n_max_neighbors` by 10 until we can store all neighbors. + + The `check` function, which indicates if the neighbor list should be rebuilt, will return True if: + - the number of particles changes + - any of the particles have moved more than half the skin distance from their reference positions (i.e., the + positions of particles when the neighbor list was last built). - Parameters - ---------- - space: Space - Class that defines how to calculate the displacement between two points and apply the boundary conditions - cutoff: float, default = 2.5 - Cutoff distance for the neighborlist - skin: float, default = 0.4 - Skin distance for the neighborlist - n_max_neighbors: int, default=200 - Maximum number of neighbors for each particle. Used for padding arrays for efficient jax computations - This will be checked and dynamically updated during the build stage - Examples - -------- + Parameters + ---------- + space: Space + Class that defines how to calculate the displacement between two points and apply the boundary conditions. + This should not be changed after initialization. + cutoff: unit.Quantity, default = 1.2 unit.nanometer + Cutoff distance for the neighborlist + skin: unit.Quantity, default = 0.4 unit.nanometer + Skin distance, i.e., buffer, for the neighborlist + Larger values of the skin will reduce the frequency of rebuilding the neighbor list, + but will increase the number of neighbors to consider. + n_max_neighbors: int, default=200 + Maximum number of neighbors for each particle. This is used for padding arrays for efficient jax computations + n_max_neighbors will be dynamically updated (in increments of 10) as part of the build function. + Examples + -------- + >>> from openmm import unit + >>> import jax.numpy as jnp + >>> + >>> from chiron.states import SamplerState + >>> sampler_state = SamplerState(positions=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]])*unit.nanometer, + >>> box_vectors=jnp.array([[10, 0.0, 0.0], [0.0, 10, 0.0], [0.0, 0.0, 10]])*unit.nanometer) + >>> + >>> from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace + >>> nbr_list = NeighborListNsqrd(OrthogonalPeriodicSpace(), cutoff=1.2*unit.nanometer, skin=0.4*unit.nanometer) + >>> + >>> # build the neighborlist + >>> nbr_list.build_from_state(sampler_state) # build the pair list from the sampler state + >>> + >>> # calculate which particles are interacting along with their distances and displacement vectors + >>> n_neighbors, neighbor_list, padding_mask, dist, r_ij = nbr_list.calculate(sampler_state.positions) + >>> + >>> # check the neighborlist + >>> if nbr_list.check(sampler_state.positions): + >>> nbr_list.build_from_state(sampler_state) # rebuild the pair list from the sampler state """ @@ -390,27 +438,64 @@ def __init__( ): if not isinstance(space, Space): raise TypeError(f"space must be of type Space, found {type(space)}") - if not cutoff.unit.is_compatible(unit.angstrom): - raise ValueError( - f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}" - ) + if not skin.unit.is_compatible(unit.angstrom): raise ValueError( f"cutoff must be a unit.Quantity with units of distance, skin.unit = {skin.unit}" ) - self.cutoff = cutoff.value_in_unit_system(unit.md_unit_system) - self.skin = skin.value_in_unit_system(unit.md_unit_system) - self.cutoff_and_skin = self.cutoff + self.skin + self.cutoff = cutoff + self.skin = skin self.n_max_neighbors = n_max_neighbors self.space = space - # set a a simple variable to know if this has at least been built once as opposed to just initialized - # this does not imply that the neighborlist is up to date + # this variable will ensure that `calculate` will fail if we try to call it before building + # note: self.is_built=True does not imply that the neighborlist is up-to-date + self.is_built = False + + @property + def cutoff(self) -> unit.Quantity: + return self._cutoff + + @cutoff.setter + def cutoff(self, cutoff: unit.Quantity) -> None: + if not cutoff.unit.is_compatible(unit.nanometer): + raise ValueError( + f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}" + ) + self._cutoff = cutoff + + # if we change the cutoff or skin we need to rebuild + # we will set the variable to ensure that attempts to call the calculate function will fail if + # we have not rebuilt the neighbor list + self.is_built = False + + @property + def skin(self) -> unit.Quantity: + return self._skin + + @skin.setter + def skin(self, skin: unit.Quantity) -> None: + if not skin.unit.is_compatible(unit.nanometer): + raise ValueError( + f"skin must be a unit.Quantity with units of distance, skin.unit = {skin.unit}" + ) + self._skin = skin + + # if we change the cutoff or skin we need to rebuild + # we will set the variable to ensure that attempts to call the calculate function will fail if + # we have not rebuilt the neighbor list self.is_built = False - # note, we need to use the partial decorator in order to use the jit decorate - # so that it knows to ignore the `self` argument + # Note, we need to use the partial decorator and declare self as static in order to JIT a function within a class. + # This approach treats internal variables of the class as static within this function; e.g., if set self.cutoff = 2, + # called the function, then changed it to 3, the value of self.cutoff in this function would still be 2. + # Thus, we need to pass any variables that may change as arguments, rather than referencing self.variable_name. + # While we could create a custom pytree instead of declaring the class as static (allowing us to reference class + # variables directly within the JITTED function), any changes to those internal variables, say self.cutoff, + # would mean a change to the hash of any JITTEd function that depends on the variable, requiring JAX to recompile + # the function, which is a slow operation. As such, it is also more efficient to just pass variables as arguments. + @partial(jax.jit, static_argnums=(0,)) def _pairs_mask(self, particle_ids: jnp.array): """ @@ -444,9 +529,18 @@ def _pairs_mask(self, particle_ids: jnp.array): return temp_mask + # note: since n_max_neighbors dictates the output size, we will define it as a static argument + # to allow us to jit this function @partial(jax.jit, static_argnums=(0, 5)) def _build_neighborlist( - self, particle_i, reduction_mask, pid, coordinates, n_max_neighbors + self, + particle_i, + reduction_mask, + pid, + positions, + n_max_neighbors, + cutoff_and_skin, + box_vectors, ): """ Jitted function to build the neighbor list for a single particle @@ -454,13 +548,17 @@ def _build_neighborlist( Parameters ---------- particle_i: jnp.array - X,Y,Z coordinates of particle i + X,Y,Z positions of particle i reduction_mask: jnp.array Mask to exclude self-interactions and double counting of pairs - coordinates: jnp.array - X,Y,Z coordinates of all particles + positions: jnp.array + X,Y,Z positions of all particles n_max_neighbors: int Maximum number of neighbors for each particle. Used for padding arrays for efficient jax computations + cutoff_and_skin: float + Cutoff distance for the neighborlist plus the skin distance, in nanometers. + box_vectors: jnp.array + Box vectors for the system Returns ------- @@ -472,19 +570,27 @@ def _build_neighborlist( Number of neighbors for the particle """ - # calculate the displacement between particle i and all other particles - r_ij, dist = self.space.displacement(particle_i, coordinates) + # Calculate the displacement between particle i and all other particles + # NOTE: It would be safer to pass the displacement calculate as a callable function, instead of referencing + # self.space. If someone changes the boundary conditions (i.e., changes space in the class), + # self.space.displacement will not change since the self is marked as status. + # However, I ran into issues passing a function through vmap, and I haven't been able to figure out how to + # resolve it yet. I do not want to remove vmap, as that would require substantially changing the flow of + # the code. For now, I've noted in the docstring that space should not change after initialization -- CRI + r_ij, dist = self.space.displacement(particle_i, positions, box_vectors) - # neighbor_mask will be an array of length n_particles (i.e., length of coordinates) + # neighbor_mask will be an array of length n_particles (i.e., length of positions) # where each element is True if the particle is a neighbor, False if it is not # subject to both the cutoff+skin and the reduction mask that eliminates double counting and self-interactions neighbor_mask = jnp.where( - (dist < self.cutoff_and_skin) & (reduction_mask), True, False + (dist < cutoff_and_skin) & (reduction_mask), True, False ) # when we pad the neighbor list, we will use last particle id in the neighbor list - # this choice was made such that when we use the neighbor list in the masked energy calculat + # this choice was made such that when we use the neighbor list in the masked energy calculation # the padded values will result in reasonably well defined values fill_value = jnp.argmax(neighbor_mask) + # if the max value is the same as the particle of interest, which can occur if particle 0 has no neighbors + # we will just increment by 1 to avoid calculating a self interaction fill_value = jnp.where(fill_value == pid, fill_value + 1, fill_value) # count up the number of neighbors @@ -506,16 +612,16 @@ def _build_neighborlist( def build( self, - coordinates: Union[jnp.array, unit.Quantity], + positions: Union[jnp.array, unit.Quantity], box_vectors: Union[jnp.array, unit.Quantity], ): """ - Build the neighborlist from an array of coordinates and box vectors. + Build the neighbor list from an array of positions and box vectors. Parameters ---------- - coordinates: jnp.array - Shape[N,3] array of particle coordinates + positions: jnp.array + Shape[N,3] array of particle positions box_vectors: jnp.array Shape[3,3] array of box vectors @@ -525,14 +631,14 @@ def build( """ - # set our reference coordinates - # the call to x0 and box_vectors automatically convert these to jnp arrays in the correct unit system - if isinstance(coordinates, unit.Quantity): - if not coordinates.unit.is_compatible(unit.nanometer): + # set our reference positions + # the call to positions and box_vectors automatically convert these to jnp arrays in the correct unit system + if isinstance(positions, unit.Quantity): + if not positions.unit.is_compatible(unit.nanometer): raise ValueError( - f"Coordinates require distance units, not {coordinates.unit}" + f"Positions require distance units, not {positions.unit}" ) - coordinates = coordinates.value_in_unit_system(unit.md_unit_system) + positions = positions.value_in_unit_system(unit.md_unit_system) if isinstance(box_vectors, unit.Quantity): if not box_vectors.unit.is_compatible(unit.nanometer): @@ -546,16 +652,17 @@ def build( f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" ) - self.ref_coordinates = coordinates + self.ref_positions = positions self.box_vectors = box_vectors + cutoff_and_skin = self.cutoff + self.skin + # the neighborlist assumes that the box vectors do not change between building and calculating the neighbor list # changes to the box vectors require rebuilding the neighbor list - self.space.box_vectors = self.box_vectors # store the ids of all the particles self.particle_ids = jnp.array( - range(0, self.ref_coordinates.shape[0]), dtype=jnp.uint32 + range(0, self.ref_positions.shape[0]), dtype=jnp.uint32 ) # calculate which pairs to exclude @@ -569,13 +676,15 @@ def build( # n_neighbors: an array of shape (n_particles) where each element is the number of neighbors for that particle self.neighbor_mask, self.neighbor_list, self.n_neighbors = jax.vmap( - self._build_neighborlist, in_axes=(0, 0, 0, None, None) + self._build_neighborlist, in_axes=(0, 0, 0, None, None, None, None) )( - self.ref_coordinates, + self.ref_positions, reduction_mask, self.particle_ids, - self.ref_coordinates, + self.ref_positions, self.n_max_neighbors, + cutoff_and_skin.value_in_unit_system(unit.md_unit_system), + self.box_vectors, ) self.neighbor_list = self.neighbor_list.reshape(-1, self.n_max_neighbors) @@ -588,13 +697,15 @@ def build( self.n_max_neighbors = int(jnp.max(self.n_neighbors) + 10) self.neighbor_mask, self.neighbor_list, self.n_neighbors = jax.vmap( - self._build_neighborlist, in_axes=(0, 0, 0, None, None) + self._build_neighborlist, in_axes=(0, 0, 0, None, None, None, None) )( - self.ref_coordinates, + self.ref_positions, reduction_mask, self.particle_ids, - self.ref_coordinates, + self.ref_positions, self.n_max_neighbors, + cutoff_and_skin.value_in_unit_system(unit.md_unit_system), + self.box_vectors, ) self.neighbor_list = self.neighbor_list.reshape(-1, self.n_max_neighbors) @@ -603,7 +714,7 @@ def build( @partial(jax.jit, static_argnums=(0,)) def _calc_distance_per_particle( - self, particle1, neighbors, neighbor_mask, coordinates + self, particle1, neighbors, neighbor_mask, positions, cutoff, box_vectors ): """ Jitted function to calculate the distance between a particle and its neighbors @@ -616,8 +727,12 @@ def _calc_distance_per_particle( Array of particle ids for the neighbors of particle1 neighbor_mask: jnp.array Mask to exclude padding from the neighbor list of particle1 - coordinates: jnp.array - X,Y,Z coordinates of all particles + positions: jnp.array + X,Y,Z positions of all particles + cutoff: float + Cutoff distance for the neighborlist, in nanometers + box_vectors: jnp.array + Box vectors for the system Returns ------- @@ -635,26 +750,27 @@ def _calc_distance_per_particle( particles1 = jnp.repeat(particle1, neighbors.shape[0]) # calculate the displacement between particle i and all neighbors + # See note above: if self.space changes, it will not show up here because self is static. r_ij, dist = self.space.displacement( - coordinates[particles1], coordinates[neighbors] + positions[particles1], positions[neighbors], box_vectors ) # calculate the mask to determine if the particle is a neighbor # this will be done based on the interaction cutoff and using the neighbor_mask to exclude padding - mask = jnp.where((dist < self.cutoff) & (neighbor_mask), 1, 0) + mask = jnp.where((dist < cutoff) & (neighbor_mask), 1, 0) # calculate the number of pairs n_pairs = mask.sum() return n_pairs, mask, dist, r_ij - def calculate(self, coordinates: jnp.array): + def calculate(self, positions: jnp.array): """ Calculate the neighbor list for the current state Parameters ---------- - coordinates: jnp.array - Shape[N,3] array of particle coordinates + positions: jnp.array + Shape[N,3] array of particle positions Returns ------- @@ -669,20 +785,34 @@ def calculate(self, coordinates: jnp.array): r_ij: jnp.array Array of displacement vectors between each particle and its neighbors. Shape (n_particles, n_max_neighbors, 3) """ - # coordinates = sampler_state.x0 + # positions = sampler_state.positions # note, we assume the box vectors do not change between building and calculating the neighbor list # changes to the box vectors require rebuilding the neighbor list n_neighbors, padding_mask, dist, r_ij = jax.vmap( - self._calc_distance_per_particle, in_axes=(0, 0, 0, None) - )(self.particle_ids, self.neighbor_list, self.neighbor_mask, coordinates) + self._calc_distance_per_particle, in_axes=(0, 0, 0, None, None, None) + )( + self.particle_ids, + self.neighbor_list, + self.neighbor_mask, + positions, + self.cutoff.value_in_unit_system(unit.md_unit_system), + self.box_vectors, + ) # mask = mask.reshape(-1, self.n_max_neighbors) return n_neighbors, self.neighbor_list, padding_mask, dist, r_ij @partial(jax.jit, static_argnums=(0,)) - def _calculate_particle_displacement(self, particle, coordinates, ref_coordinates): + def _calculate_particle_displacement( + self, + particle: int, + positions: jnp.array, + ref_positions: jnp.array, + skin: float, + box_vectors: jnp.array, + ): """ - Calculate the displacement of a particle from the reference coordinates. + Calculate the displacement of a particle from the reference positions. If the displacement exceeds the half the skin distance, return True, otherwise return False. This function is designed to allow it to be jitted and vmapped over particle indices. @@ -691,50 +821,61 @@ def _calculate_particle_displacement(self, particle, coordinates, ref_coordinate ---------- particle: int Particle id - coordinates: jnp.array - Array of particle coordinates - ref_coordinates: jnp.array - Array of reference particle coordinates + positions: jnp.array + Array of particle positions + ref_positions: jnp.array + Array of reference particle positions + skin: float + Skin distance for the neighborlist, in nanometers + box_vectors: jnp.array + Box vectors for the system + Returns ------- bool True if the particle is outside the skin distance, False if it is not. """ - # calculate the displacement of a particle from the initial coordinates - + # calculate the displacement of a particle from the initial positions + # again, note that if self.space changes, it will not show up here because self is static. r_ij, displacement = self.space.displacement( - coordinates[particle], ref_coordinates[particle] + positions[particle], ref_positions[particle], box_vectors ) - status = jnp.where(displacement >= self.skin / 2.0, True, False) + status = jnp.where(displacement >= skin / 2.0, True, False) del displacement return status - def check(self, coordinates: jnp.array) -> bool: + def check(self, positions: jnp.array) -> bool: """ - Check if the neighbor list needs to be rebuilt based on displacement of the particles from the reference coordinates. + Check if the neighbor list needs to be rebuilt based on displacement of the particles from the reference positions. If a particle moves more than 0.5 skin distance, the neighborlist will be rebuilt. - Will also return True if the size of the coordinates array changes. + Will also return True if the size of the positions array changes. Note, this could also accept a user defined criteria for distance, but this is not implemented yet. Parameters ---------- - coordinates: jnp.array - Array of particle coordinates + positions: jnp.array + Array of particle positions Returns ------- bool True if the neighbor list needs to be rebuilt, False if it does not. """ - if self.ref_coordinates.shape[0] != coordinates.shape[0]: + if self.ref_positions.shape[0] != positions.shape[0]: return True status = jax.vmap( - self._calculate_particle_displacement, in_axes=(0, None, None) - )(self.particle_ids, coordinates, self.ref_coordinates) + self._calculate_particle_displacement, in_axes=(0, None, None, None, None) + )( + self.particle_ids, + positions, + self.ref_positions, + self.skin.value_in_unit_system(unit.md_unit_system), + self.box_vectors, + ) if jnp.any(status): del status return True @@ -743,56 +884,114 @@ def check(self, coordinates: jnp.array) -> bool: return False -class PairList(PairsBase): +class PairListNsqrd(PairsBase): """ - N^2 pairlist implementation that returns the particle pair ids, displacement vectors, and distances. + A class that implements a simple pair list using JAX that determine which pairs of particles are interacting. + This class can be defined with cutoff (i.e., only returning information about pairs separated by distances + less than the cutoff) or without a cutoff (i.e., information about all possible pairs are returned). + Note, in both cases, distances are calculated using the boundary conditions defined by the simulation Space class + and only unique pairs are returned (i.e., no double counting and no self-interactions). + + This performs an O(N^2) calculation each time the `calculate` function is called and thus will be inefficient + for all but very small system sizes. + + The calculate function will return various pieces of information about the interacting pairs + (e.g., number of neighbors, neighbor ids, distances, displacement vectors) that can be used to calculate the + interaction potential/force. For efficiency of the jitted functions, the `calculate` function array + sizes are fixed. For example, distance has shape (n_particles, n_particles-1), regardless of the number of particles + that are actually neighbors (note: self interactions are removed hence n_particles-1). The `padding_mask` array + returned by `calculate` is used to exclude those pairs that are not interacting. The `padding_mask` contains values + of 1 for interacting particles and 0 for non-interacting. Parameters ---------- space: Space Class that defines how to calculate the displacement between two points and apply the boundary conditions - cutoff: float, default = 2.5 - Cutoff distance for the pair list calculation + cutoff: Optional[unit.Quantity], default = None + Cutoff distance for the pair list calculation. If None, the pair list will be calculated without a cutoff, + applying the boundary conditions as defined in space. + Examples -------- - >>> from chiron.neighbors import PairList, OrthogonalPeriodicSpace - >>> from chiron.states import SamplerState >>> import jax.numpy as jnp + >>> import openmm.unit as unit >>> - >>> space = OrthogonalPeriodicSpace() - >>> pair_list = PairList(space, cutoff=2.5) - >>> sampler_state = SamplerState(x0=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]]), + >>> from chiron.states import SamplerState + >>> sampler_state = SamplerState(positions=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]]), >>> box_vectors=jnp.array([[10, 0.0, 0.0], [0.0, 10, 0.0], [0.0, 0.0, 10]])) + >>> + >>> from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace + >>> pair_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=1.2*unit.nanometer) >>> pair_list.build_from_state(sampler_state) >>> - >>> # mask and distances are of shape (n_particles, n_particles-1), - >>> displacement_vectors of shape (n_particles, n_particles-1, 3) - >>> # mask, is a bool array that is True if the particle is within the cutoff distance, False if it is not - >>> # n_pairs is of shape (n_particles) and is per row sum of the mask. The mask ensure we also do not double count pairs - >>> n_pairs, mask, distances, displacement_vectors = pair_list.calculate(sampler_state.x0) + >>> # n_pairs is of shape (n_particles) and is per row sum of the padding_mask. + >>> # pairs, padding mask and distances are of shape (n_particles, n_particles-1), + >>> # displacement_vectors are of shape (n_particles, n_particles-1, 3) + >>> # padding_mask, is a bool array that is True if the particle is within the cutoff distance, False if it is not + >>> n_pairs, pairs, padding_mask, distances, displacement_vectors = pair_list.calculate(sampler_state.positions) """ def __init__( self, space: Space, - cutoff: unit.Quantity = unit.Quantity(1.2, unit.nanometer), + cutoff: Optional[unit.Quantity] = None, ): + """ + Initialize the PairListNsqrd class + + Parameters + ---------- + space: Space + Class that defines how to calculate the displacement between two points and apply the boundary conditions. + This should not change after initialization. + cutoff: Optional[unit.Quantity], default = None + Cutoff distance for the pair list calculation. If None, the pair list will be calculated without a cutoff. + """ if not isinstance(space, Space): raise TypeError(f"space must be of type Space, found {type(space)}") - if not cutoff.unit.is_compatible(unit.angstrom): - raise ValueError( - f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}" - ) - self.cutoff = cutoff.value_in_unit_system(unit.md_unit_system) + # keeping this public in case we want to change it later + # validation is performed in the setter + self.cutoff = cutoff + self.space = space - # set a a simple variable to know if this has at least been built once as opposed to just initialized - # this does not imply that the neighborlist is up to date + # the init function does not setup the internal arrays we need to use calculate + # this is handled in the `build` function + # this variable can be used to check that the pair list has been built before trying to use it self.is_built = False - # note, we need to use the partial decorator in order to use the jit decorate - # so that it knows to ignore the `self` argument + @property + def cutoff(self): + """ + Cutoff distance for the pair list calculation. If None, the pair list will be calculated without a cutoff. + + Returns + ------- + cutoff: unit.Quantity + Cutoff distance for the pair list calculation. If None, the pair list will be calculated without a cutoff. + """ + return self._cutoff + + @cutoff.setter + def cutoff(self, cutoff): + if cutoff is not None: + if not cutoff.unit.is_compatible(unit.angstrom): + raise ValueError( + f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}" + ) + # Note, since this is just a simple pair list, we do not need to rebuild by changing the cutoff + self._cutoff = cutoff + + # Note, we need to use the partial decorator and declare self as static in order to JIT a function within a class. + # As mentioned in a comment above in the NeighborListNsqrd class, this approach treats internal variables of the + # class as static within this function; e.g., if set self.cutoff = 2, called the function, then changed it to 3, + # the value of self.cutoff in this function would still be 2. Thus, we need to pass any variables that may change + # as arguments, rather than referencing self.variable_name. While we could create a custom pytree instead of + # declaring the class as static (allowing us to reference class variables directly within the JITTED function), + # any changes to those internal variables, say self.cutoff, would mean a change to the hash of any JITTEd function + # that depends on the variable, requiring JAX to recompile the function, which is a slow operation. + # As such, it is also more efficient to just pass variables as arguments. @partial(jax.jit, static_argnums=(0,)) def _pairs_and_mask(self, particle_ids: jnp.array): """ @@ -826,12 +1025,14 @@ def _pairs_and_mask(self, particle_ids: jnp.array): particles_i = jnp.reshape(particle_ids, (particle_ids.shape[0], 1)) # create a mask to exclude self interactions and double counting temp_mask = particles_i != particles_j + # remove self interactions all_pairs = jax.vmap(self._remove_self_interactions, in_axes=(0, 0))( particles_j, temp_mask ) del temp_mask all_pairs = jnp.array(all_pairs[0], dtype=jnp.uint32) + # create the mask that will remove any double counting of pairs reduction_mask = jnp.where(particles_i < all_pairs, True, False) return all_pairs, reduction_mask @@ -844,16 +1045,16 @@ def _remove_self_interactions(self, particles, temp_mask): def build( self, - coordinates: Union[jnp.array, unit.Quantity], + positions: Union[jnp.array, unit.Quantity], box_vectors: Union[jnp.array, unit.Quantity], ): """ - Build the neighborlist from an array of coordinates and box vectors. + Build the list from an array of positions and box vectors. Parameters ---------- - coordinates: jnp.array - Shape[n_particles,3] array of particle coordinates + positions: jnp.array + Shape[n_particles,3] array of particle positions box_vectors: jnp.array Shape[3,3] array of box vectors @@ -863,27 +1064,80 @@ def build( """ - # set our reference coordinates - # this will set self.ref_coordinates=coordinates and self.box_vectors - self._validate_build_inputs(coordinates, box_vectors) + # validate the positions and box vectors + self._validate_build_inputs(positions, box_vectors) - self.n_particles = self.ref_coordinates.shape[0] + self.n_particles = self.ref_positions.shape[0] - # the neighborlist assumes that the box vectors do not change between building and calculating the neighbor list - # changes to the box vectors require rebuilding the neighbor list - self.space.box_vectors = self.box_vectors + # the PairsList assumes that the box vectors do not change between building and calculating the neighbor list # store the ids of all the particles - self.particle_ids = jnp.array(range(0, coordinates.shape[0]), dtype=jnp.uint32) + self.particle_ids = jnp.array(range(0, positions.shape[0]), dtype=jnp.uint32) # calculate which pairs to exclude self.all_pairs, self.reduction_mask = self._pairs_and_mask(self.particle_ids) self.is_built = True - @partial(jax.jit, static_argnums=(0,)) - def _calc_distance_per_particle( - self, particle1, neighbors, neighbor_mask, coordinates + @partial(jax.jit, static_argnums=(0)) + def _calc_distance_per_particle_with_cutoff( + self, particle1, neighbors, neighbor_mask, positions, cutoff, box_vectors + ): + """ + Jitted function to calculate the distance between a particle and all possible neighbors + + Parameters + ---------- + particle1: int + Particle id + neighbors: jnp.array + Array of particle ids for the possible particle pairs of particle1 + neighbor_mask: jnp.array + Mask to exclude double particles to prevent double counting + positions: jnp.array + X,Y,Z positions of all particles, shaped (n_particles, 3) + cutoff: float + Cutoff distance for the interaction. + box_vectors: jnp.array + Box vectors for the system + + Returns + ------- + n_pairs: int + Number of interacting pairs for the particle + mask: jnp.array + Mask to exclude padding particles not within the cutoff particle1. + If a particle is within the interaction cutoff, the mask is 1, otherwise it is 0 + Array has shape (n_particles, n_particles-1) as it excludes self interactions + dist: jnp.array + Array of distances between the particle and all other particles in the system. + Array has shape (n_particles, n_particles-1) as it excludes self interactions + r_ij: jnp.array + Array of displacement vectors between the particle and all other particles in the system. + Array has shape (n_particles, n_particles-1, 3) as it excludes self interactions + . + + """ + # repeat the particle id for each neighbor + particles1 = jnp.repeat(particle1, neighbors.shape[0]) + + # calculate the displacement between particle i and all neighbors + # See note above: if self.space changes, it will not show up here because self is static. + r_ij, dist = self.space.displacement( + positions[particles1], positions[neighbors], box_vectors + ) + # calculate the mask to determine if the particle is a neighbor + # this will be done based on the interaction cutoff and using the neighbor_mask to exclude padding + mask = jnp.where((dist < cutoff) & (neighbor_mask), 1, 0) + + # calculate the number of pairs + n_pairs = mask.sum() + + return n_pairs, mask, dist, r_ij + + @partial(jax.jit, static_argnums=(0)) + def _calc_distance_per_particle_no_cutoff( + self, particle1, neighbors, neighbor_mask, positions, box_vectors ): """ Jitted function to calculate the distance between a particle and all possible neighbors @@ -896,8 +1150,10 @@ def _calc_distance_per_particle( Array of particle ids for the possible particle pairs of particle1 neighbor_mask: jnp.array Mask to exclude double particles to prevent double counting - coordinates: jnp.array - X,Y,Z coordinates of all particles, shaped (n_particles, 3) + positions: jnp.array + X,Y,Z positions of all particles, shaped (n_particles, 3) + box_vectors: jnp.array + Box vectors of the system Returns ------- @@ -914,31 +1170,33 @@ def _calc_distance_per_particle( Array of displacement vectors between the particle and all other particles in the system. Array has shape (n_particles, n_particles-1, 3) as it excludes self interactions + """ # repeat the particle id for each neighbor particles1 = jnp.repeat(particle1, neighbors.shape[0]) # calculate the displacement between particle i and all neighbors + # See note above: if self.space changes, it will not show up here because self is static. r_ij, dist = self.space.displacement( - coordinates[particles1], coordinates[neighbors] + positions[particles1], positions[neighbors], box_vectors ) # calculate the mask to determine if the particle is a neighbor # this will be done based on the interaction cutoff and using the neighbor_mask to exclude padding - mask = jnp.where((dist < self.cutoff) & (neighbor_mask), 1, 0) + mask = jnp.where(neighbor_mask, 1, 0) # calculate the number of pairs n_pairs = mask.sum() return n_pairs, mask, dist, r_ij - def calculate(self, coordinates: jnp.array): + def calculate(self, positions: jnp.array): """ - Calculate the neighbor list for the current state + Calculate the list of neighbor pairs for the current state Parameters ---------- - coordinates: jnp.array - Shape[n_particles,3] array of particle coordinates + positions: jnp.array + Shape[n_particles,3] array of particle positions Returns ------- @@ -953,35 +1211,53 @@ def calculate(self, coordinates: jnp.array): r_ij: jnp.array Array of displacement vectors between particle pairs. Shape: (n_particles, n_particles-1, 3). """ - if coordinates.shape[0] != self.n_particles: + if positions.shape[0] != self.n_particles: raise ValueError( f"Number of particles cannot changes without rebuilding. " - f"Coordinates must have shape ({self.n_particles}, 3), found {coordinates.shape}" + f"Positions must have shape ({self.n_particles}, 3), found {positions.shape}" ) - # coordinates = self.space.wrap(coordinates) - - n_neighbors, padding_mask, dist, r_ij = jax.vmap( - self._calc_distance_per_particle, in_axes=(0, 0, 0, None) - )(self.particle_ids, self.all_pairs, self.reduction_mask, coordinates) - + # if we did not define a cutoff, we will + if self.cutoff is None: + n_neighbors, padding_mask, dist, r_ij = jax.vmap( + self._calc_distance_per_particle_no_cutoff, + in_axes=(0, 0, 0, None, None), + )( + self.particle_ids, + self.all_pairs, + self.reduction_mask, + positions, + self.box_vectors, + ) + else: + n_neighbors, padding_mask, dist, r_ij = jax.vmap( + self._calc_distance_per_particle_with_cutoff, + in_axes=(0, 0, 0, None, None, None), + )( + self.particle_ids, + self.all_pairs, + self.reduction_mask, + positions, + self.cutoff.value_in_unit_system(unit.md_unit_system), + self.box_vectors, + ) return n_neighbors, self.all_pairs, padding_mask, dist, r_ij - def check(self, coordinates: jnp.array) -> bool: + def check(self, positions: jnp.array) -> bool: """ Check if we need to reconstruct internal arrays. For a simple pairlist this will always return False, unless the number of particles change. Parameters ---------- - coordinates: jnp.array - Array of particle coordinates + positions: jnp.array + Array of particle positions Returns ------- bool True if we need to rebuild the neighbor list, False if we do not. """ - if coordinates.shape[0] != self.n_particles: + if positions.shape[0] != self.n_particles: return True else: return False diff --git a/chiron/potential.py b/chiron/potential.py index 6d9b415..1d2d340 100644 --- a/chiron/potential.py +++ b/chiron/potential.py @@ -63,6 +63,70 @@ def compute_pairlist(self, positions, cutoff) -> jnp.array: return distance[interacting_mask], displacement_vectors[interacting_mask], pairs +class IdealGasPotential(NeuralNetworkPotential): + def __init__( + self, + topology: Topology, + ): + """ + Initialize the Ideal Gas potential. + + Parameters + ---------- + topology : Topology + The topology of the system + + """ + + if not isinstance(topology, (Topology, property)) and topology is not None: + raise TypeError( + f"Topology must be a Topology object, a property, or None, got type(topology) = {type(topology)}" + ) + + self.topology = topology + + def compute_energy(self, positions: jnp.array, nbr_list=None, debug_mode=False): + """ + Compute the energy for an ideal gas, which is always 0. + + Parameters + ---------- + positions : jnp.array + The positions of the particles in the system + nbr_list : NeighborList, default=None + Instance of a neighbor list or pair list class to use. + If None, an unoptimized N^2 pairlist will be used without PBC conditions. + Returns + ------- + potential_energy : float + The total potential energy of the system. + + """ + # Compute the pair distances and displacement vectors + + return 0.0 + + def compute_force(self, positions: jnp.array, nbr_list=None) -> jnp.array: + """ + Compute the force for ideal gas particles, which is always 0. + + Parameters + ---------- + positions : jnp.array + The positions of the particles in the system + nbr_list : NeighborList, optional + Instance of the neighborlist class to use. By default, set to None, which will use an N^2 pairlist + + Returns + ------- + force : jnp.array + The forces on the particles in the system + + """ + + return 0.0 + + class LJPotential(NeuralNetworkPotential): def __init__( self, @@ -200,7 +264,7 @@ def compute_energy(self, positions: jnp.array, nbr_list=None, debug_mode=False): raise ValueError("Neighborlist must be built before use") # ensure that the cutoff in the neighbor list is the same as the cutoff in the potential - if nbr_list.cutoff != self.cutoff: + if nbr_list.cutoff.value_in_unit_system(unit.md_unit_system) != self.cutoff: raise ValueError( f"Neighborlist cutoff ({nbr_list.cutoff}) must be the same as the potential cutoff ({self.cutoff})" ) @@ -285,7 +349,7 @@ def __init__( The topology object representing the molecular system. k : unit.Quantity, optional The spring constant of the harmonic potential. Default is 1.0 kcal/mol/Å^2. - x0 : unit.Quantity, optional + positions : unit.Quantity, optional The equilibrium position of the harmonic potential. Default is [0.0,0.0,0.0] Å. U0 : unit.Quantity, optional The offset potential energy of the harmonic potential. Default is 0.0 kcal/mol. @@ -302,7 +366,9 @@ def __init__( if not isinstance(k, unit.Quantity): raise TypeError(f"k must be a unit.Quantity, type(k) = {type(k)}") if not isinstance(x0, unit.Quantity): - raise TypeError(f"x0 must be a unit.Quantity, type(x0) = {type(x0)}") + raise TypeError( + f"positions must be a unit.Quantity, type(positions) = {type(x0)}" + ) if not isinstance(U0, unit.Quantity): raise TypeError(f"U0 must be a unit.Quantity, type(U0) = {type(U0)}") @@ -312,9 +378,11 @@ def __init__( ) if not x0.unit.is_compatible(unit.angstrom): raise ValueError( - f"x0 must be a unit.Quantity with units of distance, x0.unit = {x0.unit}" + f"positions must be a unit.Quantity with units of distance, positions.unit = {x0.unit}" ) - assert x0.shape[1] == 3, f"x0 must be a NX3 vector, x0.shape = {x0.shape}" + assert ( + x0.shape[1] == 3 + ), f"positions must be a NX3 vector, positions.shape = {x0.shape}" if not U0.unit.is_compatible(unit.kilocalories_per_mole): raise ValueError( f"U0 must be a unit.Quantity with units of energy, U0.unit = {U0.unit}" @@ -324,9 +392,11 @@ def __init__( log.debug("Initializing HarmonicOscillatorPotential") log.debug(f"k = {k}") - log.debug(f"x0 = {x0}") + log.debug(f"positions = {x0}") log.debug(f"U0 = {U0}") - log.debug("Energy is calculate: U(x) = (K/2) * ( (x-x0)^2 + y^2 + z^2 ) + U0") + log.debug( + "Energy is calculate: U(x) = (K/2) * ( (x-positions)^2 + y^2 + z^2 ) + U0" + ) self.k = jnp.array( k.value_in_unit_system(unit.md_unit_system) ) # spring constant @@ -339,7 +409,7 @@ def __init__( self.topology = topology def compute_energy(self, positions: jnp.array, nbr_list=None): - # the functional form is given by U(x) = (K/2) * ( (x-x0)^2 + y^2 + z^2 ) + U0 + # the functional form is given by U(x) = (K/2) * ( (x-positions)^2 + y^2 + z^2 ) + U0 # https://github.com/choderalab/openmmtools/blob/main/openmmtools/testsystems.py#L695 # compute the displacement vectors diff --git a/chiron/reporters.py b/chiron/reporters.py index 27457a6..156e86a 100644 --- a/chiron/reporters.py +++ b/chiron/reporters.py @@ -367,7 +367,7 @@ def _write_to_trajectory(self, positions: np.ndarray) -> None: file_handler=self._write_xtc_file_handle, positions=positions, iteration=self.buffer.get("step"), - box_vecotrs=self.buffer.get("box_vectors"), + box_vectors=self.buffer.get("box_vectors"), ) def read_from_trajectory(self) -> np.ndarray: @@ -409,7 +409,7 @@ def _write_to_xtc( file_handler: md.formats.XTCTrajectoryFile, positions: np.ndarray, iteration: np.ndarray, - box_vecotrs: Optional[np.ndarray] = None, + box_vectors: Optional[np.ndarray] = None, ): """ Write position data to an XTC file. @@ -428,5 +428,5 @@ def _write_to_xtc( file_handler.write( positions, time=iteration, - box=box_vecotrs, + box=box_vectors, ) diff --git a/chiron/states.py b/chiron/states.py index 99459ae..6ebdca1 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -11,26 +11,40 @@ class SamplerState: Parameters ---------- - x0 : unit.Quantity + positions : unit.Quantity The current positions of the particles in the simulation. velocities : unit.Quantity, optional The velocities of the particles in the simulation. box_vectors : unit.Quantity, optional The box vectors defining the simulation's periodic boundary conditions. + Examples + -------- + + from chiron.states import SamplerState + from chiron.utils import PRNG + from openmmtools.testsystems import HarmonicOscillator + + ho = HarmonicOscillator() + PRNG.set_seed(1234) + + sampler_state = SamplerState(positions = ho.positions, PRNG.get_random_key()) + """ def __init__( self, - x0: unit.Quantity, + positions: unit.Quantity, current_PRNG_key: random.PRNGKey, velocities: Optional[unit.Quantity] = None, box_vectors: Optional[unit.Quantity] = None, ) -> None: # NOTE: all units are internally in the openMM units system as documented here: # http://docs.openmm.org/latest/userguide/theory/01_introduction.html#units - if not isinstance(x0, unit.Quantity): - raise TypeError(f"x0 must be a unit.Quantity, got {type(x0)} instead.") + if not isinstance(positions, unit.Quantity): + raise TypeError( + f"positions must be a unit.Quantity, got {type(positions)} instead." + ) if velocities is not None and not isinstance(velocities, unit.Quantity): raise TypeError( f"velocities must be a unit.Quantity, got {type(velocities)} instead." @@ -45,8 +59,10 @@ def __init__( raise TypeError( f"box_vectors must be a unit.Quantity or openMM box, got {type(box_vectors)} instead." ) - if not x0.unit.is_compatible(unit.nanometer): - raise ValueError(f"x0 must have units of distance, got {x0.unit} instead.") + if not positions.unit.is_compatible(unit.nanometer): + raise ValueError( + f"positions must have units of distance, got {positions.unit} instead." + ) if velocities is not None and not velocities.unit.is_compatible( unit.nanometer / unit.picosecond ): @@ -63,26 +79,27 @@ def __init__( raise ValueError( f"box_vectors must be a 3x3 array, got {box_vectors.shape} instead." ) - if velocities is not None and x0.shape != velocities.shape: + if velocities is not None and positions.shape != velocities.shape: raise ValueError( - f"x0 and velocities must have the same shape, got {x0.shape} and {velocities.shape} instead." + f"positions and velocities must have the same shape, got {positions.shape} and {velocities.shape} instead." ) if current_PRNG_key is None: raise ValueError(f"random_seed must be set.") - self._x0 = x0 + self._positions = positions self._velocities = velocities self._current_PRNG_key = current_PRNG_key self._box_vectors = box_vectors self._distance_unit = unit.nanometer + self._time_unit = unit.picosecond @property def n_particles(self) -> int: - return self._x0.shape[0] + return self._positions.shape[0] @property - def x0(self) -> jnp.array: - return self._convert_to_jnp(self._x0) + def positions(self) -> jnp.array: + return self._convert_to_jnp(self._positions) @property def velocities(self) -> jnp.array: @@ -96,17 +113,40 @@ def box_vectors(self) -> jnp.array: return None return self._convert_to_jnp(self._box_vectors) - @x0.setter - def x0(self, x0: Union[jnp.array, unit.Quantity]) -> None: + @positions.setter + def positions(self, x0: Union[jnp.array, unit.Quantity]) -> None: if isinstance(x0, unit.Quantity): - self._x0 = x0 + self._positions = x0 + else: + self._positions = unit.Quantity(x0, self._distance_unit) + + @box_vectors.setter + def box_vectors(self, box_vectors: Union[jnp.array, unit.Quantity]) -> None: + if isinstance(box_vectors, unit.Quantity): + self._box_vectors = box_vectors else: - self._x0 = unit.Quantity(x0, self._distance_unit) + self._box_vectors = unit.Quantity(box_vectors, self._distance_unit) + + @velocities.setter + def velocities(self, velocities: Union[jnp.array, unit.Quantity]) -> None: + if velocities.shape != self._positions.shape: + raise ValueError( + f"velocities must have the same shape as positions, got {velocities.shape} and {self._positions.shape} instead." + ) + if isinstance(velocities, unit.Quantity): + self._velocities = velocities + else: + self._velocities = unit.Quantity( + velocities, self._distance_unit / self._time_unit + ) @property def distance_unit(self) -> unit.Unit: return self._distance_unit + def velocity_unit(self) -> unit.Unit: + return self._distance_unit / self._time_unit + @property def new_PRNG_key(self) -> random.PRNGKey: key, subkey = random.split(self._current_PRNG_key) @@ -201,7 +241,7 @@ def __init__( from .utils import get_nr_of_particles self.nr_of_particles = get_nr_of_particles(self.potential.topology) - self._check_completness() + self._check_completeness() def check_variables(self) -> None: """ @@ -215,7 +255,7 @@ def check_variables(self) -> None: set_variables = [var for var in variables if getattr(self, var) is not None] return set_variables - def _check_completness(self): + def _check_completeness(self): # check which variables are set set_variables = self.check_variables() from loguru import logger as log @@ -242,7 +282,7 @@ def get_reduced_potential( ---------- sampler_state : SamplerState The sampler state for which to compute the reduced potential. - nbr_list : NeighborList or PairList, optional + nbr_list : NeighborList or PairListNsqrd, optional The neighbor list or pair list routine to use for calculating the reduced potential. Returns @@ -263,17 +303,25 @@ def get_reduced_potential( self.beta = 1.0 / ( unit.BOLTZMANN_CONSTANT_kB * (self.temperature * unit.kelvin) ) - # log.debug(f"sample state: {sampler_state.x0}") + # log.debug(f"sample state: {sampler_state.positions}") reduced_potential = ( unit.Quantity( - self.potential.compute_energy(sampler_state.x0, nbr_list), + self.potential.compute_energy(sampler_state.positions, nbr_list), unit.kilojoule_per_mole, ) ) / unit.AVOGADRO_CONSTANT_NA # log.debug(f"reduced potential: {reduced_potential}") if self.pressure is not None: - reduced_potential += self.pressure * self.volume + self.volume = ( + sampler_state.box_vectors[0][0] + * sampler_state.box_vectors[1][1] + * sampler_state.box_vectors[2][2] + ) * unit.nanometer**3 + from loguru import logger as log + + reduced_potential += self.pressure * self.volume + # add chemical potential return self.beta * reduced_potential def kT_to_kJ_per_mol(self, energy): @@ -295,7 +343,7 @@ def calculate_reduced_potential_at_states( The sampler state for which to compute the reduced potential. thermodynamic_states : list of ThermodynamicState The thermodynamic states for which to compute the reduced potential. - nbr_list : NeighborList or PairList, optional + nbr_list : NeighborList or PairListNsqrd, optional Returns ------- list of float diff --git a/chiron/tests/conftest.py b/chiron/tests/conftest.py index 742d5ab..9f1cca6 100644 --- a/chiron/tests/conftest.py +++ b/chiron/tests/conftest.py @@ -29,8 +29,8 @@ def provide_testsystems_and_potentials(): import jax.numpy as jnp hoa_potential = HarmonicOscillatorPotential( - ho.topology, - ho.K, + hoa.topology, + hoa.K, x0=unit.Quantity( jnp.array( [ @@ -54,5 +54,3 @@ def provide_testsystems_and_potentials(): (hoa, hoa_potential), ] return TESTSYSTEM_AND_POTENTIAL - - diff --git a/chiron/tests/test_convergence_tests.py b/chiron/tests/test_convergence_tests.py index 16cfad3..df30515 100644 --- a/chiron/tests/test_convergence_tests.py +++ b/chiron/tests/test_convergence_tests.py @@ -44,9 +44,17 @@ def test_convergence_of_MC_estimator(prep_temp_dir): from chiron.states import ThermodynamicState, SamplerState thermodynamic_state = ThermodynamicState( - harmonic_potential, temperature=300, volume=30 * (unit.angstrom**3) + harmonic_potential, + temperature=300 * unit.kelvin, + volume=30 * (unit.angstrom**3), + ) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + + sampler_state = SamplerState( + positions=ho.positions, current_PRNG_key=PRNG.get_random_key() ) - sampler_state = SamplerState(ho.positions) from chiron.reporters import _SimulationReporter @@ -55,16 +63,16 @@ def test_convergence_of_MC_estimator(prep_temp_dir): simulation_reporter = _SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") # Initalize the move set (here only LangevinDynamicsMove) - from chiron.mcmc import MetropolisDisplacementMove, MoveSchedule, MCMCSampler + from chiron.mcmc import MonteCarloDisplacementMove, MoveSchedule, MCMCSampler - mc_displacement_move = MetropolisDisplacementMove( - nr_of_moves=100_000, + mc_displacement_move = MonteCarloDisplacementMove( + number_of_moves=1_000, displacement_sigma=0.5 * unit.angstrom, atom_subset=[0], reporter=simulation_reporter, ) - move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) + move_set = MoveSchedule([("MonteCarloDisplacementMove", mc_displacement_move)]) # Initalize the sampler sampler = MCMCSampler(move_set, sampler_state, thermodynamic_state) @@ -82,7 +90,9 @@ def test_convergence_of_MC_estimator(prep_temp_dir): plt.plot(chiron_energy) print("Expectation values generated with chiron") - es = chiron_energy + import jax.numpy as jnp + + es = jnp.array(chiron_energy) print(es.mean(), es.std()) print("Expectation values from openmmtools") @@ -133,12 +143,16 @@ def test_langevin_dynamics_with_LJ_fluid(prep_temp_dir): ) print(lj_fluid.system.getDefaultPeriodicBoxVectors()) + from chiron.utils import PRNG + + PRNG.set_seed(1234) sampler_state = SamplerState( - x0=lj_fluid.positions, + positions=lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors(), + current_PRNG_key=PRNG.get_random_key(), ) - print(sampler_state.x0.shape) + print(sampler_state.positions.shape) print(sampler_state.box_vectors) nbr_list = NeighborListNsqrd( @@ -153,16 +167,137 @@ def test_langevin_dynamics_with_LJ_fluid(prep_temp_dir): potential=lj_potential, temperature=300 * unit.kelvin ) - from chiron.reporters import _SimulationReporter + from chiron.reporters import LangevinDynamicsReporter id = uuid.uuid4() - reporter = _SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") + reporter = LangevinDynamicsReporter(f"{prep_temp_dir}/test_{id}.h5") - integrator = LangevinIntegrator(reporter=reporter, report_frequency=100) + integrator = LangevinIntegrator(reporter=reporter, report_interval=100) integrator.run( sampler_state, thermodynamic_state, - n_steps=2000, + number_of_steps=1000, nbr_list=nbr_list, progress_bar=True, ) + + +@pytest.mark.skip(reason="Tests takes too long") +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test takes too long.") +def test_ideal_gas(prep_temp_dir): + from openmmtools.testsystems import IdealGas + from openmm import unit + + n_particles = 216 + temperature = 298 * unit.kelvin + pressure = 1 * unit.atmosphere + mass = unit.Quantity(39.9, unit.gram / unit.mole) + + ideal_gas = IdealGas( + nparticles=n_particles, temperature=temperature, pressure=pressure + ) + + from chiron.potential import IdealGasPotential + from chiron.utils import PRNG + import jax.numpy as jnp + + # + cutoff = 0.0 * unit.nanometer + ideal_gas_potential = IdealGasPotential(ideal_gas.topology) + + from chiron.states import SamplerState, ThermodynamicState + + # define the thermodynamic state + thermodynamic_state = ThermodynamicState( + potential=ideal_gas_potential, + temperature=temperature, + pressure=pressure, + ) + + PRNG.set_seed(1234) + + # define the sampler state + sampler_state = SamplerState( + positions=ideal_gas.positions, + current_PRNG_key=PRNG.get_random_key(), + box_vectors=ideal_gas.system.getDefaultPeriodicBoxVectors(), + ) + + from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace + + # define the pair list for an orthogonal periodic space + # since particles are non-interacting, this will not really do much + # but will appropriately wrap particles in space + nbr_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=cutoff) + nbr_list.build_from_state(sampler_state) + + from chiron.reporters import MCReporter + + # initialize a reporter to save the simulation data + filename = "test_mc_ideal_gas.h5" + import os + + if os.path.isfile(filename): + os.remove(filename) + reporter = MCReporter(filename, 1) + + from chiron.mcmc import ( + MonteCarloDisplacementMove, + MonteCarloBarostatMove, + MoveSchedule, + MCMCSampler, + ) + + mc_displacement_move = MonteCarloDisplacementMove( + displacement_sigma=0.1 * unit.nanometer, + number_of_moves=10, + reporter=reporter, + autotune=True, + autotune_interval=100, + ) + + mc_barostat_move = MonteCarloBarostatMove( + volume_max_scale=0.2, + number_of_moves=100, + reporter=reporter, + autotune=True, + autotune_interval=100, + ) + move_set = MoveSchedule( + [ + ("MonteCarloDisplacementMove", mc_displacement_move), + ("MonteCarloBarostatMove", mc_barostat_move), + ] + ) + + sampler = MCMCSampler(move_set) + sampler.run( + sampler_state, thermodynamic_state, n_iterations=10, nbr_list=nbr_list + ) # how many times to repeat + + volume = reporter.get_property("volume") + + # get expectations + ideal_volume = ideal_gas.get_volume_expectation(thermodynamic_state) + ideal_volume_std = ideal_gas.get_volume_standard_deviation(thermodynamic_state) + + print(ideal_volume, ideal_volume_std) + + volume_mean = jnp.mean(jnp.array(volume)) * unit.nanometer**3 + volume_std = jnp.std(jnp.array(volume)) * unit.nanometer**3 + + print(volume_mean, volume_std) + + ideal_density = mass * n_particles / unit.AVOGADRO_CONSTANT_NA / ideal_volume + measured_density = mass * n_particles / unit.AVOGADRO_CONSTANT_NA / volume_mean + + assert jnp.isclose( + ideal_density.value_in_unit(unit.kilogram / unit.meter**3), + measured_density.value_in_unit(unit.kilogram / unit.meter**3), + atol=1e-1, + ) + # see if within 5% of ideal volume + assert abs(ideal_volume - volume_mean) / ideal_volume < 0.05 + + # see if within 10% of the ideal standard deviation of the volume + assert abs(ideal_volume_std - volume_std) / ideal_volume_std < 0.1 diff --git a/chiron/tests/test_integrators.py b/chiron/tests/test_integrators.py index 341c987..9ed5b4e 100644 --- a/chiron/tests/test_integrators.py +++ b/chiron/tests/test_integrators.py @@ -42,10 +42,12 @@ def test_langevin_dynamics(prep_temp_dir, provide_testsystems_and_potentials): reporter = LangevinDynamicsReporter() - integrator = LangevinIntegrator(reporter=reporter, report_frequency=1) - integrator.run( + integrator = LangevinIntegrator( + reporter=reporter, report_interval=1, refresh_velocities=True + ) + updated_sampler_state, updated_nbr_list = integrator.run( sampler_state, thermodynamic_state, - n_steps=20, + number_of_steps=20, ) i = i + 1 diff --git a/chiron/tests/test_mcmc.py b/chiron/tests/test_mcmc.py index 82ef9a0..bef08ba 100644 --- a/chiron/tests/test_mcmc.py +++ b/chiron/tests/test_mcmc.py @@ -41,7 +41,7 @@ def test_sample_from_harmonic_osciallator(prep_temp_dir): PRNG.set_seed(1234) sampler_state = SamplerState( - x0=ho.positions, current_PRNG_key=PRNG.get_random_key() + positions=ho.positions, current_PRNG_key=PRNG.get_random_key() ) from chiron.integrators import LangevinIntegrator @@ -53,13 +53,16 @@ def test_sample_from_harmonic_osciallator(prep_temp_dir): reporter = LangevinDynamicsReporter() integrator = LangevinIntegrator( - stepsize=2 * unit.femtosecond, reporter=reporter, report_frequency=1 + timestep=2 * unit.femtosecond, + reporter=reporter, + report_interval=1, + refresh_velocities=True, ) integrator.run( sampler_state, thermodynamic_state, - n_steps=5, + number_of_steps=5, ) integrator.reporter.flush_buffer() import jax.numpy as jnp @@ -122,7 +125,32 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_LangevinDynamics BaseReporter.set_directory(prep_temp_dir) simulation_reporter = LangevinDynamicsReporter(1) - langevin_move = LangevinDynamicsMove(nr_of_steps=10, reporter=simulation_reporter) + + # the following will reinitialize the velocities for each iteration + langevin_move = LangevinDynamicsMove( + number_of_steps=10, refresh_velocities=True, reporter=simulation_reporter + ) + + move_set = MoveSchedule([("LangevinMove", langevin_move)]) + + # Initalize the sampler + sampler = MCMCSampler(move_set) + + # Run the sampler with the thermodynamic state and sampler state and return the sampler state + sampler.run( + sampler_state, thermodynamic_state, n_iterations=2 + ) # how many times to repeat + + # the following will use the initialize velocities function + from chiron.utils import initialize_velocities + + sampler_state.velocities = initialize_velocities( + thermodynamic_state.temperature, ho.topology, sampler_state._current_PRNG_key + ) + + langevin_move = LangevinDynamicsMove( + number_of_steps=10, refresh_velocities=False, reporter=simulation_reporter + ) move_set = MoveSchedule([("LangevinMove", langevin_move)]) @@ -146,7 +174,7 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla """ from openmm import unit from chiron.potential import HarmonicOscillatorPotential - from chiron.mcmc import MetropolisDisplacementMove, MoveSchedule, MCMCSampler + from chiron.mcmc import MonteCarloDisplacementMove, MoveSchedule, MCMCSampler # Initalize the testsystem from openmmtools.testsystems import HarmonicOscillator @@ -178,14 +206,14 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla BaseReporter.set_directory(wd) simulation_reporter = MCReporter(1) - mc_displacement_move = MetropolisDisplacementMove( - nr_of_moves=10, + mc_displacement_move = MonteCarloDisplacementMove( + number_of_moves=10, displacement_sigma=0.1 * unit.angstrom, atom_subset=[0], reporter=simulation_reporter, ) - move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) + move_set = MoveSchedule([("MonteCarloDisplacementMove", mc_displacement_move)]) # Initalize the sampler sampler = MCMCSampler(move_set) @@ -206,7 +234,7 @@ def test_sample_from_harmonic_osciallator_array_with_MCMC_classes_and_Metropolis sampler states, and uses the Metropolis displacement move in an MCMC sampling scheme. """ from openmm import unit - from chiron.mcmc import MetropolisDisplacementMove, MoveSchedule, MCMCSampler + from chiron.mcmc import MonteCarloDisplacementMove, MoveSchedule, MCMCSampler # Initalize the testsystem from openmmtools.testsystems import HarmonicOscillatorArray @@ -240,14 +268,14 @@ def test_sample_from_harmonic_osciallator_array_with_MCMC_classes_and_Metropolis simulation_reporter = MCReporter(1) - mc_displacement_move = MetropolisDisplacementMove( - nr_of_moves=10, + mc_displacement_move = MonteCarloDisplacementMove( + number_of_moves=10, displacement_sigma=0.1 * unit.angstrom, atom_subset=None, reporter=simulation_reporter, ) - move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) + move_set = MoveSchedule([("MonteCarloDisplacementMove", mc_displacement_move)]) # Initalize the sampler sampler = MCMCSampler(move_set) @@ -296,6 +324,134 @@ def test_thermodynamic_state_inputs(): ThermodynamicState(potential=harmonic_potential, pressure=100 * unit.atmosphere) +def test_mc_barostat_parameter_setting(): + import jax.numpy as jnp + from chiron.mcmc import MonteCarloBarostatMove + + barostat_move = MonteCarloBarostatMove( + volume_max_scale=0.1, + number_of_moves=1, + ) + + assert barostat_move.volume_max_scale == 0.1 + assert barostat_move.number_of_moves == 1 + + +def test_mc_barostat(prep_temp_dir): + import jax.numpy as jnp + + from chiron.reporters import MCReporter, BaseReporter + + wd = prep_temp_dir.join(f"_test_{uuid.uuid4()}") + BaseReporter.set_directory(wd) + simulation_reporter = MCReporter(1) + + from chiron.mcmc import MonteCarloBarostatMove + + barostat_move = MonteCarloBarostatMove( + volume_max_scale=0.1, + number_of_moves=10, + reporter=simulation_reporter, + report_interval=1, + ) + + from chiron.potential import IdealGasPotential + from openmm import unit + + positions = ( + jnp.array( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 1, 0], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ] + ) + * unit.nanometer + ) + box_vectors = ( + jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + * unit.nanometer + ) + volume = box_vectors[0][0] * box_vectors[1][1] * box_vectors[2][2] + + from openmm.app import Topology, Element + + topology = Topology() + element = Element.getBySymbol("Ar") + chain = topology.addChain() + residue = topology.addResidue("system", chain) + for i in range(positions.shape[0]): + topology.addAtom("Ar", element, residue) + + ideal_gas_potential = IdealGasPotential(topology) + + from chiron.states import SamplerState, ThermodynamicState + from chiron.utils import PRNG + + PRNG.set_seed(1234) + + # define the sampler state + sampler_state = SamplerState( + positions=positions, + box_vectors=box_vectors, + current_PRNG_key=PRNG.get_random_key(), + ) + + # define the thermodynamic state + thermodynamic_state = ThermodynamicState( + potential=ideal_gas_potential, + temperature=300 * unit.kelvin, + pressure=1.0 * unit.atmosphere, + ) + + from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace + + # since particles are non-interacting and we will not displacece them, the pair list basically + # does nothing in this case. + nbr_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=0 * unit.nanometer) + + sampler_state, thermodynamic_state, nbr_list = barostat_move.update( + sampler_state, thermodynamic_state, nbr_list + ) + potential_energies = simulation_reporter.get_property("potential_energy") + volumes = simulation_reporter.get_property("volume") + + # ideal gas treatment, so stored energy will only be a + # consequence of pressure, volume, and temperature + from loguru import logger as log + + log.debug(f"PE {potential_energies * unit.kilojoules_per_mole}") + log.debug(thermodynamic_state.pressure) + log.debug(thermodynamic_state.beta) + log.debug(volumes) + log.debug(volumes * unit.nanometer**3) + + # assert that the PE is always zero + assert potential_energies[0] == 0 + assert potential_energies[-1] == 0 + + # the reduced potential will only be a consequence of the pressure, volume, and temperature + + assert jnp.isclose( + thermodynamic_state.get_reduced_potential(sampler_state), + ( + thermodynamic_state.pressure + * thermodynamic_state.beta + * (volumes[-1] * unit.nanometer**3) + ), + 1e-3, + ) + + print(barostat_move.statistics["n_accepted"]) + assert barostat_move.statistics["n_proposed"] == 10 + assert barostat_move.statistics["n_accepted"] == 8 + + def test_sample_from_joint_distribution_of_two_HO_with_local_moves_and_MC_updates(): # define two harmonic oscillators with different spring constants and equilibrium positions # sample from the joint distribution of the two HO using local langevin moves diff --git a/chiron/tests/test_minization.py b/chiron/tests/test_minization.py index cf0cf4e..4faa41d 100644 --- a/chiron/tests/test_minization.py +++ b/chiron/tests/test_minization.py @@ -3,7 +3,7 @@ def test_minimization(): import jax.numpy as jnp from chiron.states import SamplerState - from chiron.neighbors import PairList, OrthogonalPeriodicSpace + from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace from openmm import unit # initialize testystem @@ -25,12 +25,14 @@ def test_minimization(): box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors(), ) # use parilist - nbr_list = PairList(OrthogonalPeriodicSpace(), cutoff=cutoff) + nbr_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=cutoff) nbr_list.build_from_state(sampler_state) # compute intial energy with and without pairlist - initial_e_with_nbr_list = lj_potential.compute_energy(sampler_state.x0, nbr_list) - initial_e_without_nbr_list = lj_potential.compute_energy(sampler_state.x0) + initial_e_with_nbr_list = lj_potential.compute_energy( + sampler_state.positions, nbr_list + ) + initial_e_without_nbr_list = lj_potential.compute_energy(sampler_state.positions) print(f"initial_e_with_nbr_list: {initial_e_with_nbr_list}") print(f"initial_e_without_nbr_list: {initial_e_without_nbr_list}") assert not jnp.isclose( @@ -38,7 +40,7 @@ def test_minimization(): ), "initial_e_with_nbr_list and initial_e_without_nbr_list should not be close" # minimize energy for 0 steps results = minimize_energy( - sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=0 + sampler_state.positions, lj_potential.compute_energy, nbr_list, maxiter=0 ) # check that the minimization did not change the energy @@ -48,7 +50,7 @@ def test_minimization(): min_x, nbr_list ) after_0_steps_minimization_e_without_nbr_list = lj_potential.compute_energy( - sampler_state.x0 + sampler_state.positions ) print( f"after_0_steps_minimization_e_with_nbr_list: {after_0_steps_minimization_e_with_nbr_list}" @@ -67,7 +69,7 @@ def test_minimization(): # after 100 steps of minimization steps = 100 results = minimize_energy( - sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=steps + sampler_state.positions, lj_potential.compute_energy, nbr_list, maxiter=steps ) min_x = results.params e_min = lj_potential.compute_energy(min_x, nbr_list) @@ -86,7 +88,7 @@ def test_minimize_two_particles(): import jax.numpy as jnp from chiron.states import SamplerState - from chiron.neighbors import PairList, OrthogonalPeriodicSpace + from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace from openmm import unit from chiron.potential import LJPotential @@ -103,13 +105,13 @@ def test_minimize_two_particles(): # define the sampler state sampler_state = SamplerState( - x0=coordinates * unit.nanometer, + positions=coordinates * unit.nanometer, current_PRNG_key=PRNG.get_random_key(), box_vectors=jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) * unit.nanometer, ) - pair_list = PairList(OrthogonalPeriodicSpace(), cutoff=cutoff) + pair_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=cutoff) pair_list.build_from_state(sampler_state) e_start = lj_potential.compute_energy(coordinates, pair_list) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 0f153be..533d879 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -25,7 +25,9 @@ def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) - lang_move = LangevinDynamicsMove(stepsize=1.0 * unit.femtoseconds, nr_of_steps=100) + lang_move = LangevinDynamicsMove( + timestep=1.0 * unit.femtoseconds, number_of_steps=100 + ) BaseReporter.set_directory("multistate_test") reporter = MultistateReporter() reporter.reset_reporter_file() @@ -183,21 +185,24 @@ def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSa ho_multistate_sampler_multiple_minima.minimize() assert np.allclose( - ho_multistate_sampler_multiple_minima.sampler_states[0].x0, + ho_multistate_sampler_multiple_minima.sampler_states[0].positions, np.array([[0.0, 0.0, 0.0]]), ) assert np.allclose( - ho_multistate_sampler_multiple_minima.sampler_states[1].x0, + ho_multistate_sampler_multiple_minima.sampler_states[1].positions, np.array([[0.05, 0.0, 0.0]]), atol=1e-2, ) assert np.allclose( - ho_multistate_sampler_multiple_minima.sampler_states[2].x0, + ho_multistate_sampler_multiple_minima.sampler_states[2].positions, np.array([[0.1, 0.0, 0.0]]), atol=1e-2, ) +@pytest.mark.skip( + reason="Multistate code still needs to be modified in the multistage branch" +) def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): """ Test function for running the multistate sampler. @@ -221,12 +226,13 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): ho_sampler.run(n_iteratinos) # check that we have the correct number of iterations, replicas and states - assert ho_sampler.iteration == n_iteratinos - assert ho_sampler._iteration == n_iteratinos + assert ho_sampler.iteration == n_iterations + assert ho_sampler._iteration == n_iterations assert ho_sampler.n_replicas == 4 assert ho_sampler.n_states == 4 u_kn = ho_sampler._reporter.get_property("u_kn") + assert u_kn.shape == (n_iteratinos, 4, 4) # check that the free energies are correct print(ho_sampler.analytical_f_i) diff --git a/chiron/tests/test_pairs.py b/chiron/tests/test_pairs.py index fb2bf2c..60df5ad 100644 --- a/chiron/tests/test_pairs.py +++ b/chiron/tests/test_pairs.py @@ -2,7 +2,7 @@ import pytest from chiron.neighbors import ( NeighborListNsqrd, - PairList, + PairListNsqrd, OrthogonalPeriodicSpace, OrthogonalNonperiodicSpace, ) @@ -13,78 +13,47 @@ def test_orthogonal_periodic_displacement(): # test that the incorrect box shapes throw an exception - with pytest.raises(ValueError): - space = OrthogonalPeriodicSpace(jnp.array([10.0, 10.0, 10.0])) - # test that incorrect units throw an exception - with pytest.raises(ValueError): - space = OrthogonalPeriodicSpace( - unit.Quantity( - jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]), - unit.radians, - ) - ) - - space = OrthogonalPeriodicSpace( - jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - ) - # test that the box vectors are set correctly - assert jnp.all( - space.box_vectors - == jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - ) - - # test that the box lengths for an orthogonal box are set correctly - assert jnp.all(space._box_lengths == jnp.array([10.0, 10.0, 10.0])) + space = OrthogonalPeriodicSpace() + box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) # test calculation of the displacement_vector and distance between two points p1 = jnp.array([[0, 0, 0], [0, 0, 0]]) p2 = jnp.array([[1, 0, 0], [6, 0, 0]]) - r_ij, distance = space.displacement(p1, p2) + r_ij, distance = space.displacement(p1, p2, box_vectors) assert jnp.all(r_ij == jnp.array([[-1.0, 0.0, 0.0], [4.0, 0.0, 0.0]])) assert jnp.all(distance == jnp.array([1, 4])) # test that the periodic wrapping works as expected - wrapped_x = space.wrap(jnp.array([11, 0, 0])) + wrapped_x = space.wrap(jnp.array([11, 0, 0]), box_vectors) assert jnp.all(wrapped_x == jnp.array([1, 0, 0])) - wrapped_x = space.wrap(jnp.array([-1, 0, 0])) + wrapped_x = space.wrap(jnp.array([-1, 0, 0]), box_vectors) assert jnp.all(wrapped_x == jnp.array([9, 0, 0])) - wrapped_x = space.wrap(jnp.array([5, 0, 0])) + wrapped_x = space.wrap(jnp.array([5, 0, 0]), box_vectors) assert jnp.all(wrapped_x == jnp.array([5, 0, 0])) - wrapped_x = space.wrap(jnp.array([5, 12, -1])) + wrapped_x = space.wrap(jnp.array([5, 12, -1]), box_vectors) assert jnp.all(wrapped_x == jnp.array([5, 2, 9])) - # test the setter for the box vectors - space.box_vectors = jnp.array( - [[10.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 30.0]] - ) - assert jnp.all( - space._box_vectors - == jnp.array([[10.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 30.0]]) - ) - assert jnp.all(space._box_lengths == jnp.array([10.0, 20.0, 30.0])) - def test_orthogonal_nonperiodic_displacement(): - space = OrthogonalNonperiodicSpace( - jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - ) + space = OrthogonalNonperiodicSpace() + box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) p1 = jnp.array([[0, 0, 0], [0, 0, 0]]) p2 = jnp.array([[1, 0, 0], [6, 0, 0]]) - r_ij, distance = space.displacement(p1, p2) + r_ij, distance = space.displacement(p1, p2, box_vectors) assert jnp.all(r_ij == jnp.array([[-1.0, 0.0, 0.0], [-6.0, 0.0, 0.0]])) assert jnp.all(distance == jnp.array([1, 6])) - wrapped_x = space.wrap(jnp.array([11, -1, 2])) + wrapped_x = space.wrap(jnp.array([11, -1, 2]), box_vectors) assert jnp.all(wrapped_x == jnp.array([11, -1, 2])) @@ -100,34 +69,33 @@ def test_neighborlist_pair(): PRNG.set_seed(1234) state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), + positions=unit.Quantity(coordinates, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) space = OrthogonalPeriodicSpace() - cutoff = 1.1 - skin = 0.1 + cutoff = 1.1 * unit.nanometer + skin = 0.1 * unit.nanometer nbr_list = NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=5, ) assert nbr_list.cutoff == cutoff assert nbr_list.skin == skin - assert nbr_list.cutoff_and_skin == cutoff + skin assert nbr_list.n_max_neighbors == 5 nbr_list.build_from_state(state) - assert jnp.all(nbr_list.ref_coordinates == coordinates) + assert jnp.all(nbr_list.ref_positions == coordinates) assert jnp.all(nbr_list.box_vectors == box_vectors) assert nbr_list.is_built == True - nbr_list.build(state.x0, state.box_vectors) + nbr_list.build(state.positions, state.box_vectors) - assert jnp.all(nbr_list.ref_coordinates == coordinates) + assert jnp.all(nbr_list.ref_positions == coordinates) assert jnp.all(nbr_list.box_vectors == box_vectors) assert nbr_list.is_built == True @@ -195,12 +163,12 @@ def test_neighborlist_pair(): def test_inputs(): space = OrthogonalPeriodicSpace() # every particle should interact with every other particle - cutoff = 2.1 - skin = 0.1 + cutoff = 2.1 * unit.nanometer + skin = 0.1 * unit.nanometer nbr_list = NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=5, ) # check that the state is of the correct type @@ -213,7 +181,7 @@ def test_inputs(): PRNG.set_seed(1234) state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), + positions=unit.Quantity(coordinates, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=None, ) @@ -247,24 +215,24 @@ def test_inputs(): with pytest.raises(TypeError): NeighborListNsqrd( 123, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=5, ) # check units of cutoff with pytest.raises(ValueError): NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.radian), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=unit.Quantity(123, unit.radian), + skin=unit.Quantity(123, unit.nanometer), n_max_neighbors=5, ) # check units of skin with pytest.raises(ValueError): NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.radian), + cutoff=unit.Quantity(123, unit.nanometer), + skin=unit.Quantity(123, unit.radian), n_max_neighbors=5, ) @@ -287,19 +255,19 @@ def test_neighborlist_pair_multiple_particles(): PRNG.set_seed(1234) state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), + positions=unit.Quantity(coordinates, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) space = OrthogonalPeriodicSpace() # every particle should interact with every other particle - cutoff = 2.1 - skin = 0.1 + cutoff = 2.1 * unit.nanometer + skin = 0.1 * unit.nanometer nbr_list = NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=5, ) nbr_list.build_from_state(state) @@ -310,12 +278,12 @@ def test_neighborlist_pair_multiple_particles(): assert jnp.all(n_interacting == jnp.array([7, 6, 5, 4, 3, 2, 1, 0])) # every particle should be in the nieghbor list, but only a subset in the interacting range - cutoff = 1.1 - skin = 1.1 + cutoff = 1.1 * unit.nanometer + skin = 1.1 * unit.nanometer nbr_list = NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=5, ) nbr_list.build_from_state(state) @@ -342,7 +310,7 @@ def test_neighborlist_pair_multiple_particles(): ) ) # test passing coordinates and box vectors directly - nbr_list.build(state.x0, state.box_vectors) + nbr_list.build(state.positions, state.box_vectors) assert jnp.all(nbr_list.n_neighbors == jnp.array([7, 6, 5, 4, 3, 2, 1, 0])) @@ -362,17 +330,16 @@ def test_pairlist_pair(): PRNG.set_seed(1234) state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), + positions=unit.Quantity(coordinates, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) space = OrthogonalPeriodicSpace() - cutoff = 1.1 - skin = 0.1 - pair_list = PairList( + cutoff = 1.1 * unit.nanometer + pair_list = PairListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), + cutoff=cutoff, ) assert pair_list.cutoff == cutoff @@ -382,7 +349,7 @@ def test_pairlist_pair(): assert jnp.all(pair_list.reduction_mask == jnp.array([[True], [False]])) assert pair_list.is_built == True - n_pairs, all_pairs, mask, dist, displacement = pair_list.calculate(coordinates) + n_pairs, all_pairs, mask, dist, displacement = pair_list.calculate(state.positions) assert jnp.all(n_pairs == jnp.array([1, 0])) assert jnp.all(all_pairs.shape == (2, 1)) @@ -394,10 +361,49 @@ def test_pairlist_pair(): assert pair_list.check(coordinates) == False - coordinates = coordinates = jnp.array([[0, 0, 0], [1, 0, 0], [1, 1, 0]]) + coordinates = jnp.array([[0, 0, 0], [1, 0, 0], [1, 1, 0]]) # we changed number of particles, and thus should rebuild assert pair_list.check(coordinates) == True + # test without using a cutoff + # this will be exactly the same as with a cutoff, given it is just two particles + cutoff = None + pair_list = PairListNsqrd( + space, + cutoff=None, + ) + pair_list.build_from_state(state) + + assert pair_list.cutoff == cutoff + n_pairs, all_pairs, mask, dist, displacement = pair_list.calculate(state.positions) + assert jnp.all(n_pairs == jnp.array([1, 0])) + assert jnp.all(all_pairs.shape == (2, 1)) + assert jnp.all(all_pairs == jnp.array([[1], [0]])) + assert jnp.all(mask == jnp.array([[1], [0]])) + assert jnp.all(dist == jnp.array([[1.0], [1.0]])) + assert displacement.shape == (2, 1, 3) + assert jnp.all(displacement == jnp.array([[[-1.0, 0.0, 0.0]], [[1.0, 0.0, 0.0]]])) + + # test the difference between a short cutoff with no interactions and the same + # system with no cutoff. + + # this test ultimately have no particles in the neighbor list + # because the cutoff is really short + cutoff = 0.5 * unit.nanometer + pair_list = PairListNsqrd(space, cutoff=cutoff) + + assert pair_list.cutoff == cutoff + pair_list.build_from_state(state) + n_pairs, all_pairs, mask, dist, displacement = pair_list.calculate(state.positions) + # the mask will all be false because the cutoff is too short + assert jnp.all(mask == jnp.array([[0], [0]])) + + # set the cutoff to None, and calculate all pairs in the box + pair_list.cutoff = None + n_pairs, all_pairs, mask, dist, displacement = pair_list.calculate(state.positions) + # the mask will have the single pair in the box be true + assert jnp.all(mask == jnp.array([[1], [0]])) + def test_pair_list_multiple_particles(): # test the pair list for multiple particles @@ -416,18 +422,18 @@ def test_pair_list_multiple_particles(): PRNG.set_seed(1234) state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), + positions=unit.Quantity(coordinates, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) space = OrthogonalPeriodicSpace() # every particle should interact with every other particle - cutoff = 2.1 - skin = 0.1 - pair_list = PairList( + cutoff = 2.1 * unit.nanometer + skin = 0.1 * unit.nanometer + pair_list = PairListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), + cutoff=cutoff, ) pair_list.build_from_state(state) @@ -454,8 +460,8 @@ def test_pair_list_multiple_particles(): # compare to nbr_list nbr_list = NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=20, ) nbr_list.build_from_state(state) diff --git a/chiron/tests/test_potential.py b/chiron/tests/test_potential.py index 230fa1c..3013df1 100644 --- a/chiron/tests/test_potential.py +++ b/chiron/tests/test_potential.py @@ -178,7 +178,7 @@ def test_lennard_jones(): positions = jnp.array([[0, 0, 0], [i * 0.25 * 2 ** (1 / 6), 0, 0]]) state = SamplerState( - x0=unit.Quantity(positions, unit.nanometer), + positions=unit.Quantity(positions, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) diff --git a/chiron/tests/test_states.py b/chiron/tests/test_states.py index 4499f18..e437fad 100644 --- a/chiron/tests/test_states.py +++ b/chiron/tests/test_states.py @@ -31,14 +31,14 @@ def test_initialize_state(): sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) assert jnp.allclose( - sampler_state.x0, + sampler_state.positions, jnp.array([[0.0, 0.0, 0.0]]), ) def test_sampler_state_conversion(): """Test converting a sampler state to jnp arrays. - Note, testing the conversion of x0, where internal unit length is nanometers + Note, testing the conversion of positions, where internal unit length is nanometers and thus output jnp.arrays (with units dropped) should reflect this. """ from chiron.states import SamplerState @@ -54,7 +54,7 @@ def test_sampler_state_conversion(): ) assert jnp.allclose( - sampler_state.x0, + sampler_state.positions, jnp.array([[10.0, 10.0, 10.0]]), ) @@ -64,7 +64,7 @@ def test_sampler_state_conversion(): ) assert jnp.allclose( - sampler_state.x0, + sampler_state.positions, jnp.array([[1.0, 1.0, 1.0]]), ) @@ -81,11 +81,11 @@ def test_sampler_state_inputs(): # test input of positions # should have units with pytest.raises(TypeError): - SamplerState(x0=jnp.array([1, 2, 3])) + SamplerState(positions=jnp.array([1, 2, 3])) # throw and error because of incompatible units with pytest.raises(ValueError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.radians), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.radians), current_PRNG_key=PRNG.get_random_key(), ) @@ -93,14 +93,14 @@ def test_sampler_state_inputs(): # velocities should have units with pytest.raises(TypeError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), velocities=jnp.array([1, 2, 3]), ) # velocities should have units of distance/time with pytest.raises(ValueError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), velocities=unit.Quantity(jnp.array([1, 2, 3]), unit.nanometers), ) @@ -109,14 +109,14 @@ def test_sampler_state_inputs(): # box_vectors should have units with pytest.raises(TypeError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), box_vectors=jnp.array([1, 2, 3]), ) # box_vectors should have units of distance with pytest.raises(ValueError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity( jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), unit.radians @@ -125,7 +125,7 @@ def test_sampler_state_inputs(): # check to see that the size of the box vectors are correct with pytest.raises(ValueError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity( jnp.array([[1, 0, 0], [0, 1, 0]]), unit.nanometers @@ -140,7 +140,7 @@ def test_sampler_state_inputs(): # check openmm_box conversion: state = SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), box_vectors=openmm_box, ) @@ -155,7 +155,7 @@ def test_sampler_state_inputs(): # openmm box vectors end up as a list with contents; check to make sure we capture an error if we pass a bad list with pytest.raises(TypeError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), box_vectors=[123], ) diff --git a/chiron/tests/test_testsystems.py b/chiron/tests/test_testsystems.py index 3506b77..5b16ef8 100644 --- a/chiron/tests/test_testsystems.py +++ b/chiron/tests/test_testsystems.py @@ -1,3 +1,13 @@ +import pytest + + +@pytest.fixture(scope="session") +def prep_temp_dir(tmpdir_factory): + """Create a temporary directory for the test.""" + tmpdir = tmpdir_factory.mktemp("test_testsystems") + return tmpdir + + def compute_openmm_reference_energy(testsystem, positions): from openmm import unit from openmm.app import Simulation @@ -186,7 +196,7 @@ def test_LJ_fluid(): PRNG.set_seed(1234) state = SamplerState( - x0=lj_openmm.positions, + positions=lj_openmm.positions, current_PRNG_key=PRNG.get_random_key(), box_vectors=lj_openmm.system.getDefaultPeriodicBoxVectors(), ) @@ -200,7 +210,7 @@ def test_LJ_fluid(): lj_openmm.topology, sigma=sigma, epsilon=epsilon, cutoff=cutoff ) - e_chiron_energy = lj_chiron.compute_energy(state.x0, nbr_list) + e_chiron_energy = lj_chiron.compute_energy(state.positions, nbr_list) e_openmm_energy = compute_openmm_reference_energy( lj_openmm, lj_openmm.positions ) diff --git a/chiron/tests/test_utils.py b/chiron/tests/test_utils.py index 59f6d10..7ce4474 100644 --- a/chiron/tests/test_utils.py +++ b/chiron/tests/test_utils.py @@ -64,11 +64,14 @@ def test_reporter(prep_temp_dir, ho_multistate_sampler_multiple_ks): reporter = LangevinDynamicsReporter("langevin_test") reporter.reset_reporter_file() - integrator = LangevinIntegrator(reporter=reporter, report_frequency=1) + integrator = LangevinIntegrator( + reporter=reporter, + report_interval=1, + ) integrator.run( sampler_state, thermodynamic_state, - n_steps=20, + number_of_steps=20, ) import numpy as np diff --git a/chiron/utils.py b/chiron/utils.py index a9a7da0..41daf33 100644 --- a/chiron/utils.py +++ b/chiron/utils.py @@ -11,20 +11,21 @@ def __init__(self) -> None: """ A PRNG class that can be used to generate random numbers in JAX. The intended use case is to initialize new PRN streams in the `SamplerState` class. - + Example: -------- from chiron.utils import PRNG from chiron.states import SamplerState from openmmtools.testsystems import HarmonicOscillator - + ho = HarmonicOscillator() PRNG.set_seed(1234) sampler_state = [SamplerState(ho.positions, PRNG.get_random_key()) for _ in x0s] - + """ - + pass + @classmethod def set_seed(cls, seed: int) -> None: cls._seed = seed @@ -37,6 +38,25 @@ def get_random_key(cls) -> int: return subkey +def get_full_path(relative_path: str) -> str: + """Get the fill path of a file that is defined relative to the chiron module root directory. + + Parameters + ---------- + relative_path : str + The relative path of the file. + + Returns + ------- + str + The full path of the file. + """ + from importlib.resources import files + + _MODULE_ROOT = files("chiron") + return f"{_MODULE_ROOT}/../{relative_path}" + + def get_data_file_path(relative_path: str) -> str: """Get the full path to one of the reference files in testsystems. In the source distribution, these files are in ``chiron/data/``, @@ -85,9 +105,40 @@ def get_nr_of_particles(topology: Topology) -> int: def get_list_of_mass(topology: Topology) -> unit.Quantity: """Get the mass of the system from the topology.""" - from simtk import unit + from openmm import unit mass = [] for atom in topology.atoms(): mass.append(atom.element.mass.value_in_unit(unit.amu)) return mass * unit.amu + + +def initialize_velocities( + temperature: unit.Quantity, topology: Topology, key +) -> unit.Quantity: + """Initialize the velocities from the Maxwell-Boltzmann distribution at the given temperature. + + Parameters + ---------- + temperature : unit.Quantity + The temperature of the system. + topology : Topology + The topology of the system. + key : int + The PRNG key. + + """ + from openmm import unit + import jax.numpy as jnp + + mass = get_list_of_mass(topology) + + kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA + + kbT_unitless = (kB * temperature).value_in_unit_system(unit.md_unit_system) + mass_unitless = jnp.array(mass.value_in_unit_system(unit.md_unit_system))[:, None] + sigma_v = jnp.sqrt(kbT_unitless / mass_unitless) + + v0 = sigma_v * random.normal(key, [len(mass), 3]) + + return v0 * unit.nanometer / unit.picosecond From d38ed6d98bab2ac26ad8c696843bd6c226208e8b Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 29 Feb 2024 15:27:00 -0800 Subject: [PATCH 49/55] Modified multistate sampler to accept a list of Pair/Neighbor lists (one for each sampler state). Changed test system (HOA) to use nonperiodic space and the pair list; modified neighbor/pairlist classes to not fail if box_vectors = None (errors will be thrown in Space class if box vectors are needed). --- chiron/multistate.py | 51 +++++++++++++----- chiron/neighbors.py | 92 +++++++++++++++++++++------------ chiron/tests/test_multistate.py | 31 ++++++----- chiron/tests/test_pairs.py | 4 +- 4 files changed, 117 insertions(+), 61 deletions(-) diff --git a/chiron/multistate.py b/chiron/multistate.py index e67d7bb..89c9f8f 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -1,6 +1,6 @@ from typing import List, Optional, Union from chiron.states import SamplerState, ThermodynamicState -from chiron.neighbors import NeighborListNsqrd +from chiron.neighbors import PairsBase from openmm import unit import numpy as np from chiron.mcmc import MCMCMove, MCMCSampler @@ -36,7 +36,7 @@ class MultiStateSampler: Methods ------- - create(thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_list: NeighborListNsqrd) + create(thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_lists: List[PairsBase]) Creates a new multistate sampler simulation. minimize(tolerance: unit.Quantity = 1.0 * unit.kilojoules_per_mole / unit.nanometers, max_iterations: int = 1000) Minimizes all replicas in the sampler. @@ -75,6 +75,7 @@ def __init__( self._neighborhoods = None self._n_accepted_matrix = None self._n_proposed_matrix = None + self._nbr_lists = None self._reporter = reporter # NOTE: reporter needs to be putlic, API change ahead self._metadata = None @@ -162,7 +163,7 @@ def is_periodic(self): """ if self._sampler_states is None: return None - return self._thermodynamic_states[0].is_periodic + return self.is_periodic @property def is_completed(self): @@ -187,9 +188,10 @@ def _compute_replica_energies(self, replica_id: int) -> np.ndarray: # Retrieve sampler state associated to this replica. sampler_state = self._sampler_states[replica_id] + nbr_list = self._sampler_states[replica_id] # Compute energy for all thermodynamic states. energies = calculate_reduced_potential_at_states( - sampler_state, self._thermodynamic_states, self.nbr_list + sampler_state, self._thermodynamic_states, nbr_list ) return energies @@ -197,7 +199,7 @@ def create( self, thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], - nbr_list: NeighborListNsqrd, + nbr_lists: List[PairsBase], ): """ Create a new multistate sampler simulation. @@ -208,8 +210,8 @@ def create( List of ThermodynamicStates to simulate, with one replica per state. sampler_states : List[SamplerState] List of initial SamplerStates. The number of states is the number of replicas. - nbr_list : NeighborListNsqrd - Neighbor list object for the simulation. + nbr_lists : List[PairsBase] + A list of objects used to efficiently calculate interacting pairs for each sampler state. Raises ------ @@ -227,14 +229,14 @@ def create( "Number of thermodynamic states and sampler states must be equal." ) - self.nbr_list = nbr_list - self._allocate_variables(thermodynamic_states, sampler_states) + self._allocate_variables(thermodynamic_states, sampler_states, nbr_lists) self._reporter = MultistateReporter() def _allocate_variables( self, thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], + nbr_lists: List[PairsBase], ) -> None: """ Allocate and initialize internal variables for the sampler. @@ -245,6 +247,8 @@ def _allocate_variables( A list of ThermodynamicState objects to be used in the sampler. sampler_states : List[SamplerState] A list of SamplerState objects for initializing the sampler. + nbr_lists : List[PairsBase] + A list of objects used to efficiently calculate interacting pairs for each sampler state. Raises ------ @@ -255,8 +259,16 @@ def _allocate_variables( import numpy as np self._thermodynamic_states = copy.deepcopy(thermodynamic_states) - self._sampler_states = sampler_states + self._sampler_states = copy.deepcopy(sampler_states) + self._nbr_lists = copy.deepcopy(nbr_lists) + assert len(self._thermodynamic_states) == len(self._sampler_states) + assert len(self._thermodynamic_states) == len(self._nbr_lists) + + # initial build of neighborlists + for nbr_list, state in zip(self._nbr_lists, self._sampler_states): + nbr_list.build(state.positions, state.box_vectors) + self._replica_thermodynamic_states = np.arange( len(thermodynamic_states), dtype=int ) @@ -325,13 +337,23 @@ def _minimize_replica( minimized_state = minimize_energy( sampler_state.positions, thermodynamic_state.potential.compute_energy, - self.nbr_list, + self._nbr_lists[replica_id], maxiter=max_iterations, ) # Update the sampler state self._sampler_states[replica_id].positions = minimized_state.params + # it is not likely that we would need to rebuild after minimization + # but we should make sure check to make sure + if self._nbr_lists[replica_id].check( + self._sampler_states[replica_id].positions + ): + self._nbr_lists[replica_id].build( + self._sampler_states[replica_id].positions, + self._sampler_states[replica_id].box_vectors, + ) + # Compute and log final energy final_energy = thermodynamic_state.get_reduced_potential(sampler_state) log.debug( @@ -395,6 +417,7 @@ def _propagate_replica(self, replica_id: int): thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] sampler_state = self._sampler_states[replica_id] thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] + nbr_list = self._nbr_lists[replica_id] mcmc_sampler = self._mcmc_sampler[thermodynamic_state_id] # Propagate using the mcmc sampler @@ -402,8 +425,10 @@ def _propagate_replica(self, replica_id: int): ( self._sampler_states[replica_id], self._thermodynamic_states[thermodynamic_state_id], - nbr_list, - ) = mcmc_sampler.run(sampler_state, thermodynamic_state) + self._nbr_lists[replica_id], + ) = mcmc_sampler.run( + sampler_state, thermodynamic_state, self.number_of_iterations, nbr_list + ) # Append the new state to the trajectory for analysis. self._traj[replica_id].append(self._sampler_states[replica_id].positions) diff --git a/chiron/neighbors.py b/chiron/neighbors.py index be80cf0..ead53df 100644 --- a/chiron/neighbors.py +++ b/chiron/neighbors.py @@ -68,6 +68,9 @@ def displacement( # calculate uncorrected r_ij r_ij = xyz_1 - xyz_2 + if box_vectors is None: + raise ValueError("box_vectors must be provided for a periodic system") + box_lengths = jnp.array( [box_vectors[0][0], box_vectors[1][1], box_vectors[2][2]] ) @@ -97,6 +100,9 @@ def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: Wrapped positions of the system """ + if box_vectors is None: + raise ValueError("box_vectors must be provided for a periodic system") + box_lengths = jnp.array( [box_vectors[0][0], box_vectors[1][1], box_vectors[2][2]] ) @@ -106,16 +112,16 @@ def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: return xyz -class OrthogonalNonperiodicSpace(Space): +class OrthogonalNonPeriodicSpace(Space): @partial(jax.jit, static_argnums=(0,)) def displacement( self, xyz_1: jnp.array, xyz_2: jnp.array, - box_vectors: jnp.array, + box_vectors: Optional[jnp.array] = None, ) -> Tuple[jnp.array, jnp.array]: """ - Calculate the periodic distance between two points. + Calculate the distance between two points in a non-periodic system. Parameters ---------- @@ -123,8 +129,9 @@ def displacement( Positions of the first point xyz_2: jnp.array Positions of the second point - box_vectors: jnp.array + box_vectors: Optional[jnp.array]=None Box vectors for the system. + These are not needed for a non-periodic system, but are included for consistent API usage. Returns ------- @@ -134,7 +141,7 @@ def displacement( Distance between the two points """ - # calculate uncorrect r_ij + # calculate r_ij r_ij = xyz_1 - xyz_2 # calculate the scalar distance @@ -143,17 +150,21 @@ def displacement( return r_ij, dist @partial(jax.jit, static_argnums=(0,)) - def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: + def wrap( + self, xyz: jnp.array, box_vectors: Optional[jnp.array] = None + ) -> jnp.array: """ - Wrap the positions of the system. - For the Non-periodic system, this does not alter the positions + Wrap the positions of the system inside the box. + For the non-periodic system, this does not alter the positions. Parameters ---------- xyz: jnp.array Positions of the system - box_vectors: jnp.array - Box vectors for the system + box_vectors: Optional[jnp.array]=None + Box vectors for the system. + These are not needed for a non-periodic system, but are included for consistent API usage. + Returns ------- @@ -226,7 +237,7 @@ def __init__( def build( self, positions: Union[jnp.array, unit.Quantity], - box_vectors: Union[jnp.array, unit.Quantity], + box_vectors: Union[jnp.array, unit.Quantity, None], ): """ Build list from an array of positions and array of box vectors. @@ -236,9 +247,10 @@ def build( positions: jnp.array or unit.Quantity Shape[n_particles,3] array of particle positions, either with or without units attached. If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. - box_vectors: jnp.array or unit.Quantity + box_vectors: jnp.array or unit.Quantity or None Shape[3,3] array of box vectors for the system, either with or without units attached. If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. + If None, the system is assumed to be non-periodic and the Space class must reflect this. Returns ------- @@ -250,7 +262,7 @@ def build( def _validate_build_inputs( self, positions: Union[jnp.array, unit.Quantity], - box_vectors: Union[jnp.array, unit.Quantity], + box_vectors: Union[jnp.array, unit.Quantity, None], ): """ Validate the inputs to the build function. @@ -292,6 +304,8 @@ def _validate_build_inputs( f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" ) self.box_vectors = box_vectors + if box_vectors is None: + self.box_vectors = None def build_from_state(self, sampler_state: SamplerState): """ @@ -310,8 +324,8 @@ def build_from_state(self, sampler_state: SamplerState): raise TypeError(f"Expected SamplerState, got {type(sampler_state)} instead") positions = sampler_state.positions - if sampler_state.box_vectors is None: - raise ValueError(f"SamplerState does not contain box vectors") + # if sampler_state.box_vectors is None: + # raise ValueError(f"SamplerState does not contain box vectors") box_vectors = sampler_state.box_vectors self.build(positions, box_vectors) @@ -557,8 +571,9 @@ def _build_neighborlist( Maximum number of neighbors for each particle. Used for padding arrays for efficient jax computations cutoff_and_skin: float Cutoff distance for the neighborlist plus the skin distance, in nanometers. - box_vectors: jnp.array - Box vectors for the system + box_vectors: Union[jnp.array, None] + Box vectors for the system. + If None, the system is assumed to be non-periodic and the Space class must be compatible with this. Returns ------- @@ -613,7 +628,7 @@ def _build_neighborlist( def build( self, positions: Union[jnp.array, unit.Quantity], - box_vectors: Union[jnp.array, unit.Quantity], + box_vectors: Union[jnp.array, unit.Quantity, None], ): """ Build the neighbor list from an array of positions and box vectors. @@ -622,7 +637,7 @@ def build( ---------- positions: jnp.array Shape[N,3] array of particle positions - box_vectors: jnp.array + box_vectors: Union[jnp.array, None] Shape[3,3] array of box vectors Returns @@ -647,10 +662,11 @@ def build( ) box_vectors = box_vectors.value_in_unit_system(unit.md_unit_system) - if box_vectors.shape != (3, 3): - raise ValueError( - f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" - ) + if isinstance(box_vectors, jnp.ndarray): + if box_vectors.shape != (3, 3): + raise ValueError( + f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" + ) self.ref_positions = positions self.box_vectors = box_vectors @@ -714,7 +730,13 @@ def build( @partial(jax.jit, static_argnums=(0,)) def _calc_distance_per_particle( - self, particle1, neighbors, neighbor_mask, positions, cutoff, box_vectors + self, + particle1: int, + neighbors: jnp.array, + neighbor_mask: jnp.array, + positions: jnp.array, + cutoff: float, + box_vectors: Union[jnp.array, None], ): """ Jitted function to calculate the distance between a particle and its neighbors @@ -731,8 +753,9 @@ def _calc_distance_per_particle( X,Y,Z positions of all particles cutoff: float Cutoff distance for the neighborlist, in nanometers - box_vectors: jnp.array - Box vectors for the system + box_vectors: Union[jnp.array, None] + Box vectors for the system. + If None, the system is assumed to be non-periodic and the Space class must be compatible with this. Returns ------- @@ -1046,7 +1069,7 @@ def _remove_self_interactions(self, particles, temp_mask): def build( self, positions: Union[jnp.array, unit.Quantity], - box_vectors: Union[jnp.array, unit.Quantity], + box_vectors: Union[jnp.array, unit.Quantity, None], ): """ Build the list from an array of positions and box vectors. @@ -1055,8 +1078,9 @@ def build( ---------- positions: jnp.array Shape[n_particles,3] array of particle positions - box_vectors: jnp.array - Shape[3,3] array of box vectors + box_vectors: jnp.array or unit.Quantity, or None + Shape[3,3] array of box vectors, with or without units. + If None, the system is assumed to be non-periodic and the Space class must be compatible with this. Returns ------- @@ -1098,8 +1122,9 @@ def _calc_distance_per_particle_with_cutoff( X,Y,Z positions of all particles, shaped (n_particles, 3) cutoff: float Cutoff distance for the interaction. - box_vectors: jnp.array - Box vectors for the system + box_vectors: Union[jnp.array, None] + Box vectors for the system. + If None, the system is assumed to be non-periodic and the Space class must be compatible with this. Returns ------- @@ -1152,8 +1177,9 @@ def _calc_distance_per_particle_no_cutoff( Mask to exclude double particles to prevent double counting positions: jnp.array X,Y,Z positions of all particles, shaped (n_particles, 3) - box_vectors: jnp.array - Box vectors of the system + box_vectors: Union[jnp.array, None] + Box vectors of the system. + If None, the system is assumed to be non-periodic and the Space class must be compatible with this. Returns ------- diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 533d879..c850efb 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -1,29 +1,27 @@ +import copy + from chiron.multistate import MultiStateSampler -from chiron.neighbors import NeighborListNsqrd +from chiron.neighbors import PairListNsqrd import pytest from typing import Tuple -def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: +def setup_sampler() -> Tuple[PairListNsqrd, MultiStateSampler]: """ - Set up the neighbor list and multistate sampler for the simulation. + Set up the pair list and multistate sampler for the simulation. Returns: - Tuple: A tuple containing the neighbor list and multistate sampler objects. + Tuple: A tuple containing the pair list and multistate sampler objects. """ from openmm import unit from chiron.mcmc import LangevinDynamicsMove - from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace + from chiron.neighbors import PairListNsqrd, OrthogonalNonPeriodicSpace from chiron.reporters import MultistateReporter, BaseReporter from chiron.mcmc import MCMCSampler, MoveSchedule - sigma = 0.34 * unit.nanometer - cutoff = 3.0 * sigma - skin = 0.5 * unit.nanometer + cutoff = 1.0 * unit.nanometer - nbr_list = NeighborListNsqrd( - OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 - ) + nbr_list = PairListNsqrd(OrthogonalNonPeriodicSpace(), cutoff=cutoff) lang_move = LangevinDynamicsMove( timestep=1.0 * unit.femtoseconds, number_of_steps=100 @@ -76,10 +74,14 @@ def ho_multistate_sampler_multiple_minima() -> MultiStateSampler: sampler_state = [SamplerState(ho.positions, PRNG.get_random_key()) for _ in x0s] nbr_list, multistate_sampler = setup_sampler() + import copy + + nbr_lists = [copy.deepcopy(nbr_list) for _ in x0s] + multistate_sampler.create( thermodynamic_states=thermodynamic_states, sampler_states=sampler_state, - nbr_list=nbr_list, + nbr_lists=nbr_lists, ) return multistate_sampler @@ -135,11 +137,14 @@ def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: ) nbr_list, multistate_sampler = setup_sampler() + import copy + + nbr_lists = [copy.deepcopy(nbr_list) for _ in sigmas] multistate_sampler.create( thermodynamic_states=thermodynamic_states, sampler_states=sampler_state, - nbr_list=nbr_list, + nbr_lists=nbr_lists, ) multistate_sampler.analytical_f_i = f_i multistate_sampler.delta_f_ij_analytical = f_i - f_i[:, np.newaxis] diff --git a/chiron/tests/test_pairs.py b/chiron/tests/test_pairs.py index 60df5ad..f463a43 100644 --- a/chiron/tests/test_pairs.py +++ b/chiron/tests/test_pairs.py @@ -4,7 +4,7 @@ NeighborListNsqrd, PairListNsqrd, OrthogonalPeriodicSpace, - OrthogonalNonperiodicSpace, + OrthogonalNonPeriodicSpace, ) from chiron.states import SamplerState @@ -42,7 +42,7 @@ def test_orthogonal_periodic_displacement(): def test_orthogonal_nonperiodic_displacement(): - space = OrthogonalNonperiodicSpace() + space = OrthogonalNonPeriodicSpace() box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) p1 = jnp.array([[0, 0, 0], [0, 0, 0]]) p2 = jnp.array([[1, 0, 0], [6, 0, 0]]) From e5bf78b68d86f2157d883f99f61944b76d4ba0a0 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 29 Feb 2024 20:19:45 -0800 Subject: [PATCH 50/55] Fixed multistate test of harmonic oscillator array. --- chiron/states.py | 11 ++++++++--- chiron/tests/test_multistate.py | 16 ++++++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/chiron/states.py b/chiron/states.py index 6ebdca1..805d02f 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -329,10 +329,13 @@ def kT_to_kJ_per_mol(self, energy): return energy / self.beta +from chiron.neighbors import PairsBase + + def calculate_reduced_potential_at_states( sampler_state: SamplerState, thermodynamic_states: List[ThermodynamicState], - nbr_list=None, + nbr_list: Optional[PairsBase] = None, ): """ Calculate the reduced potential for a list of thermodynamic states. @@ -343,7 +346,7 @@ def calculate_reduced_potential_at_states( The sampler state for which to compute the reduced potential. thermodynamic_states : list of ThermodynamicState The thermodynamic states for which to compute the reduced potential. - nbr_list : NeighborList or PairListNsqrd, optional + nbr_list : NeighborList or PairListNsqrd, or None, optional Returns ------- list of float @@ -355,6 +358,8 @@ def calculate_reduced_potential_at_states( reduced_potentials = np.zeros(len(thermodynamic_states)) for state_idx, state in enumerate(thermodynamic_states): - reduced_potentials[state_idx] = state.get_reduced_potential(sampler_state) + reduced_potentials[state_idx] = state.get_reduced_potential( + sampler_state, nbr_list + ) log.debug(f"reduced potentials per sampler sate: {reduced_potentials}") return reduced_potentials diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index c850efb..f23854d 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -205,9 +205,9 @@ def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSa ) -@pytest.mark.skip( - reason="Multistate code still needs to be modified in the multistage branch" -) +# @pytest.mark.skip( +# reason="Multistate code still needs to be modified in the multistage branch" +# ) def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): """ Test function for running the multistate sampler. @@ -227,8 +227,8 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): print(f"Analytical free energy difference: {ho_sampler.delta_f_ij_analytical[0]}") - n_iteratinos = 25 - ho_sampler.run(n_iteratinos) + n_iterations = 25 + ho_sampler.run(n_iterations) # check that we have the correct number of iterations, replicas and states assert ho_sampler.iteration == n_iterations @@ -238,7 +238,11 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): u_kn = ho_sampler._reporter.get_property("u_kn") - assert u_kn.shape == (n_iteratinos, 4, 4) + # the u_kn array is transposed to be _states, n_replicas, n_iterations + # SHOULD THIS BE TRANSPOSED IN THE REPORTER? I feel safer to have it + # be transposed when used (if we want it in such a form). + # note n_iterations+1 because it logs time = 0 as well + assert u_kn.shape == (4, 4, n_iterations + 1) # check that the free energies are correct print(ho_sampler.analytical_f_i) # [ 0. , -0.28593054, -0.54696467, -0.78709279] From 76707b1f6167c098a0ad50d2d8aa59e438d0b672 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 29 Feb 2024 23:56:29 -0800 Subject: [PATCH 51/55] multistate test passing locally. --- chiron/potential.py | 15 ++++++++++++--- chiron/tests/test_multistate.py | 6 +++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/chiron/potential.py b/chiron/potential.py index 1d2d340..608b44b 100644 --- a/chiron/potential.py +++ b/chiron/potential.py @@ -408,12 +408,21 @@ def __init__( ) # offset potential energy self.topology = topology + from functools import partial + + @partial(jax.jit, static_argnums=(0,)) + def _compute_energy(self, positions: jnp.array, x0: jnp.array, k, U0): + displacement_vectors = positions - x0 + # Use the 3D harmonic oscillator potential to compute the potential energy + potential_energy = 0.5 * k * jnp.sum(displacement_vectors**2) + U0 + return potential_energy + def compute_energy(self, positions: jnp.array, nbr_list=None): # the functional form is given by U(x) = (K/2) * ( (x-positions)^2 + y^2 + z^2 ) + U0 # https://github.com/choderalab/openmmtools/blob/main/openmmtools/testsystems.py#L695 # compute the displacement vectors - displacement_vectors = positions - self.x0 + # displacement_vectors = positions - self.x0 # Uue the 3D harmonic oscillator potential to compute the potential energy - potential_energy = 0.5 * self.k * jnp.sum(displacement_vectors**2) + self.U0 - return potential_energy + # potential_energy = 0.5 * self.k * jnp.sum(displacement_vectors**2) + self.U0 + return self._compute_energy(positions, self.x0, self.k, self.U0) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index f23854d..4294d21 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -19,7 +19,7 @@ def setup_sampler() -> Tuple[PairListNsqrd, MultiStateSampler]: from chiron.reporters import MultistateReporter, BaseReporter from chiron.mcmc import MCMCSampler, MoveSchedule - cutoff = 1.0 * unit.nanometer + cutoff = 10.0 * unit.nanometer nbr_list = PairListNsqrd(OrthogonalNonPeriodicSpace(), cutoff=cutoff) @@ -135,7 +135,7 @@ def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: for sigma in sigmas ] ) - + log.info(f"Analytical free energy difference: {f_i}") nbr_list, multistate_sampler = setup_sampler() import copy @@ -227,7 +227,7 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): print(f"Analytical free energy difference: {ho_sampler.delta_f_ij_analytical[0]}") - n_iterations = 25 + n_iterations = 20 ho_sampler.run(n_iterations) # check that we have the correct number of iterations, replicas and states From 5522629822fd0aee74fa3129d401eedc391bdd54 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Fri, 22 Mar 2024 11:03:04 -0700 Subject: [PATCH 52/55] Working through my comment son the PR. --- chiron/analysis.py | 1 - chiron/multistate.py | 13 +++++++++---- chiron/states.py | 9 +++++---- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/chiron/analysis.py b/chiron/analysis.py index 47fdb86..3931377 100644 --- a/chiron/analysis.py +++ b/chiron/analysis.py @@ -56,5 +56,4 @@ def get_free_energy_difference(self): from loguru import logger as log log.debug(self.mbar.f_k[-1]) - self.f_k = self.mbar.f_k return self.mbar_f_k[-1] diff --git a/chiron/multistate.py b/chiron/multistate.py index 89c9f8f..b415fb5 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -77,7 +77,7 @@ def __init__( self._n_proposed_matrix = None self._nbr_lists = None - self._reporter = reporter # NOTE: reporter needs to be putlic, API change ahead + self._reporter = reporter # NOTE: reporter needs to be public, API change ahead self._metadata = None self._mcmc_sampler = copy.deepcopy(mcmc_sampler) self._online_estimator = None @@ -163,14 +163,19 @@ def is_periodic(self): """ if self._sampler_states is None: return None - return self.is_periodic + # if we define box vectors the system will be periodic + # I think we will need to check the sampler states to ensure that they all do have box vectors defined + if self._sampler_states[0].box_vectors is not None: + self._is_periodic = True + + return self._is_periodic @property def is_completed(self): """Check if we have reached any of the stop target criteria (read-only)""" return self._is_completed() - def _compute_replica_energies(self, replica_id: int) -> np.ndarray: + def _compute_replica_reduced_potential(self, replica_id: int) -> np.ndarray: """ Compute the energy of a replica across all thermodynamic states. @@ -514,7 +519,7 @@ def _compute_energies(self) -> None: for replica_id in range(self.n_replicas): self._energy_thermodynamic_states[ replica_id, : - ] = self._compute_replica_energies(replica_id) + ] = self._compute_replica_reduced_potential(replica_id) def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: """ diff --git a/chiron/states.py b/chiron/states.py index 805d02f..8ce9894 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -353,13 +353,14 @@ def calculate_reduced_potential_at_states( The reduced potential of the system for each thermodynamic state. """ - import numpy as np + import jax.numpy as jnp + import jax from loguru import logger as log - reduced_potentials = np.zeros(len(thermodynamic_states)) + reduced_potentials = jnp.zeros(len(thermodynamic_states)) for state_idx, state in enumerate(thermodynamic_states): - reduced_potentials[state_idx] = state.get_reduced_potential( - sampler_state, nbr_list + reduced_potentials = reduced_potentials.at[state_idx].set( + state.get_reduced_potential(sampler_state, nbr_list) ) log.debug(f"reduced potentials per sampler sate: {reduced_potentials}") return reduced_potentials From f7620825d5265ac3029f8c7939acf18f4631b6f5 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Fri, 22 Mar 2024 11:29:40 -0700 Subject: [PATCH 53/55] Removed transposing of data in reporter for u_kn. Data is now transposed when we call analysis --- chiron/analysis.py | 4 +++ chiron/mcmc.py | 6 ++-- chiron/multistate.py | 49 ++++++++++++++++++++------------- chiron/reporters.py | 14 +++++----- chiron/states.py | 2 +- chiron/tests/test_multistate.py | 14 ++++------ 6 files changed, 51 insertions(+), 38 deletions(-) diff --git a/chiron/analysis.py b/chiron/analysis.py index 3931377..d44e49f 100644 --- a/chiron/analysis.py +++ b/chiron/analysis.py @@ -28,6 +28,10 @@ def initialize(self, u_kn: np.ndarray, N_k: np.ndarray): from loguru import logger as log log.debug(f"{N_k=}") + u_kn = np.transpose( + u_kn, (2, 1, 0) + ) # shape: n_states, n_replicas, n_iterations + self.mbar = MBAR(u_kn=u_kn, N_k=N_k) @property diff --git a/chiron/mcmc.py b/chiron/mcmc.py index 64d71d2..ea83a0e 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -717,13 +717,13 @@ def _propose( if self.atom_subset is not None and self.atom_subset_mask is None: import jax.numpy as jnp - self.atom_subset_mask = jnp.zeros(current_sampler_state.n_particles) + self.atom_subset_mask = jnp.zeros(current_sampler_state.number_of_particles) for atom in self.atom_subset: self.atom_subset_mask = self.atom_subset_mask.at[atom].set(1) key = current_sampler_state.new_PRNG_key - nr_of_atoms = current_sampler_state.n_particles + nr_of_atoms = current_sampler_state.number_of_particles unitless_displacement_sigma = self.displacement_sigma.value_in_unit_system( unit.md_unit_system @@ -951,7 +951,7 @@ def _propose( import jax.random as jrandom - nr_of_atoms = current_sampler_state.n_particles + nr_of_atoms = current_sampler_state.number_of_particles initial_volume = ( current_sampler_state.box_vectors[0][0] diff --git a/chiron/multistate.py b/chiron/multistate.py index b415fb5..77119a2 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -84,7 +84,7 @@ def __init__( self._offline_estimator = MBAREstimator() @property - def n_states(self) -> int: + def number_of_thermodynamic_states(self) -> int: """ Get the number of thermodynamic states in the sampler. @@ -99,7 +99,7 @@ def n_states(self) -> int: return len(self._thermodynamic_states) @property - def n_replicas(self) -> int: + def number_of_replicas(self) -> int: """ Get the number of replicas in the sampler. @@ -279,23 +279,30 @@ def _allocate_variables( ) # Initialize matrices for tracking acceptance and proposal statistics. - self._n_accepted_matrix = np.zeros([self.n_states, self.n_states], np.int64) - self._n_proposed_matrix = np.zeros([self.n_states, self.n_states], np.int64) + self._n_accepted_matrix = np.zeros( + [self.number_of_thermodynamic_states, self.number_of_thermodynamic_states], + np.int64, + ) + self._n_proposed_matrix = np.zeros( + [self.number_of_thermodynamic_states, self.number_of_thermodynamic_states], + np.int64, + ) self._energy_thermodynamic_states = np.zeros( - [self.n_replicas, self.n_states], np.float64 + [self.number_of_replicas, self.number_of_thermodynamic_states], np.float64 ) - self._traj = [[] for _ in range(self.n_replicas)] + self._traj = [[] for _ in range(self.number_of_replicas)] # Ensure there is an MCMCSampler for each thermodynamic state. from chiron.mcmc import MCMCSampler if isinstance(self._mcmc_sampler, MCMCSampler): self._mcmc_sampler = [ - copy.deepcopy(self._mcmc_sampler) for _ in range(self.n_states) + copy.deepcopy(self._mcmc_sampler) + for _ in range(self.number_of_thermodynamic_states) ] - elif len(self._mcmc_sampler) != self.n_states: + elif len(self._mcmc_sampler) != self.number_of_thermodynamic_states: raise RuntimeError( - f"The number of MCMCMoves ({len(self._mcmc_sampler)}) and ThermodynamicStates ({self.n_states}) must be the same." + f"The number of MCMCMoves ({len(self._mcmc_sampler)}) and ThermodynamicStates ({self.number_of_thermodynamic_states}) must be the same." ) # Reset iteration counter. @@ -335,7 +342,7 @@ def _minimize_replica( # Compute the initial energy of the system for logging. initial_energy = thermodynamic_state.get_reduced_potential(sampler_state) log.debug( - f"Replica {replica_id + 1}/{self.n_replicas}: initial energy {initial_energy:8.3f}kT" + f"Replica {replica_id + 1}/{self.number_of_replicas}: initial energy {initial_energy:8.3f}kT" ) # Perform minimization @@ -362,7 +369,7 @@ def _minimize_replica( # Compute and log final energy final_energy = thermodynamic_state.get_reduced_potential(sampler_state) log.debug( - f"Replica {replica_id + 1}/{self.n_replicas}: final energy {final_energy:8.3f}kT" + f"Replica {replica_id + 1}/{self.number_of_replicas}: final energy {final_energy:8.3f}kT" ) def minimize( @@ -393,7 +400,7 @@ def minimize( from loguru import logger as log # Check that simulation has been created. - if self.n_replicas == 0: + if self.number_of_replicas == 0: raise RuntimeError( "Cannot minimize replicas. The simulation must be created first." ) @@ -401,7 +408,7 @@ def minimize( log.debug("Minimizing all replicas...") # Iterate over all replicas and minimize them - for replica_id in range(self.n_replicas): + for replica_id in range(self.number_of_replicas): self._minimize_replica(replica_id, tolerance, max_iterations) def _propagate_replica(self, replica_id: int): @@ -499,7 +506,7 @@ def _propagate_replicas(self) -> None: log.debug("Propagating all replicas...") # Iterate over all replicas and propagate each one. - for replica_id in range(self.n_replicas): + for replica_id in range(self.number_of_replicas): self._propagate_replica(replica_id) def _compute_energies(self) -> None: @@ -513,10 +520,12 @@ def _compute_energies(self) -> None: log.debug("Computing energy matrix for all replicas...") # Initialize the energy matrix and neighborhoods - self._energy_thermodynamic_states = np.zeros((self.n_replicas, self.n_states)) + self._energy_thermodynamic_states = np.zeros( + (self.number_of_replicas, self.number_of_thermodynamic_states) + ) # Calculate and store energies for each replica. - for replica_id in range(self.n_replicas): + for replica_id in range(self.number_of_replicas): self._energy_thermodynamic_states[ replica_id, : ] = self._compute_replica_reduced_potential(replica_id) @@ -614,8 +623,10 @@ def _report_positions(self): log.debug("Reporting positions...") # numpy array with shape (n_replicas, n_atoms, 3) - xyz = np.zeros((self.n_replicas, self._sampler_states[0].positions.shape[0], 3)) - for replica_id in range(self.n_replicas): + xyz = np.zeros( + (self.number_of_replicas, self._sampler_states[0].positions.shape[0], 3) + ) + for replica_id in range(self.number_of_replicas): xyz[replica_id] = self._sampler_states[replica_id].positions return {"positions": xyz} @@ -691,7 +702,7 @@ def _update_analysis(self): # Perform offline free energy estimate if requested if self._offline_estimator: log.debug("Performing offline free energy estimate...") - N_k = [self._iteration] * self.n_states + N_k = [self._iteration] * self.number_of_thermodynamic_states u_kn = self._reporter.get_property("u_kn") self._offline_estimator.initialize( u_kn=u_kn, diff --git a/chiron/reporters.py b/chiron/reporters.py index 156e86a..ba5606c 100644 --- a/chiron/reporters.py +++ b/chiron/reporters.py @@ -194,13 +194,13 @@ def get_property(self, name: str) -> np.ndarray: log.warning(f"{name} not in HDF5 file") return None - if name == "u_kn": - return np.transpose( - data, (2, 1, 0) - ) # shape: n_states, n_replicas, n_iterations - - else: - return data + # if name == "u_kn": + # return np.transpose( + # data, (2, 1, 0) + # ) # shape: n_states, n_replicas, n_iterations + # + # else: + return data from typing import Optional diff --git a/chiron/states.py b/chiron/states.py index 8ce9894..0e45606 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -94,7 +94,7 @@ def __init__( self._time_unit = unit.picosecond @property - def n_particles(self) -> int: + def number_of_particles(self) -> int: return self._positions.shape[0] @property diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 4294d21..322cab6 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -166,8 +166,8 @@ def test_multistate_class(ho_multistate_sampler_multiple_minima: MultiStateSampl """ assert ho_multistate_sampler_multiple_minima._iteration == 0 - assert ho_multistate_sampler_multiple_minima.n_replicas == 3 - assert ho_multistate_sampler_multiple_minima.n_states == 3 + assert ho_multistate_sampler_multiple_minima.number_of_replicas == 3 + assert ho_multistate_sampler_multiple_minima.number_of_thermodynamic_states == 3 assert ho_multistate_sampler_multiple_minima._energy_thermodynamic_states.shape == ( 3, 3, @@ -233,16 +233,14 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): # check that we have the correct number of iterations, replicas and states assert ho_sampler.iteration == n_iterations assert ho_sampler._iteration == n_iterations - assert ho_sampler.n_replicas == 4 - assert ho_sampler.n_states == 4 + assert ho_sampler.number_of_replicas == 4 + assert ho_sampler.number_of_thermodynamic_states == 4 u_kn = ho_sampler._reporter.get_property("u_kn") - # the u_kn array is transposed to be _states, n_replicas, n_iterations - # SHOULD THIS BE TRANSPOSED IN THE REPORTER? I feel safer to have it - # be transposed when used (if we want it in such a form). + # we no longer transpose the array in the reporter; it is transposed before analysis in mbar # note n_iterations+1 because it logs time = 0 as well - assert u_kn.shape == (4, 4, n_iterations + 1) + assert u_kn.shape == (n_iterations + 1, 4, 4) # check that the free energies are correct print(ho_sampler.analytical_f_i) # [ 0. , -0.28593054, -0.54696467, -0.78709279] From 92c84cfdd4876c6f875ebe0b5f11980c8c8d262e Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Fri, 22 Mar 2024 22:28:39 -0700 Subject: [PATCH 54/55] fixed assert in test now that data is not transposed. --- chiron/tests/test_multistate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 322cab6..9408500 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -241,6 +241,7 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): # we no longer transpose the array in the reporter; it is transposed before analysis in mbar # note n_iterations+1 because it logs time = 0 as well assert u_kn.shape == (n_iterations + 1, 4, 4) + # check that the free energies are correct print(ho_sampler.analytical_f_i) # [ 0. , -0.28593054, -0.54696467, -0.78709279] From bb58931366d35d33477a05f6216a9c94ef97bea4 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Fri, 22 Mar 2024 22:47:51 -0700 Subject: [PATCH 55/55] fixed assert in utils test of reporter now that data is not transposed. --- chiron/tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chiron/tests/test_utils.py b/chiron/tests/test_utils.py index 7ce4474..e36c1f3 100644 --- a/chiron/tests/test_utils.py +++ b/chiron/tests/test_utils.py @@ -133,5 +133,5 @@ def test_reporter(prep_temp_dir, ho_multistate_sampler_multiple_ks): "step", ] u_kn = ho_sampler._reporter.get_property("u_kn") - assert u_kn.shape == (4, 4, 6) + assert u_kn.shape == (6, 4, 4) assert os.path.exists(ho_sampler._reporter.log_file_path)