diff --git a/tests/test_barostat.py b/tests/test_barostat.py index d9a830aaa..c10ce5d71 100644 --- a/tests/test_barostat.py +++ b/tests/test_barostat.py @@ -3,6 +3,8 @@ from simtk import unit import time +import pytest + from testsystems.relative import hif2a_ligand_pair from md.builders import build_water_system @@ -24,6 +26,66 @@ from timemachine.constants import BOLTZ, ENERGY_UNIT, DISTANCE_UNIT +def test_barostat_zero_interval(): + pressure = 1. * unit.atmosphere + temperature = 300.0 * unit.kelvin + initial_waterbox_width = 2.0 * unit.nanometer + barostat_interval = 0 + seed = 2021 + np.random.seed(seed) + + mol_a = hif2a_ligand_pair.mol_a + ff = hif2a_ligand_pair.ff + complex_system, complex_coords, complex_box, complex_top = build_water_system( + initial_waterbox_width.value_in_unit(unit.nanometer)) + + min_complex_coords = minimize_host_4d([mol_a], complex_system, complex_coords, ff, complex_box) + afe = AbsoluteFreeEnergy(mol_a, ff) + + unbound_potentials, sys_params, masses, coords = afe.prepare_host_edge( + ff.get_ordered_params(), complex_system, min_complex_coords + ) + + # get list of molecules for barostat by looking at bond table + harmonic_bond_potential = unbound_potentials[0] + bond_list = get_bond_list(harmonic_bond_potential) + group_indices = get_group_indices(bond_list) + + lam = 1.0 + + bound_potentials = [] + for params, unbound_pot in zip(sys_params, unbound_potentials): + bp = unbound_pot.bind(np.asarray(params)) + bound_potentials.append(bp) + + u_impls = [] + for bp in bound_potentials: + bp_impl = bp.bound_impl(precision=np.float32) + u_impls.append(bp_impl) + + with pytest.raises(RuntimeError): + custom_ops.MonteCarloBarostat( + coords.shape[0], + pressure.value_in_unit(unit.bar), + temperature.value_in_unit(unit.kelvin), + group_indices, + 0, + u_impls, + seed + ) + # Setting it to 1 should be valid. + baro = custom_ops.MonteCarloBarostat( + coords.shape[0], + pressure.value_in_unit(unit.bar), + temperature.value_in_unit(unit.kelvin), + group_indices, + 1, + u_impls, + seed + ) + # Setting back to 0 should raise another error + with pytest.raises(RuntimeError): + baro.set_interval(0) def test_barostat_partial_group_idxs(): """Verify that the barostat can handle a subset of the molecules diff --git a/timemachine/cpp/src/barostat.cu b/timemachine/cpp/src/barostat.cu index 9e800492f..c8f4e1c0a 100644 --- a/timemachine/cpp/src/barostat.cu +++ b/timemachine/cpp/src/barostat.cu @@ -33,6 +33,9 @@ MonteCarloBarostat::MonteCarloBarostat( seed_(seed), step_(0) { + // Trigger check that interval is valid + this->set_interval(interval_); + // lets not have another facepalm moment again... if(temperature < 100.0) { std::cout << "warning temperature less than 100K" << std::endl; @@ -314,6 +317,9 @@ void MonteCarloBarostat::inplace_move( }; void MonteCarloBarostat::set_interval(const int interval){ + if (interval <= 0) { + throw std::runtime_error("Barostat interval must be greater than 0"); + } interval_ = interval; // Clear the step, to ensure user can expect that in N steps the barostat will trigger step_ = 0;