From 94f94d90208b43ca083ce7978ca6e7ddadd9085b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 18 Mar 2024 14:07:14 +0100 Subject: [PATCH] Add test for ensemble --- tests/test_model.py | 99 ++++++++++++++++++++++++++++++--------------- 1 file changed, 67 insertions(+), 32 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index b792595b8..31a09a213 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,7 +9,7 @@ import torch import lightning as pl from torchmdnet import models -from torchmdnet.models.model import create_model +from torchmdnet.models.model import create_model, load_model from torchmdnet.models import output_modules from torchmdnet.models.utils import dtype_mapping @@ -23,7 +23,9 @@ def test_forward(model_name, use_batch, explicit_q_s, precision): z, pos, batch = create_example_batch() pos = pos.to(dtype=dtype_mapping[precision]) - model = create_model(load_example_args(model_name, prior_model=None, precision=precision)) + model = create_model( + load_example_args(model_name, prior_model=None, precision=precision) + ) batch = batch if use_batch else None if explicit_q_s: model(z, pos, batch=batch, q=None, s=None) @@ -33,10 +35,12 @@ def test_forward(model_name, use_batch, explicit_q_s, precision): @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("output_model", output_modules.__all__) -@mark.parametrize("precision", [32,64]) +@mark.parametrize("precision", [32, 64]) def test_forward_output_modules(model_name, output_model, precision): z, pos, batch = create_example_batch() - args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision) + args = load_example_args( + model_name, remove_prior=True, output_model=output_model, precision=precision + ) model = create_model(args) model(z, pos, batch=batch) @@ -61,18 +65,25 @@ def test_torchscript(model_name, device): grad_outputs=grad_outputs, )[0] + def test_torchscript_output_modification(): - model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True)) + model = create_model( + load_example_args("tensornet", remove_prior=True, derivative=True) + ) + class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.model = model + def forward(self, z, pos, batch): y, neg_dy = self.model(z, pos, batch=batch) # A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor] - return y, 2*neg_dy + return y, 2 * neg_dy + torch.jit.script(MyModel()) + @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("device", ["cpu", "cuda"]) def test_torchscript_dynamic_shapes(model_name, device): @@ -84,11 +95,11 @@ def test_torchscript_dynamic_shapes(model_name, device): model = torch.jit.script( create_model(load_example_args(model_name, remove_prior=True, derivative=True)) ).to(device=device) - #Repeat the input to make it dynamic + # Repeat the input to make it dynamic for rep in range(0, 5): print(rep) - zi = z.repeat_interleave(rep+1, dim=0).to(device=device) - posi = pos.repeat_interleave(rep+1, dim=0).to(device=device) + zi = z.repeat_interleave(rep + 1, dim=0).to(device=device) + posi = pos.repeat_interleave(rep + 1, dim=0).to(device=device) batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device) y, neg_dy = model(zi, posi, batch=batchi) grad_outputs = [torch.ones_like(neg_dy)] @@ -98,32 +109,35 @@ def test_torchscript_dynamic_shapes(model_name, device): grad_outputs=grad_outputs, )[0] -#Currently only tensornet is CUDA graph compatible + +# Currently only tensornet is CUDA graph compatible @mark.parametrize("model_name", ["tensornet"]) def test_cuda_graph_compatible(model_name): if not torch.cuda.is_available(): pytest.skip("CUDA not available") z, pos, batch = create_example_batch() - args = {"model": model_name, - "embedding_dimension": 128, - "num_layers": 2, - "num_rbf": 32, - "rbf_type": "expnorm", - "trainable_rbf": False, - "activation": "silu", - "cutoff_lower": 0.0, - "cutoff_upper": 5.0, - "max_z": 100, - "max_num_neighbors": 128, - "equivariance_invariance_group": "O(3)", - "prior_model": None, - "atom_filter": -1, - "derivative": True, - "check_error": False, - "static_shapes": True, - "output_model": "Scalar", - "reduce_op": "sum", - "precision": 32 } + args = { + "model": model_name, + "embedding_dimension": 128, + "num_layers": 2, + "num_rbf": 32, + "rbf_type": "expnorm", + "trainable_rbf": False, + "activation": "silu", + "cutoff_lower": 0.0, + "cutoff_upper": 5.0, + "max_z": 100, + "max_num_neighbors": 128, + "equivariance_invariance_group": "O(3)", + "prior_model": None, + "atom_filter": -1, + "derivative": True, + "check_error": False, + "static_shapes": True, + "output_model": "Scalar", + "reduce_op": "sum", + "precision": 32, + } model = create_model(args).to(device="cuda") model.eval() z = z.to("cuda") @@ -142,6 +156,7 @@ def test_cuda_graph_compatible(model_name): assert torch.allclose(y, y2) assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5) + @mark.parametrize("model_name", models.__all_models__) def test_seed(model_name): args = load_example_args(model_name, remove_prior=True) @@ -153,6 +168,7 @@ def test_seed(model_name): for p1, p2 in zip(m1.parameters(), m2.parameters()): assert (p1 == p2).all(), "Parameters don't match although using the same seed." + @mark.parametrize("model_name", models.__all_models__) @mark.parametrize( "output_model", @@ -199,7 +215,9 @@ def test_forward_output(model_name, output_model, overwrite_reference=False): ), f"Set new reference outputs for {model_name} with output model {output_model}." # compare actual ouput with reference - torch.testing.assert_close(pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5) + torch.testing.assert_close( + pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5 + ) if derivative: torch.testing.assert_close( deriv, expected[model_name][output_model]["deriv"], atol=1e-5, rtol=1e-5 @@ -218,7 +236,7 @@ def test_gradients(model_name): remove_prior=True, output_model=output_model, derivative=derivative, - precision=precision + precision=precision, ) model = create_model(args) z, pos, batch = create_example_batch(n_atoms=5) @@ -227,3 +245,20 @@ def test_gradients(model_name): torch.autograd.gradcheck( model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3 ) + + +def test_ensemble(): + ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3 + model = load_model(ckpts[0]) + ensemble_model = load_model(ckpts) + z, pos, batch = create_example_batch(n_atoms=5) + + pred, deriv = model(z, pos, batch) + pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch) + + torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5) + assert y_std.shape == pred.shape + assert neg_dy_std.shape == deriv.shape + assert (y_std == 0).all() + assert (neg_dy_std == 0).all()