From e97aefcae9f68198e19652580da108795ee330dc Mon Sep 17 00:00:00 2001 From: Matt Wittmann Date: Tue, 22 Feb 2022 12:46:28 -0800 Subject: [PATCH] Wrap NonbondedAllPairs and NonbondedPairList, add tests --- tests/nonbonded/conftest.py | 56 ++++++++ tests/nonbonded/test_nonbonded_all_pairs.py | 126 ++++++++++++++++++ .../test_nonbonded_interaction_group.py | 53 -------- tests/nonbonded/test_nonbonded_pair_list.py | 102 ++++++++++++++ timemachine/cpp/src/nonbonded_pair_list.cu | 9 +- timemachine/cpp/src/wrap_kernels.cpp | 124 +++++++++++++++++ timemachine/lib/potentials.py | 34 +++++ 7 files changed, 448 insertions(+), 56 deletions(-) create mode 100644 tests/nonbonded/conftest.py create mode 100644 tests/nonbonded/test_nonbonded_all_pairs.py rename tests/{ => nonbonded}/test_nonbonded_interaction_group.py (89%) create mode 100644 tests/nonbonded/test_nonbonded_pair_list.py diff --git a/tests/nonbonded/conftest.py b/tests/nonbonded/conftest.py new file mode 100644 index 0000000000..9e56720c68 --- /dev/null +++ b/tests/nonbonded/conftest.py @@ -0,0 +1,56 @@ +import numpy as np +import pytest +from simtk.openmm import app + +from timemachine.fe.utils import to_md_units +from timemachine.ff.handlers import openmm_deserializer +from timemachine.lib import potentials + + +@pytest.fixture(autouse=True) +def set_random_seed(): + np.random.seed(2022) + yield + + +@pytest.fixture() +def rng(): + return np.random.default_rng(2022) + + +@pytest.fixture +def example_system(): + pdb_path = "tests/data/5dfr_solv_equil.pdb" + host_pdb = app.PDBFile(pdb_path) + ff = app.ForceField("amber99sbildn.xml", "tip3p.xml") + return ( + ff.createSystem(host_pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False), + host_pdb.positions, + host_pdb.topology.getPeriodicBoxVectors(), + ) + + +@pytest.fixture +def example_nonbonded_params(example_system): + host_system, _, _ = example_system + host_fns, _ = openmm_deserializer.deserialize_system(host_system, cutoff=1.0) + + nonbonded_fn = None + for f in host_fns: + if isinstance(f, potentials.Nonbonded): + nonbonded_fn = f + + assert nonbonded_fn is not None + return nonbonded_fn.params + + +@pytest.fixture +def example_conf(example_system): + _, host_conf, _ = example_system + return np.array([[to_md_units(x), to_md_units(y), to_md_units(z)] for x, y, z in host_conf]) + + +@pytest.fixture +def example_box(example_system): + _, _, box = example_system + return np.asarray(box / box.unit) diff --git a/tests/nonbonded/test_nonbonded_all_pairs.py b/tests/nonbonded/test_nonbonded_all_pairs.py new file mode 100644 index 0000000000..5d81e767e6 --- /dev/null +++ b/tests/nonbonded/test_nonbonded_all_pairs.py @@ -0,0 +1,126 @@ +import functools + +import numpy as np +import pytest +from common import GradientTest + +from timemachine.lib.potentials import NonbondedAllPairs, NonbondedAllPairsInterpolated +from timemachine.potentials import nonbonded + + +def test_nonbonded_all_pairs_invalid_planes_offsets(): + with pytest.raises(RuntimeError) as e: + NonbondedAllPairs([0], [0, 0], 2.0, 1.1).unbound_impl(np.float32) + + assert "lambda offset idxs and plane idxs need to be equivalent" in str(e) + + +def test_nonbonded_all_pairs_invalid_num_atoms(): + potential = NonbondedAllPairs([0], [0], 2.0, 1.1).unbound_impl(np.float32) + with pytest.raises(RuntimeError) as e: + potential.execute(np.zeros((2, 3)), np.zeros((1, 3)), np.eye(3), 0) + + assert "NonbondedAllPairs::execute_device(): expected N == N_, got N=2, N_=1" in str(e) + + +def test_nonbonded_all_pairs_invalid_num_params(): + potential = NonbondedAllPairs([0], [0], 2.0, 1.1).unbound_impl(np.float32) + with pytest.raises(RuntimeError) as e: + potential.execute(np.zeros((1, 3)), np.zeros((2, 3)), np.eye(3), 0) + + assert "NonbondedAllPairs::execute_device(): expected P == M*N_*3, got P=6, M*N_*3=3" in str(e) + + potential_interp = NonbondedAllPairsInterpolated([0], [0], 2.0, 1.1).unbound_impl(np.float32) + with pytest.raises(RuntimeError) as e: + potential_interp.execute(np.zeros((1, 3)), np.zeros((1, 3)), np.eye(3), 0) + + assert "NonbondedAllPairs::execute_device(): expected P == M*N_*3, got P=3, M*N_*3=6" in str(e) + + +def make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff): + @functools.wraps(nonbonded.nonbonded_v3) + def wrapped(conf, params, box, lamb): + num_atoms, _ = conf + no_rescale = np.ones((num_atoms, num_atoms)) + return nonbonded.nonbonded_v3( + conf, + params, + box, + lamb, + charge_rescale_mask=no_rescale, + lj_rescale_mask=no_rescale, + beta=beta, + cutoff=cutoff, + lambda_plane_idxs=lambda_plane_idxs, + lambda_offset_idxs=lambda_offset_idxs, + ) + + return wrapped + + +@pytest.mark.parametrize("lamb", [0.0, 0.1]) +@pytest.mark.parametrize("beta", [2.0]) +@pytest.mark.parametrize("cutoff", [1.1]) +@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)]) +@pytest.mark.parametrize("num_atoms", [33, 65, 231, 1050, 4080]) +def test_nonbonded_all_pairs_correctness( + num_atoms, + precision, + rtol, + atol, + cutoff, + beta, + lamb, + example_nonbonded_params, + example_conf, + example_box, + rng: np.random.Generator, +): + conf = example_conf[:num_atoms] + params = example_nonbonded_params[:num_atoms, :] + + lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + + ref_potential = make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff) + test_potential = NonbondedAllPairs(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff) + + GradientTest().compare_forces( + conf, params, example_box, lamb, ref_potential, test_potential, precision=precision, rtol=rtol, atol=atol + ) + + +@pytest.mark.parametrize("lamb", [0.0, 0.1, 0.9, 1.0]) +@pytest.mark.parametrize("beta", [2.0]) +@pytest.mark.parametrize("cutoff", [1.1]) +@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)]) +@pytest.mark.parametrize("num_atoms", [33, 231, 4080]) +def test_nonbonded_all_pairs_interpolated_correctness( + num_atoms, + precision, + rtol, + atol, + cutoff, + beta, + lamb, + example_nonbonded_params, + example_conf, + example_box, + rng: np.random.Generator, +): + "Compares with jax reference implementation." + + conf = example_conf[:num_atoms] + params_initial = example_nonbonded_params[:num_atoms, :] + params_final = params_initial + rng.normal(0, 0.01, size=params_initial.shape) + params = np.concatenate((params_initial, params_final)) + + lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + + ref_potential = nonbonded.interpolated(make_ref_potential(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)) + test_potential = NonbondedAllPairsInterpolated(lambda_plane_idxs, lambda_offset_idxs, beta, cutoff) + + GradientTest().compare_forces( + conf, params, example_box, lamb, ref_potential, test_potential, precision=precision, rtol=rtol, atol=atol + ) diff --git a/tests/test_nonbonded_interaction_group.py b/tests/nonbonded/test_nonbonded_interaction_group.py similarity index 89% rename from tests/test_nonbonded_interaction_group.py rename to tests/nonbonded/test_nonbonded_interaction_group.py index a786c2c208..2172980c17 100644 --- a/tests/test_nonbonded_interaction_group.py +++ b/tests/nonbonded/test_nonbonded_interaction_group.py @@ -5,64 +5,11 @@ import numpy as np import pytest from common import GradientTest, prepare_reference_nonbonded -from simtk.openmm import app -from timemachine.fe.utils import to_md_units -from timemachine.ff.handlers import openmm_deserializer -from timemachine.lib import potentials from timemachine.lib.potentials import NonbondedInteractionGroup, NonbondedInteractionGroupInterpolated from timemachine.potentials import jax_utils, nonbonded -@pytest.fixture(autouse=True) -def set_random_seed(): - np.random.seed(2022) - yield - - -@pytest.fixture() -def rng(): - return np.random.default_rng(2022) - - -@pytest.fixture -def example_system(): - pdb_path = "tests/data/5dfr_solv_equil.pdb" - host_pdb = app.PDBFile(pdb_path) - ff = app.ForceField("amber99sbildn.xml", "tip3p.xml") - return ( - ff.createSystem(host_pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False), - host_pdb.positions, - host_pdb.topology.getPeriodicBoxVectors(), - ) - - -@pytest.fixture -def example_nonbonded_params(example_system): - host_system, _, _ = example_system - host_fns, _ = openmm_deserializer.deserialize_system(host_system, cutoff=1.0) - - nonbonded_fn = None - for f in host_fns: - if isinstance(f, potentials.Nonbonded): - nonbonded_fn = f - - assert nonbonded_fn is not None - return nonbonded_fn.params - - -@pytest.fixture -def example_conf(example_system): - _, host_conf, _ = example_system - return np.array([[to_md_units(x), to_md_units(y), to_md_units(z)] for x, y, z in host_conf]) - - -@pytest.fixture -def example_box(example_system): - _, _, box = example_system - return np.asarray(box / box.unit) - - def test_nonbonded_interaction_group_invalid_indices(): def make_potential(ligand_idxs, num_atoms): lambda_plane_idxs = [0] * num_atoms diff --git a/tests/nonbonded/test_nonbonded_pair_list.py b/tests/nonbonded/test_nonbonded_pair_list.py new file mode 100644 index 0000000000..ad2abe10f3 --- /dev/null +++ b/tests/nonbonded/test_nonbonded_pair_list.py @@ -0,0 +1,102 @@ +import jax + +jax.config.update("jax_enable_x64", True) + +import functools + +import numpy as np +import pytest +from common import GradientTest + +from timemachine.lib.potentials import NonbondedPairList +from timemachine.potentials import jax_utils, nonbonded + + +def test_nonbonded_pair_list_invalid_pair_idxs(): + with pytest.raises(RuntimeError) as e: + NonbondedPairList([0], [0], [0], [0], 2.0, 1.1).unbound_impl(np.float32) + + assert "pair_idxs.size() must be even, but got 1" in str(e) + + with pytest.raises(RuntimeError) as e: + NonbondedPairList([(0, 0)], [(1, 1)], [0], [0], 2.0, 1.1).unbound_impl(np.float32) + + assert "illegal pair with src == dst: 0, 0" in str(e) + + with pytest.raises(RuntimeError) as e: + NonbondedPairList([(0, 1)], [(1, 1), (2, 2)], [0], [0], 2.0, 1.1).unbound_impl(np.float32) + + assert "expected same number of pairs and scale tuples, but got 1 != 2" in str(e) + + +def make_ref_potential(pair_idxs, scales, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff): + @functools.wraps(nonbonded.nonbonded_v3_on_specific_pairs) + def wrapped(conf, params, box, lamb): + + # compute 4d coordinates + w = jax_utils.compute_lifting_parameter(lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) + conf_4d = jax_utils.augment_dim(conf, w) + box_4d = (1000 * jax.numpy.eye(4)).at[:3, :3].set(box) + + vdW, electrostatics = nonbonded.nonbonded_v3_on_specific_pairs( + conf_4d, params, box_4d, pair_idxs[:, 0], pair_idxs[:, 1], beta, cutoff + ) + return jax.numpy.sum(scales[:, 1] * vdW + scales[:, 0] * electrostatics) + + return wrapped + + +@pytest.mark.parametrize("lamb", [0.0, 0.1]) +@pytest.mark.parametrize("beta", [2.0]) +@pytest.mark.parametrize("cutoff", [1.1]) +@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)]) +@pytest.mark.parametrize("num_atoms", [4080]) +@pytest.mark.parametrize("num_atoms_interacting", [1, 30, 1000]) +def test_nonbonded_interaction_group_correctness( + num_atoms_interacting, + precision, + rtol, + atol, + cutoff, + beta, + lamb, + example_nonbonded_params, + example_conf, + example_box, + rng: np.random.Generator, +): + "Compares with jax reference implementation." + + num_atoms, _ = example_conf.shape + + atom_idxs = rng.choice( + num_atoms, + size=( + 2, + num_atoms_interacting, + ), + replace=False, + ).astype(np.int32) + + pair_idxs = np.stack(np.meshgrid(atom_idxs[0, :], atom_idxs[1, :])).reshape(2, -1).T + num_pairs, _ = pair_idxs.shape + + scales = rng.uniform(0, 1, size=(num_pairs, 2)) + + lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + + ref_potential = make_ref_potential(pair_idxs, scales, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff) + test_potential = NonbondedPairList(pair_idxs, scales, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff) + + GradientTest().compare_forces( + example_conf, + example_nonbonded_params, + example_box, + lamb, + ref_potential, + test_potential, + precision=precision, + rtol=rtol, + atol=atol, + ) diff --git a/timemachine/cpp/src/nonbonded_pair_list.cu b/timemachine/cpp/src/nonbonded_pair_list.cu index f08cc7eca7..cfc75d1a48 100644 --- a/timemachine/cpp/src/nonbonded_pair_list.cu +++ b/timemachine/cpp/src/nonbonded_pair_list.cu @@ -24,19 +24,22 @@ NonbondedPairList::NonbondedPairList( kernel_cache_.program(kernel_src.c_str()).kernel("k_add_du_dp_interpolated").instantiate()) { if (pair_idxs.size() % 2 != 0) { - throw std::runtime_error("pair_idxs.size() must be exactly 2*M"); + throw std::runtime_error("pair_idxs.size() must be even, but got " + std::to_string(pair_idxs.size())); } for (int i = 0; i < M_; i++) { auto src = pair_idxs[i * 2 + 0]; auto dst = pair_idxs[i * 2 + 1]; if (src == dst) { - throw std::runtime_error("illegal pair with src == dst"); + throw std::runtime_error( + "illegal pair with src == dst: " + std::to_string(src) + ", " + std::to_string(dst)); } } if (scales.size() / 2 != M_) { - throw std::runtime_error("bad scales size!"); + throw std::runtime_error( + "expected same number of pairs and scale tuples, but got " + std::to_string(M_) + + " != " + std::to_string(scales.size() / 2)); } gpuErrchk(cudaMalloc(&d_pair_idxs_, M_ * 2 * sizeof(*d_pair_idxs_))); diff --git a/timemachine/cpp/src/wrap_kernels.cpp b/timemachine/cpp/src/wrap_kernels.cpp index d695bb14c2..92247826e1 100644 --- a/timemachine/cpp/src/wrap_kernels.cpp +++ b/timemachine/cpp/src/wrap_kernels.cpp @@ -706,6 +706,58 @@ template void declare_nonbonded(py::modul py::arg("transform_lambda_w") = "lambda"); } +template void declare_nonbonded_all_pairs(py::module &m, const char *typestr) { + + using Class = timemachine::NonbondedAllPairs; + std::string pyclass_name = std::string("NonbondedAllPairs_") + typestr; + py::class_, timemachine::Potential>( + m, pyclass_name.c_str(), py::buffer_protocol(), py::dynamic_attr()) + .def("set_nblist_padding", &timemachine::NonbondedAllPairs::set_nblist_padding) + .def("disable_hilbert_sort", &timemachine::NonbondedAllPairs::disable_hilbert_sort) + .def( + py::init([](const py::array_t &lambda_plane_idxs_i, + const py::array_t &lambda_offset_idxs_i, + const double beta, + const double cutoff, + const std::string &transform_lambda_charge = "lambda", + const std::string &transform_lambda_sigma = "lambda", + const std::string &transform_lambda_epsilon = "lambda", + const std::string &transform_lambda_w = "lambda") { + std::vector lambda_plane_idxs(lambda_plane_idxs_i.size()); + std::memcpy( + lambda_plane_idxs.data(), lambda_plane_idxs_i.data(), lambda_plane_idxs_i.size() * sizeof(int)); + + std::vector lambda_offset_idxs(lambda_offset_idxs_i.size()); + std::memcpy( + lambda_offset_idxs.data(), lambda_offset_idxs_i.data(), lambda_offset_idxs_i.size() * sizeof(int)); + + std::string dir_path = dirname(__FILE__); + std::string kernel_dir = dir_path + "/kernels"; + std::string src_path = kernel_dir + "/k_lambda_transformer_jit.cuh"; + std::ifstream t(src_path); + std::string source_str((std::istreambuf_iterator(t)), std::istreambuf_iterator()); + source_str = std::regex_replace(source_str, std::regex("KERNEL_DIR"), kernel_dir); + source_str = + std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_CHARGE"), transform_lambda_charge); + source_str = + std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_SIGMA"), transform_lambda_sigma); + source_str = + std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_EPSILON"), transform_lambda_epsilon); + source_str = std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_W"), transform_lambda_w); + + return new timemachine::NonbondedAllPairs( + lambda_plane_idxs, lambda_offset_idxs, beta, cutoff, source_str); + }), + py::arg("lambda_plane_idxs_i"), + py::arg("lambda_offset_idxs_i"), + py::arg("beta"), + py::arg("cutoff"), + py::arg("transform_lambda_charge") = "lambda", + py::arg("transform_lambda_sigma") = "lambda", + py::arg("transform_lambda_epsilon") = "lambda", + py::arg("transform_lambda_w") = "lambda"); +} + std::set unique_idxs(const std::vector &idxs) { std::set unique_idxs(idxs.begin(), idxs.end()); if (unique_idxs.size() < idxs.size()) { @@ -774,6 +826,66 @@ void declare_nonbonded_interaction_group(py::module &m, const char *typestr) { py::arg("transform_lambda_w") = "lambda"); } +template void declare_nonbonded_pair_list(py::module &m, const char *typestr) { + const bool Negated = false; + using Class = timemachine::NonbondedPairList; + std::string pyclass_name = std::string("NonbondedPairList_") + typestr; + py::class_, timemachine::Potential>( + m, pyclass_name.c_str(), py::buffer_protocol(), py::dynamic_attr()) + .def( + py::init([](const py::array_t &pair_idxs_i, + const py::array_t &scales_i, + const py::array_t &lambda_plane_idxs_i, + const py::array_t &lambda_offset_idxs_i, + const double beta, + const double cutoff, + const std::string &transform_lambda_charge = "lambda", + const std::string &transform_lambda_sigma = "lambda", + const std::string &transform_lambda_epsilon = "lambda", + const std::string &transform_lambda_w = "lambda") { + std::vector pair_idxs(pair_idxs_i.size()); + std::memcpy(pair_idxs.data(), pair_idxs_i.data(), pair_idxs_i.size() * sizeof(int)); + + std::vector scales(scales_i.size()); + std::memcpy(scales.data(), scales_i.data(), scales_i.size() * sizeof(double)); + + std::vector lambda_plane_idxs(lambda_plane_idxs_i.size()); + std::memcpy( + lambda_plane_idxs.data(), lambda_plane_idxs_i.data(), lambda_plane_idxs_i.size() * sizeof(int)); + + std::vector lambda_offset_idxs(lambda_offset_idxs_i.size()); + std::memcpy( + lambda_offset_idxs.data(), lambda_offset_idxs_i.data(), lambda_offset_idxs_i.size() * sizeof(int)); + + std::string dir_path = dirname(__FILE__); + std::string kernel_dir = dir_path + "/kernels"; + std::string src_path = kernel_dir + "/k_lambda_transformer_jit.cuh"; + std::ifstream t(src_path); + std::string source_str((std::istreambuf_iterator(t)), std::istreambuf_iterator()); + source_str = std::regex_replace(source_str, std::regex("KERNEL_DIR"), kernel_dir); + source_str = + std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_CHARGE"), transform_lambda_charge); + source_str = + std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_SIGMA"), transform_lambda_sigma); + source_str = + std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_EPSILON"), transform_lambda_epsilon); + source_str = std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_W"), transform_lambda_w); + + return new timemachine::NonbondedPairList( + pair_idxs, scales, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff, source_str); + }), + py::arg("pair_idxs_i"), + py::arg("scales_i"), + py::arg("lambda_plane_idxs_i"), + py::arg("lambda_offset_idxs_i"), + py::arg("beta"), + py::arg("cutoff"), + py::arg("transform_lambda_charge") = "lambda", + py::arg("transform_lambda_sigma") = "lambda", + py::arg("transform_lambda_epsilon") = "lambda", + py::arg("transform_lambda_w") = "lambda"); +} + void declare_barostat(py::module &m) { using Class = timemachine::MonteCarloBarostat; @@ -879,11 +991,23 @@ PYBIND11_MODULE(custom_ops, m) { declare_nonbonded(m, "f64"); declare_nonbonded(m, "f32"); + declare_nonbonded_all_pairs(m, "f64_interpolated"); + declare_nonbonded_all_pairs(m, "f32_interpolated"); + + declare_nonbonded_all_pairs(m, "f64"); + declare_nonbonded_all_pairs(m, "f32"); + declare_nonbonded_interaction_group(m, "f64_interpolated"); declare_nonbonded_interaction_group(m, "f32_interpolated"); declare_nonbonded_interaction_group(m, "f64"); declare_nonbonded_interaction_group(m, "f32"); + declare_nonbonded_pair_list(m, "f64_interpolated"); + declare_nonbonded_pair_list(m, "f32_interpolated"); + + declare_nonbonded_pair_list(m, "f64"); + declare_nonbonded_pair_list(m, "f32"); + declare_context(m); } diff --git a/timemachine/lib/potentials.py b/timemachine/lib/potentials.py index 26346a5d83..9c63a0c1d9 100644 --- a/timemachine/lib/potentials.py +++ b/timemachine/lib/potentials.py @@ -280,6 +280,23 @@ def unbound_impl(self, precision): return custom_ctor(*self.args) +class NonbondedAllPairs(CustomOpWrapper): + pass + + +class NonbondedAllPairsInterpolated(NonbondedAllPairs): + def unbound_impl(self, precision): + cls_name_base = "NonbondedAllPairs" + if precision == np.float64: + cls_name_base += "_f64_interpolated" + else: + cls_name_base += "_f32_interpolated" + + custom_ctor = getattr(custom_ops, cls_name_base) + + return custom_ctor(*self.args) + + class NonbondedInteractionGroup(CustomOpWrapper): pass @@ -295,3 +312,20 @@ def unbound_impl(self, precision): custom_ctor = getattr(custom_ops, cls_name_base) return custom_ctor(*self.args) + + +class NonbondedPairList(CustomOpWrapper): + pass + + +class NonbondedPairListInterpolated(NonbondedPairList): + def unbound_impl(self, precision): + cls_name_base = "NonbondedPairList" + if precision == np.float64: + cls_name_base += "_f64_interpolated" + else: + cls_name_base += "_f32_interpolated" + + custom_ctor = getattr(custom_ops, cls_name_base) + + return custom_ctor(*self.args)