diff --git a/slow_tests/test_benchmark.py b/slow_tests/test_benchmark.py index 58dc77c5d..4cfaeee25 100644 --- a/slow_tests/test_benchmark.py +++ b/slow_tests/test_benchmark.py @@ -11,14 +11,14 @@ from simtk.openmm import app -from timemachine.lib import custom_ops -from timemachine.lib import LangevinIntegrator +from timemachine.lib import custom_ops, LangevinIntegrator, MonteCarloBarostat from fe.utils import to_md_units from fe import free_energy from fe.topology import SingleTopology from md import builders, minimizer +from md.barostat.utils import get_bond_list, get_group_indices def recenter(conf, box): @@ -47,7 +47,8 @@ def benchmark( num_batches=100, steps_per_batch=1000, compute_du_dp_interval=100, - compute_du_dl_interval=0 + compute_du_dl_interval=0, + barostat_interval=0, ): """ TODO: configuration blob containing num_batches, steps_per_batch, and any other options @@ -56,10 +57,12 @@ def benchmark( seed = 1234 dt = 1.5e-3 + temperature =300 + pressure = 1.0 seconds_per_day = 86400 intg = LangevinIntegrator( - 300, + temperature, dt, 1.0, np.array(masses), @@ -71,12 +74,28 @@ def benchmark( for potential in bound_potentials: bps.append(potential.bound_impl(precision=np.float32)) # get the bound implementation + baro_impl = None + if barostat_interval > 0: + harmonic_bond_potential = bound_potentials[0] + bond_list = get_bond_list(harmonic_bond_potential) + group_idxs = get_group_indices(bond_list) + baro = MonteCarloBarostat( + x0.shape[0], + pressure, + temperature, + group_idxs, + barostat_interval, + seed, + ) + baro_impl = baro.impl(bps) + ctxt = custom_ops.Context( x0, v0, box, intg, - bps + bps, + barostat=baro_impl, ) # initialize observables @@ -156,6 +175,7 @@ def benchmark_dhfr(verbose=False, num_batches=100, steps_per_batch=1000): v0 = np.zeros_like(host_conf) benchmark("dhfr-apo", host_masses, 0.0, x0, v0, box, host_fns, verbose, num_batches=num_batches, steps_per_batch=steps_per_batch) + benchmark("dhfr-apo-barostat-interval-25", host_masses, 0.0, x0, v0, box, host_fns, verbose, num_batches=num_batches, steps_per_batch=steps_per_batch, barostat_interval=25) def benchmark_hif2a(verbose=False, num_batches=100, steps_per_batch=1000): @@ -194,6 +214,7 @@ def benchmark_hif2a(verbose=False, num_batches=100, steps_per_batch=1000): # lamb = 0.0 benchmark(stage+"-apo", host_masses, 0.0, x0, v0, host_box, host_fns, verbose, num_batches=num_batches, steps_per_batch=steps_per_batch) + benchmark(stage+"-apo-barostat-interval-25", host_masses, 0.0, x0, v0, host_box, host_fns, verbose, num_batches=num_batches, steps_per_batch=steps_per_batch, barostat_interval=25) # RBFE unbound_potentials, sys_params, masses, coords = rfe.prepare_host_edge(ff_params, host_system, x0)