diff --git a/chiron/integrators.py b/chiron/integrators.py index b25d348..1870867 100644 --- a/chiron/integrators.py +++ b/chiron/integrators.py @@ -29,6 +29,7 @@ def __init__( collision_rate=1.0 / unit.picoseconds, save_frequency: int = 100, reporter: Optional[SimulationReporter] = None, + save_traj_in_memory: bool = False, ) -> None: """ Initialize the LangevinIntegrator object. @@ -56,7 +57,8 @@ def __init__( log.info(f"Using reporter {reporter} saving to {reporter.filename}") self.reporter = reporter self.save_frequency = save_frequency - + self.save_traj_in_memory = save_traj_in_memory + self.traj = [] self.velocities = None def set_velocities(self, vel: unit.Quantity) -> None: @@ -109,10 +111,10 @@ def run( temperature = thermodynamic_state.temperature x0 = sampler_state.x0 - log.info("Running Langevin dynamics") - log.info(f"n_steps = {n_steps}") - log.info(f"temperature = {temperature}") - log.info(f"Using seed: {key}") + log.debug("Running Langevin dynamics") + log.debug(f"n_steps = {n_steps}") + log.debug(f"temperature = {temperature}") + log.debug(f"Using seed: {key}") kbT_unitless = (self.kB * temperature).value_in_unit_system(unit.md_unit_system) mass_unitless = jnp.array(mass.value_in_unit_system(unit.md_unit_system))[ @@ -169,6 +171,9 @@ def run( if step % self.save_frequency == 0: # log.debug(f"Saving at step {step}") # check if reporter is attribute of the class + # log.debug(f"step {step} energy {potential.compute_energy(x, nbr_list)}") + # log.debug(f"step {step} force {F}") + if hasattr(self, "reporter") and self.reporter is not None: d = { "traj": x, @@ -180,6 +185,8 @@ def run( # log.debug(d) self.reporter.report(d) + if self.save_traj_in_memory: + self.traj.append(x) log.debug("Finished running Langevin dynamics") # save the final state of the simulation in the sampler_state object diff --git a/chiron/mcmc.py b/chiron/mcmc.py index c67730e..cbd9979 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -71,6 +71,7 @@ def __init__( simulation_reporter: Optional[SimulationReporter] = None, nr_of_steps=1_000, seed: int = 1234, + save_traj_in_memory: bool = False, ): """ Initialize the LangevinDynamicsMove with a molecular system. @@ -88,13 +89,15 @@ def __init__( self.stepsize = stepsize self.collision_rate = collision_rate self.simulation_reporter = simulation_reporter - + self.save_traj_in_memory = save_traj_in_memory + self.traj = [] from chiron.integrators import LangevinIntegrator self.integrator = LangevinIntegrator( stepsize=self.stepsize, collision_rate=self.collision_rate, reporter=self.simulation_reporter, + save_traj_in_memory=save_traj_in_memory, ) def run( @@ -122,6 +125,9 @@ def run( n_steps=self.nr_of_moves, key=self.key, ) + if self.save_traj_in_memory: + self.traj.append(self.integrator.traj) + self.integrator.traj = [] class MCMove(MCMCMove): diff --git a/chiron/multistate.py b/chiron/multistate.py index 4a071d8..08d36d0 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -1,11 +1,10 @@ import copy -import time from typing import List, Optional from chiron.states import SamplerState, ThermodynamicState import datetime from loguru import logger as log import numpy as np -from openmmtools.utils import time_it, with_timer +from openmmtools.utils import with_timer from chiron.neighbors import NeighborListNsqrd from openmm import unit from chiron.mcmc import MCMCMove @@ -30,8 +29,6 @@ class MultiStateSampler(object): they will be assigned to the correspondent thermodynamic state on creation. If None is provided, Langevin dynamics with 2fm timestep, 5.0/ps collision rate, and 500 steps per iteration will be used. - number_of_iterations : int or infinity, optional, default: 1 - The number of iterations to perform. locality : int > 0, optional, default None If None, the energies at all states will be computed for every replica each iteration. @@ -41,14 +38,18 @@ class MultiStateSampler(object): ---------- n_replicas n_states - iteration mcmc_moves sampler_states metadata is_completed """ - def __init__(self, mcmc_moves=None, number_of_iterations=1, locality=None): + def __init__( + self, + mcmc_moves=None, + locality=None, + online_analysis_interval=5, + ): # These will be set on initialization. See function # create() for explanation of single variables. self._thermodynamic_states = None @@ -57,13 +58,17 @@ def __init__(self, mcmc_moves=None, number_of_iterations=1, locality=None): self._replica_thermodynamic_states = None self._iteration = None self._energy_thermodynamic_states = None + self._energy_thermodynamic_states_for_each_iteration = None self._neighborhoods = None self._energy_unsampled_states = None self._n_accepted_matrix = None self._n_proposed_matrix = None self._reporter = None self._metadata = None + self._online_analysis_interval = online_analysis_interval self._timing_data = dict() + self.free_energy_estimator = None + self._traj = None # Handling default propagator. if mcmc_moves is None: @@ -78,10 +83,6 @@ def __init__(self, mcmc_moves=None, number_of_iterations=1, locality=None): else: self._mcmc_moves = copy.deepcopy(mcmc_moves) - # Store constructor parameters. Everything is marked for internal - # usage because any change to these attribute implies a change - # in the storage file as well. Use properties for checks. - self.number_of_iterations = number_of_iterations self._last_mbar_f_k = None self._last_err_free_energy = None @@ -164,52 +165,22 @@ def _compute_replica_energies(self, replica_id: int) -> np.ndarray: import jax.numpy as jnp from chiron.states import calculate_reduced_potential_at_states - log.debug(f"{self._replica_thermodynamic_states=}") - - # Determine neighborhood - state_index = self._replica_thermodynamic_states[replica_id] - neighborhood = self._neighborhood(state_index) - log.debug(f"{neighborhood=}") # Only compute energies of the sampled states over neighborhoods. - energy_neighborhood_states = np.zeros(len(neighborhood)) - neighborhood_thermodynamic_states = [ - self._thermodynamic_states[n] for n in neighborhood + thermodynamic_states = [ + self._thermodynamic_states[n] for n in range(self.n_states) ] - # Retrieve sampler state associated to this replica. sampler_state = self._sampler_states[replica_id] - log.debug(f"{sampler_state=}") # Compute energy for all thermodynamic states. - from openmmtools.states import group_by_compatibility - - for energies, the_states in [ - (energy_neighborhood_states, neighborhood_thermodynamic_states), - ]: - # Group thermodynamic states by compatibility. - compatible_groups, original_indices = group_by_compatibility(the_states) - - # Compute the reduced potentials of all the compatible states. - for compatible_group, state_indices in zip( - compatible_groups, original_indices - ): - # Compute and update the reduced potentials. - compatible_energies = calculate_reduced_potential_at_states( - sampler_state, compatible_group, self.nbr_list - ) - for energy_idx, state_idx in enumerate(state_indices): - energies[state_idx] = compatible_energies[energy_idx] - - # Return the new energies. - log.info(f"Computed energies for replica {replica_id}") - log.info(f"{energy_neighborhood_states=}") - return energy_neighborhood_states + return calculate_reduced_potential_at_states( + sampler_state, thermodynamic_states, self.nbr_list + ) def create( self, thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_list: NeighborListNsqrd, - reporter: MultiStateReporter, metadata: Optional[dict] = None, ): """Create new multistate sampler simulation. @@ -221,8 +192,6 @@ def create( of sampler states provided. nbr_list : NeighborListNsqrd Neighbor list object to be used in the simulation. - reporter : MultiStateReporter - Reporter object to record simulation data. metadata : dict, optional Optional simulation metadata to be stored in the file. @@ -233,6 +202,7 @@ def create( """ # TODO: initialize reporter here # TODO: consider unsampled thermodynamic states for reweighting schemes + self.free_energy_estimator = "mbar" # Ensure the number of thermodynamic states matches the number of sampler states if len(thermodynamic_states) != len(sampler_states): @@ -242,31 +212,12 @@ def create( self._allocate_variables(thermodynamic_states, sampler_states) self.nbr_list = nbr_list - self._reporter = reporter - self._reporter.open(mode="a") - - @classmethod - def _default_initial_thermodynamic_states( - cls, - thermodynamic_states: List[ThermodynamicState], - sampler_states: List[SamplerState], - ): - """ - Create the initial_thermodynamic_states obeying the following rules: - - * ``len(thermodynamic_states) == len(sampler_states)``: 1-to-1 distribution - """ - n_thermo = len(thermodynamic_states) - n_sampler = len(sampler_states) - assert n_thermo == n_sampler, "Must have 1-to-1 distribution of states" - initial_thermo_states = np.arange(n_thermo, dtype=int) - return initial_thermo_states + self._reporter = None def _allocate_variables( self, thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], - unsampled_thermodynamic_states: Optional[List[ThermodynamicState]] = None, ) -> None: """ Allocate and initialize internal variables for the sampler. @@ -299,16 +250,10 @@ def _allocate_variables( copy.deepcopy(sampler_state) for sampler_state in sampler_states ] - # Handle default unsampled thermodynamic states. - self._unsampled_states = ( - copy.deepcopy(unsampled_thermodynamic_states) - if unsampled_thermodynamic_states is not None - else [] - ) - + assert len(self._thermodynamic_states) == len(self._sampler_states) # Set initial thermodynamic state indices - initial_thermodynamic_states = self._default_initial_thermodynamic_states( - thermodynamic_states, sampler_states + initial_thermodynamic_states = np.arange( + len(self._thermodynamic_states), dtype=int ) self._replica_thermodynamic_states = np.array( initial_thermodynamic_states, np.int64 @@ -327,11 +272,7 @@ def _allocate_variables( self._energy_thermodynamic_states = np.zeros( [self.n_replicas, self.n_states], np.float64 ) - self._neighborhoods = np.zeros([self.n_replicas, self.n_states], "i1") - self._energy_unsampled_states = np.zeros( - [self.n_replicas, len(self._unsampled_states)], np.float64 - ) - + self._traj = [[] for _ in range(self.n_replicas)] # Ensure there is an MCMCMove for each thermodynamic state. if isinstance(self._mcmc_moves, MCMCMove): self._mcmc_moves = [ @@ -438,101 +379,6 @@ def minimize( for replica_id in range(self.n_replicas): self._minimize_replica(replica_id, tolerance, max_iterations) - def _equilibration_timings(self, timer, iteration: int, n_iterations: int): - iteration_time = timer.stop("Equilibration Iteration") - partial_total_time = timer.partial("Run Equilibration") - time_per_iteration = partial_total_time / iteration - estimated_time_remaining = time_per_iteration * (n_iterations - iteration) - estimated_total_time = time_per_iteration * n_iterations - estimated_finish_time = time.time() + estimated_time_remaining - # TODO: Transmit timing information - - log.info(f"Iteration took {iteration_time:.3f}s.") - if estimated_time_remaining != float("inf"): - log.info( - "Estimated completion (of equilibration only) in {}, at {} (consuming total wall clock time {}).".format( - str(datetime.timedelta(seconds=estimated_time_remaining)), - time.ctime(estimated_finish_time), - str(datetime.timedelta(seconds=estimated_total_time)), - ) - ) - - def equilibrate( - self, n_iterations: int, mcmc_moves: Optional[List[MCMCMove]] = None - ): - """ - Equilibrate all replicas in the sampler. - - This method equilibrates the system by running a specified number of - MCMC iterations. The equilibration uses either the provided MCMC moves - or the default ones set during initialization. - - Parameters - ---------- - n_iterations : int - The number of equilibration iterations to perform. - mcmc_moves : Optional[List[mcmc.MCMCMove]], optional - A list of MCMCMove objects to use for equilibration. If None, the - MCMC moves used in production will be used. Defaults to None. - - Raises - ------ - RuntimeError - If the simulation has not been created before calling this method. - """ - # Check that simulation has been created. - if self.n_replicas == 0: - raise RuntimeError( - "Cannot equilibrate replicas. The simulation must be created first." - ) - - # Use production MCMC moves if none are provided - mcmc_moves = mcmc_moves or self._mcmc_moves - - # Make sure there is one MCMCMove per thermodynamic state. - if isinstance(mcmc_moves, MCMCMove): - mcmc_moves = [copy.deepcopy(mcmc_moves) for _ in range(self.n_states)] - - if len(mcmc_moves) != self.n_states: - raise RuntimeError( - f"The number of MCMCMoves ({len(self._mcmc_moves)}) and ThermodynamicStates ({self.n_states}) for equilibration must be the same." - ) - from openmmtools.utils import Timer - - timer = Timer() - timer.start("Run Equilibration") - - # Temporarily set the equilibration MCMCMoves. - production_mcmc_moves = self._mcmc_moves - self._mcmc_moves = mcmc_moves - - for iteration in range(1, n_iterations + 1): - log.info(f"Equilibration iteration {iteration}/{n_iterations}") - timer.start("Equilibration Iteration") - - # NOTE: Unlike run(), do NOT increment iteration counter. - # self._iteration += 1 - - # Propagate replicas. - self._propagate_replicas() - - # Compute energies of all replicas at all states - self._compute_energies() - - # Update thermodynamic states - self._replica_thermodynamic_states = self._mix_replicas() - - # Computing timing information - self._equilibration_timings( - timer, iteration=iteration, n_iterations=n_iterations - ) - timer.report_timing() - - # Restore production MCMCMoves. - self._mcmc_moves = production_mcmc_moves - - # TODO: Update stored positions. - def _propagate_replica(self, replica_id: int): """ Propagate the state of a single replica. @@ -550,19 +396,14 @@ def _propagate_replica(self, replica_id: int): If an error occurs during the propagation of the replica. """ # Retrieve thermodynamic, sampler states, and MCMC move of this replica. - # Retrieve thermodynamic and sampler states for the replica thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] sampler_state = self._sampler_states[replica_id] thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] mcmc_move = self._mcmc_moves[thermodynamic_state_id] - # Apply MCMC move. - try: - mcmc_move.run(sampler_state, thermodynamic_state) - except Exception as e: - log.warning(e) - raise e + mcmc_move.run(sampler_state, thermodynamic_state) + self._traj[replica_id].append(sampler_state.x0) def _perform_swap_proposals(self): """ @@ -615,7 +456,6 @@ def _mix_replicas(self) -> np.ndarray: log.debug( f"Accepted {n_swaps_accepted}/{n_swaps_proposed} attempted swaps ({swap_fraction_accepted * 100.0:.1f}%)" ) - return new_replica_states @with_timer("Propagating all replicas") def _propagate_replicas(self) -> None: @@ -631,38 +471,6 @@ def _propagate_replicas(self) -> None: for replica_id in range(self.n_replicas): self._propagate_replica(replica_id) - def _neighborhood(self, state_index: int) -> List[int]: - """ - Compute the indices of neighboring states for a given state. - - This method determines the neighborhood of states around a given state index, - considering the 'locality' parameter. If 'locality' is None, the neighborhood - includes all states; otherwise, it includes states within the 'locality' range. - - Parameters - ---------- - state_index : int - The index of the state for which the neighborhood is to be calculated. - - Returns - ------- - List[int] - A list of state indices that are considered neighbors of the given state. - """ - if self.locality is None: - # Global neighborhood - return list(range(self.n_states)) - else: - # Local neighborhood specified by 'locality' - lower_bound = max(0, state_index - self.locality) - upper_bound = min(self.n_states, state_index + self.locality + 1) - return list( - range( - lower_bound, - upper_bound, - ) - ) - @with_timer("Computing energy matrix") def _compute_energies(self) -> None: """ @@ -676,23 +484,14 @@ def _compute_energies(self) -> None: log.debug("Computing energy matrix for all replicas...") # Initialize the energy matrix and neighborhoods self._energy_thermodynamic_states = np.zeros((self.n_replicas, self.n_states)) - self._neighborhoods = np.zeros((self.n_replicas, self.n_states), dtype=bool) # Calculate energies for each replica for replica_id in range(self.n_replicas): - neighborhood = self._neighborhood( - self._replica_thermodynamic_states[replica_id] - ) - self._neighborhoods[replica_id, neighborhood] = True - # Compute and store energies for the neighborhood states self._energy_thermodynamic_states[ - replica_id, neighborhood + replica_id, : ] = self._compute_replica_energies(replica_id) - log.debug(self._energy_thermodynamic_states) - log.debug(self._neighborhoods) - def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: """ Determine if the sampling process has met its completion criteria. @@ -744,10 +543,7 @@ def _update_run_progress(self, timer, run_initial_iteration, iteration_limit): f"Estimated completion in {self._timing_data['estimated_time_remaining']}, at {self._timing_data['estimated_localtime_finish_date']} (consuming total wall clock time {self._timing_data['estimated_total_time']})." ) - # Perform sanity checks to see if we should terminate here. - self._check_nan_energy() - - def run(self, n_iterations: Optional[int] = None) -> None: + def run(self, n_iterations: int = 10) -> None: """ Execute the replica-exchange simulation. @@ -757,9 +553,8 @@ def run(self, n_iterations: Optional[int] = None) -> None: Parameters ---------- - n_iterations : Optional[int], default=None - The number of iterations to run. If None, the sampler runs for the - number of iterations specified at initialization. + n_iterations : int, default=10 + The number of iterations to run. Raises ------ @@ -769,34 +564,31 @@ def run(self, n_iterations: Optional[int] = None) -> None: # If this is the first iteration, compute and store the # starting energies of the minimized/equilibrated structures. + self.number_of_iterations = n_iterations log.info("Running simulation...") + self._energy_thermodynamic_states_for_each_iteration_in_run = np.zeros( + [self.n_replicas, self.n_states, n_iterations + 1], np.float64 + ) # Initialize energies if this is the first iteration if self._iteration == 0: self._compute_energies() - - self._reporter.write_energies( - energy_thermodynamic_states=self._energy_thermodynamic_states, - energy_neighborhoods=self._neighborhoods, - energy_unsampled_states=self._energy_unsampled_states, - iteration=self._iteration, - ) + # store energies for mbar analysis + self._energy_thermodynamic_states_for_each_iteration_in_run[ + :, :, self._iteration + ] = self._energy_thermodynamic_states + # TODO report energies from openmmtools.utils import Timer timer = Timer() timer.start("Run ReplicaExchange") - run_initial_iteration = self._iteration - # Determine the number of iterations to run before stopping. - iteration_limit = ( - min(self._iteration + n_iterations, self.number_of_iterations) - if n_iterations is not None - else self.number_of_iterations - ) + iteration_limit = n_iterations - # Main loop. + # start the sampling loop + log.debug(f"{iteration_limit=}") while not self._is_completed(iteration_limit): # Increment iteration counter. self._iteration += 1 @@ -807,7 +599,7 @@ def run(self, n_iterations: Optional[int] = None) -> None: timer.start("Iteration") # Update thermodynamic states - self._replica_thermodynamic_states = self._mix_replicas() + self._mix_replicas() # Propagate replicas. self._propagate_replicas() @@ -815,26 +607,31 @@ def run(self, n_iterations: Optional[int] = None) -> None: # Compute energies of all replicas at all states self._compute_energies() + # Add energies to the energy matrix + self._energy_thermodynamic_states_for_each_iteration_in_run[ + :, :, self._iteration + ] = self._energy_thermodynamic_states # Write iteration to storage file - self._report_iteration() + # TODO + # self._report_iteration() - # TODO: Update analysis - # self._update_analysis() + # Update analysis + self._update_analysis() - # Update timing and progress information - self._update_run_progress( - timer=timer, - run_initial_iteration=run_initial_iteration, - iteration_limit=iteration_limit, - ) - - @with_timer("Writing iteration information to storage") def _report_iteration(self): - """Store positions, states, and energies of current iteration.n""" - # Call report_iteration_items for a subclass-friendly function - self._report_iteration_items() - self._reporter.write_timestamp(self._iteration) - self._reporter.write_last_iteration(self._iteration) + """Store positions, states, and energies of current iteration.""" + + # TODO: write energies + + # TODO: write trajectory + + # TODO: write mixing statistics + self._reporter.write_energies( + self._energy_thermodynamic_states, + self._neighborhoods, + self._energy_unsampled_states, + self._iteration, + ) def _report_iteration_items(self): """ @@ -924,45 +721,36 @@ def flatten(iterator): return flatten(self.mcmc_moves) - def _check_nan_energy(self): - """Checks that energies are finite and abort otherwise. + def _update_analysis(self): + """Update analysis of free energies""" + + if self._online_analysis_interval is None: + log.debug("No online analysis requested") + # Perform no analysis and exit function + return + + # Perform offline free energy estimate if requested + if self.free_energy_estimator == "mbar": + self._last_err_free_energy = self._mbar_analysis() - Checks both sampled and unsampled thermodynamic states. + return + def _mbar_analysis(self): """ - # Find faulty replicas to create error message. - nan_replicas = [] + Perform mbar analysis + """ + from pymbar import MBAR - # Check sampled thermodynamic states first. - state_type = "thermodynamic state" - for replica_id, state_id in enumerate(self._replica_thermodynamic_states): - neighborhood = self._neighborhood(state_id) - energies_neighborhood = self._energy_thermodynamic_states[ - replica_id, neighborhood - ] - if np.any(np.isnan(energies_neighborhood)): - nan_replicas.append((replica_id, energies_neighborhood)) - - # If there are no NaNs in energies, look for NaNs in the unsampled states energies. - if (len(nan_replicas) == 0) and (self._energy_unsampled_states.shape[1] > 0): - state_type = "unsampled thermodynamic state" - for replica_id in range(self.n_replicas): - if np.any(np.isnan(self._energy_unsampled_states[replica_id])): - nan_replicas.append( - (replica_id, self._energy_unsampled_states[replica_id]) - ) - - # Raise exception if we have found some NaN energies. - if len(nan_replicas) > 0: - # Log failed replica, its thermo state, and the energy matrix row. - err_msg = "NaN encountered in {} energies for the following replicas and states".format( - state_type - ) - for replica_id, energy_row in nan_replicas: - err_msg += "\n\tEnergies for positions at replica {} (current state {}): {} kT".format( - replica_id, - self._replica_thermodynamic_states[replica_id], - energy_row, - ) - log.critical(err_msg) - raise RuntimeError(err_msg) + self._last_mbar_f_k_offline = np.zeros(len(self._thermodynamic_states)) + + log.debug( + f"{self._energy_thermodynamic_states_for_each_iteration_in_run.shape=}" + ) + log.debug(f"{self.n_states=}") + u_kn = self._energy_thermodynamic_states_for_each_iteration_in_run + log.debug(f"{self._iteration=}") + N_k = [self._iteration] * self.n_states + log.debug(f"{N_k=}") + mbar = MBAR(u_kn=u_kn, N_k=N_k) + log.debug(mbar.f_k) + self._last_mbar_f_k_offline = mbar.f_k diff --git a/chiron/reporters.py b/chiron/reporters.py index 0ec29b4..b2fae35 100644 --- a/chiron/reporters.py +++ b/chiron/reporters.py @@ -21,8 +21,6 @@ def __init__(self, filename: str, topology: Topology, buffer_size: int = 1): Number of data points to buffer before writing to disk (default is 1). """ - import mdtraj as md - self.filename = filename self.buffer_size = buffer_size self.topology = topology @@ -52,7 +50,7 @@ def report(self, data_dict): if len(self.buffer[key]) >= self.buffer_size: self._write_to_disk(key) - def _write_to_disk(self, key): + def _write_to_disk(self, key:str): """ Write buffered data of a given key to the HDF5 file. diff --git a/chiron/states.py b/chiron/states.py index 3a71dbf..8ba5e52 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -217,32 +217,6 @@ def _check_completness(self): if self.temperature and self.pressure and self.nr_of_particles: log.info("NpT ensemble is simulated.") - def is_state_compatible(self, thermodynamic_state): - """Check compatibility between ThermodynamicStates. - - Parameters - ---------- - thermodynamic_state : ThermodynamicState - The thermodynamic state to test. - - Returns - ------- - is_compatible : bool - True if the states are compatible, False otherwise. - - Examples - -------- - States in the same ensemble (NVT or NPT) are compatible. - States in different ensembles are not compatible. - States that store different systems (that differ by more than - barostat and thermostat pressure and temperature) are also not - compatible. - """ - - # Check that the states are in the same ensemble. - # TODO: implement this - pass - def get_reduced_potential( self, sampler_state: SamplerState, nbr_list=None ) -> float: @@ -274,14 +248,14 @@ def get_reduced_potential( self.beta = 1.0 / ( unit.BOLTZMANN_CONSTANT_kB * (self.temperature * unit.kelvin) ) - log.debug(f"sample state: {sampler_state.x0}") + # log.debug(f"sample state: {sampler_state.x0}") reduced_potential = ( unit.Quantity( self.potential.compute_energy(sampler_state.x0, nbr_list), unit.kilojoule_per_mole, ) ) / unit.AVOGADRO_CONSTANT_NA - log.debug(f"reduced potential: {reduced_potential}") + # log.debug(f"reduced potential: {reduced_potential}") if self.pressure is not None: reduced_potential += self.pressure * self.volume @@ -294,7 +268,7 @@ def kT_to_kJ_per_mol(self, energy): def calculate_reduced_potential_at_states( sampler_state: SamplerState, - themrodynamic_states: List[ThermodynamicState], + thermodynamic_states: List[ThermodynamicState], nbr_list=None, ): """ @@ -313,7 +287,10 @@ def calculate_reduced_potential_at_states( The reduced potential of the system for each thermodynamic state. """ - reduced_potentials = [] - for state in themrodynamic_states: - reduced_potentials.append(state.get_reduced_potential(sampler_state)) + import numpy as np + + reduced_potentials = np.zeros(len(thermodynamic_states)) + for state_idx, state in enumerate(thermodynamic_states): + reduced_potentials[state_idx] = state.get_reduced_potential(sampler_state) + log.debug(f"reduced potentials per sampler sate: {reduced_potentials}") return reduced_potentials diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 8499e31..ce74e1d 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -1,149 +1,217 @@ from chiron.multistate import MultiStateSampler +from chiron.neighbors import NeighborListNsqrd import pytest +from typing import Tuple -@pytest.fixture -def ho_multistate_sampler() -> MultiStateSampler: +def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: """ - Create a multi-state sampler for a harmonic oscillator system. + Set up the neighbor list and multistate sampler for the simulation. Returns: - MultiStateSampler: The multi-state sampler object. + Tuple: A tuple containing the neighbor list and multistate sampler objects. """ - import math from openmm import unit from chiron.mcmc import LangevinDynamicsMove - from chiron.states import ThermodynamicState, SamplerState - from openmmtools.testsystems import HarmonicOscillator - from chiron.potential import HarmonicOscillatorPotential from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace - ho = HarmonicOscillator() - n_replicas = 3 - T_min = 298.0 * unit.kelvin # Minimum temperature. - T_max = 600.0 * unit.kelvin # Maximum temperature. - temperatures = [ - T_min - + (T_max - T_min) - * (math.exp(float(i) / float(n_replicas - 1)) - 1.0) - / (math.e - 1.0) - for i in range(n_replicas) - ] + sigma = 0.34 * unit.nanometer + cutoff = 3.0 * sigma + skin = 0.5 * unit.nanometer + + nbr_list = NeighborListNsqrd( + OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 + ) + + move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=500) + + multistate_sampler = MultiStateSampler(mcmc_moves=move) + return nbr_list, multistate_sampler + + +@pytest.fixture +def ho_multistate_sampler_multiple_minima() -> MultiStateSampler: + """ + Create a multi-state sampler for multiple harmonic oscillators with different minimum values. + + Returns: + MultiStateSampler: The multi-state sampler object. + """ + from chiron.states import ThermodynamicState, SamplerState + from chiron.potential import HarmonicOscillatorPotential import jax.numpy as jnp + from openmm import unit + n_replicas = 3 + T = 300.0 * unit.kelvin # Minimum temperature. x0s = [ unit.Quantity(jnp.array([[x0, 0.0, 0.0]]), unit.angstrom) for x0 in jnp.linspace(0.0, 1.0, n_replicas) ] + + from openmmtools.testsystems import HarmonicOscillator + + ho = HarmonicOscillator() + thermodynamic_states = [ ThermodynamicState( HarmonicOscillatorPotential(ho.topology, x0=x0), temperature=T ) - for T, x0 in zip(temperatures, x0s) + for x0 in x0s ] - sampler_state = [SamplerState(ho.positions) for _ in temperatures] + sampler_state = [SamplerState(ho.positions) for _ in x0s] + nbr_list, multistate_sampler = setup_sampler() + multistate_sampler.create( + thermodynamic_states=thermodynamic_states, + sampler_states=sampler_state, + nbr_list=nbr_list, + ) - # Initialize simulation object with options. Run with a langevin integrator. - # initialize the LennardJones potential in chiron - # - sigma = 0.34 * unit.nanometer - cutoff = 3.0 * sigma - skin = 0.5 * unit.nanometer + return multistate_sampler - nbr_list = NeighborListNsqrd( - OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 - ) - move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=50) +@pytest.fixture +def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: + """ + Create a multi-state sampler for a harmonic oscillator system with different spring constants. + Returns + ------- + MultiStateSampler + The multi-state sampler object. + """ + from openmm import unit + from chiron.states import ThermodynamicState, SamplerState + from openmmtools.testsystems import HarmonicOscillator + from chiron.potential import HarmonicOscillatorPotential + + ho = HarmonicOscillator() + n_states = 4 + + T = 300.0 * unit.kelvin # Minimum temperature. + kT = unit.BOLTZMANN_CONSTANT_kB * T * unit.AVOGADRO_CONSTANT_NA + sigmas = [ + unit.Quantity(2.0 + 0.2 * state_index, unit.angstrom) + for state_index in range(n_states) + ] + Ks = [kT / sigma**2 for sigma in sigmas] + + thermodynamic_states = [ + ThermodynamicState(HarmonicOscillatorPotential(ho.topology, k=k), temperature=T) + for k in Ks + ] + from loguru import logger as log + + log.info(f"Initialize harmonic oscillator with {n_states} states and ks {Ks}") + + sampler_state = [SamplerState(ho.positions) for _ in sigmas] + import numpy as np - from openmmtools.multistate import MultiStateReporter + f_i = np.array( + [ + -np.log(2 * np.pi * (sigma / unit.angstroms) ** 2) * (3.0 / 2.0) + for sigma in sigmas + ] + ) - reporter = MultiStateReporter("test.nc") + nbr_list, multistate_sampler = setup_sampler() - multistate_sampler = MultiStateSampler(mcmc_moves=move, number_of_iterations=2) multistate_sampler.create( thermodynamic_states=thermodynamic_states, sampler_states=sampler_state, nbr_list=nbr_list, - reporter=reporter, ) - + multistate_sampler.analytical_f_i = f_i + multistate_sampler.delta_f_ij_analytical = f_i - f_i[:, np.newaxis] return multistate_sampler -def test_multistate_class(ho_multistate_sampler): - # test the multistate_sampler object - assert ho_multistate_sampler.number_of_iterations == 2 - assert ho_multistate_sampler.n_replicas == 3 - assert ho_multistate_sampler.n_states == 3 - assert ho_multistate_sampler._energy_thermodynamic_states.shape == (3, 3) - assert ho_multistate_sampler._n_proposed_matrix.shape == (3, 3) +def test_multistate_class(ho_multistate_sampler_multiple_minima: MultiStateSampler): + """ + Test initialization for the MultiStateSampler class. + Parameters: + ------- + ho_multistate_sampler_multiple_minima: MultiStateSampler + An instance of the MultiStateSampler class. + Raises: + ------- + AssertionError: + If any of the assertions fail. -def test_multistate_minimize(ho_multistate_sampler): + """ + assert ho_multistate_sampler_multiple_minima._iteration == 0 + assert ho_multistate_sampler_multiple_minima.n_replicas == 3 + assert ho_multistate_sampler_multiple_minima.n_states == 3 + assert ho_multistate_sampler_multiple_minima._energy_thermodynamic_states.shape == ( + 3, + 3, + ) + assert ho_multistate_sampler_multiple_minima._n_proposed_matrix.shape == (3, 3) + + +def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSampler): """ Test function for the `minimize` method of the `ho_multistate_sampler` object. - It checks if the sampler states are correctly minimized. + Check if the sampler states are correctly minimized. Parameters ---------- - ho_multistate_sampler: The `ho_multistate_sampler` object to be tested. + ho_multistate_sampler: MultiStateSampler """ import numpy as np - ho_multistate_sampler.minimize() + ho_multistate_sampler_multiple_minima.minimize() assert np.allclose( - ho_multistate_sampler.sampler_states[0].x0, np.array([[0.0, 0.0, 0.0]]) + ho_multistate_sampler_multiple_minima.sampler_states[0].x0, + np.array([[0.0, 0.0, 0.0]]), ) assert np.allclose( - ho_multistate_sampler.sampler_states[1].x0, + ho_multistate_sampler_multiple_minima.sampler_states[1].x0, np.array([[0.05, 0.0, 0.0]]), atol=1e-2, ) assert np.allclose( - ho_multistate_sampler.sampler_states[2].x0, + ho_multistate_sampler_multiple_minima.sampler_states[2].x0, np.array([[0.1, 0.0, 0.0]]), atol=1e-2, ) -def test_multistate_equilibration(ho_multistate_sampler): - import numpy as np +def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): + """ + Test function for running the multistate sampler. - ho_multistate_sampler.equilibrate(10) + Parameters + ---------- + ho_multistate_sampler_multiple_ks: MultiStateSampler + The multistate sampler object. + Raises + ------- + AssertionError: If free energy does not converge to the analytical free energy difference. - assert np.allclose( - ho_multistate_sampler._replica_thermodynamic_states, np.array([0, 1, 2]) - ) - assert np.allclose( - ho_multistate_sampler._energy_thermodynamic_states, - np.array( - [ - [4.81132936, 3.84872651, 3.10585403], - [6.54490519, 5.0176239, 3.85019779], - [9.48260307, 7.07196712, 5.21255827], - ] - ), - ) + """ - a = 7 + ho_sampler = ho_multistate_sampler_multiple_ks + import numpy as np + print(f"Analytical free energy difference: {ho_sampler.delta_f_ij_analytical[0]}") -def test_multistate_run(ho_multistate_sampler): - import numpy as np + n_iteratinos = 25 + ho_sampler.run(n_iteratinos) + + # check that we have the correct number of iterations, replicas and states + assert ho_sampler.iteration == n_iteratinos + assert ho_sampler._iteration == n_iteratinos + assert ho_sampler.n_replicas == 4 + assert ho_sampler.n_states == 4 + + # check that the free energies are correct + print(ho_sampler.analytical_f_i) + print(ho_sampler.delta_f_ij_analytical) + print(ho_sampler._last_mbar_f_k_offline) - ho_multistate_sampler.equilibrate(10) assert np.allclose( - ho_multistate_sampler._energy_thermodynamic_states, - np.array( - [ - [4.81132936, 3.84872651, 3.10585403], - [6.54490519, 5.0176239, 3.85019779], - [9.48260307, 7.07196712, 5.21255827], - ] - ), + ho_sampler.delta_f_ij_analytical[0], ho_sampler._last_mbar_f_k_offline, atol=0.1 ) - ho_multistate_sampler.run(10) - diff --git a/chiron/tests/test_states.py b/chiron/tests/test_states.py index c9640a2..2e314a9 100644 --- a/chiron/tests/test_states.py +++ b/chiron/tests/test_states.py @@ -124,11 +124,12 @@ def test_sampler_state_inputs(): x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), box_vectors=openmm_box, ) - assert jnp.all( + assert jnp.allclose( state.box_vectors == jnp.array( [[4.0311456, 0.0, 0.0], [0.0, 4.0311456, 0.0], [0.0, 0.0, 4.0311456]] - ) + ), + atol=1e-4, ) # openmm box vectors end up as a list with contents; check to make sure we capture an error if we pass a bad list