Skip to content

Commit

Permalink
Add barostat to benchmarks
Browse files Browse the repository at this point in the history
* Its way too slow right now
* Use barostat interval of 0 as flag to disable barostat
  • Loading branch information
badisa committed Jul 19, 2021
1 parent d3d6117 commit e9c528b
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions slow_tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

from simtk.openmm import app

from timemachine.lib import custom_ops
from timemachine.lib import LangevinIntegrator
from timemachine.lib import custom_ops, LangevinIntegrator, MonteCarloBarostat

from fe.utils import to_md_units
from fe import free_energy
from fe.topology import SingleTopology

from md import builders, minimizer
from md.barostat.utils import get_bond_list, get_group_indices

def recenter(conf, box):

Expand Down Expand Up @@ -47,7 +47,8 @@ def benchmark(
num_batches=100,
steps_per_batch=1000,
compute_du_dp_interval=100,
compute_du_dl_interval=0
compute_du_dl_interval=0,
barostat_interval=0,
):
"""
TODO: configuration blob containing num_batches, steps_per_batch, and any other options
Expand All @@ -56,10 +57,12 @@ def benchmark(

seed = 1234
dt = 1.5e-3
temperature =300
pressure = 1.0
seconds_per_day = 86400

intg = LangevinIntegrator(
300,
temperature,
dt,
1.0,
np.array(masses),
Expand All @@ -71,12 +74,28 @@ def benchmark(
for potential in bound_potentials:
bps.append(potential.bound_impl(precision=np.float32)) # get the bound implementation

baro_impl = None
if barostat_interval > 0:
harmonic_bond_potential = bound_potentials[0]
bond_list = get_bond_list(harmonic_bond_potential)
group_idxs = get_group_indices(bond_list)
baro = MonteCarloBarostat(
x0.shape[0],
pressure,
temperature,
group_idxs,
barostat_interval,
seed,
)
baro_impl = baro.impl(bps)

ctxt = custom_ops.Context(
x0,
v0,
box,
intg,
bps
bps,
barostat=baro_impl,
)

# initialize observables
Expand Down Expand Up @@ -156,6 +175,7 @@ def benchmark_dhfr(verbose=False, num_batches=100, steps_per_batch=1000):
v0 = np.zeros_like(host_conf)

benchmark("dhfr-apo", host_masses, 0.0, x0, v0, box, host_fns, verbose, num_batches=num_batches, steps_per_batch=steps_per_batch)
benchmark("dhfr-apo-barostat-interval-25", host_masses, 0.0, x0, v0, box, host_fns, verbose, num_batches=num_batches, steps_per_batch=steps_per_batch, barostat_interval=25)

def benchmark_hif2a(verbose=False, num_batches=100, steps_per_batch=1000):

Expand Down Expand Up @@ -194,6 +214,7 @@ def benchmark_hif2a(verbose=False, num_batches=100, steps_per_batch=1000):

# lamb = 0.0
benchmark(stage+"-apo", host_masses, 0.0, x0, v0, host_box, host_fns, verbose, num_batches=num_batches, steps_per_batch=steps_per_batch)
benchmark(stage+"-apo-barostat-interval-25", host_masses, 0.0, x0, v0, host_box, host_fns, verbose, num_batches=num_batches, steps_per_batch=steps_per_batch, barostat_interval=25)

# RBFE
unbound_potentials, sys_params, masses, coords = rfe.prepare_host_edge(ff_params, host_system, x0)
Expand Down

0 comments on commit e9c528b

Please sign in to comment.