Skip to content

Commit

Permalink
Add test for ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Mar 18, 2024
1 parent d071cd3 commit 94f94d9
Showing 1 changed file with 67 additions and 32 deletions.
99 changes: 67 additions & 32 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import lightning as pl
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.models.model import create_model, load_model
from torchmdnet.models import output_modules
from torchmdnet.models.utils import dtype_mapping

Expand All @@ -23,7 +23,9 @@
def test_forward(model_name, use_batch, explicit_q_s, precision):
z, pos, batch = create_example_batch()
pos = pos.to(dtype=dtype_mapping[precision])
model = create_model(load_example_args(model_name, prior_model=None, precision=precision))
model = create_model(
load_example_args(model_name, prior_model=None, precision=precision)
)
batch = batch if use_batch else None
if explicit_q_s:
model(z, pos, batch=batch, q=None, s=None)
Expand All @@ -33,10 +35,12 @@ def test_forward(model_name, use_batch, explicit_q_s, precision):

@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("output_model", output_modules.__all__)
@mark.parametrize("precision", [32,64])
@mark.parametrize("precision", [32, 64])
def test_forward_output_modules(model_name, output_model, precision):
z, pos, batch = create_example_batch()
args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision)
args = load_example_args(
model_name, remove_prior=True, output_model=output_model, precision=precision
)
model = create_model(args)
model(z, pos, batch=batch)

Expand All @@ -61,18 +65,25 @@ def test_torchscript(model_name, device):
grad_outputs=grad_outputs,
)[0]


def test_torchscript_output_modification():
model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True))
model = create_model(
load_example_args("tensornet", remove_prior=True, derivative=True)
)

class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.model = model

def forward(self, z, pos, batch):
y, neg_dy = self.model(z, pos, batch=batch)
# A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor]
return y, 2*neg_dy
return y, 2 * neg_dy

torch.jit.script(MyModel())


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("device", ["cpu", "cuda"])
def test_torchscript_dynamic_shapes(model_name, device):
Expand All @@ -84,11 +95,11 @@ def test_torchscript_dynamic_shapes(model_name, device):
model = torch.jit.script(
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
).to(device=device)
#Repeat the input to make it dynamic
# Repeat the input to make it dynamic
for rep in range(0, 5):
print(rep)
zi = z.repeat_interleave(rep+1, dim=0).to(device=device)
posi = pos.repeat_interleave(rep+1, dim=0).to(device=device)
zi = z.repeat_interleave(rep + 1, dim=0).to(device=device)
posi = pos.repeat_interleave(rep + 1, dim=0).to(device=device)
batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device)
y, neg_dy = model(zi, posi, batch=batchi)
grad_outputs = [torch.ones_like(neg_dy)]
Expand All @@ -98,32 +109,35 @@ def test_torchscript_dynamic_shapes(model_name, device):
grad_outputs=grad_outputs,
)[0]

#Currently only tensornet is CUDA graph compatible

# Currently only tensornet is CUDA graph compatible
@mark.parametrize("model_name", ["tensornet"])
def test_cuda_graph_compatible(model_name):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
z, pos, batch = create_example_batch()
args = {"model": model_name,
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_error": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32 }
args = {
"model": model_name,
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_error": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32,
}
model = create_model(args).to(device="cuda")
model.eval()
z = z.to("cuda")
Expand All @@ -142,6 +156,7 @@ def test_cuda_graph_compatible(model_name):
assert torch.allclose(y, y2)
assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5)


@mark.parametrize("model_name", models.__all_models__)
def test_seed(model_name):
args = load_example_args(model_name, remove_prior=True)
Expand All @@ -153,6 +168,7 @@ def test_seed(model_name):
for p1, p2 in zip(m1.parameters(), m2.parameters()):
assert (p1 == p2).all(), "Parameters don't match although using the same seed."


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize(
"output_model",
Expand Down Expand Up @@ -199,7 +215,9 @@ def test_forward_output(model_name, output_model, overwrite_reference=False):
), f"Set new reference outputs for {model_name} with output model {output_model}."

# compare actual ouput with reference
torch.testing.assert_close(pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5)
torch.testing.assert_close(
pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5
)
if derivative:
torch.testing.assert_close(
deriv, expected[model_name][output_model]["deriv"], atol=1e-5, rtol=1e-5
Expand All @@ -218,7 +236,7 @@ def test_gradients(model_name):
remove_prior=True,
output_model=output_model,
derivative=derivative,
precision=precision
precision=precision,
)
model = create_model(args)
z, pos, batch = create_example_batch(n_atoms=5)
Expand All @@ -227,3 +245,20 @@ def test_gradients(model_name):
torch.autograd.gradcheck(
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)


def test_ensemble():
ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3
model = load_model(ckpts[0])
ensemble_model = load_model(ckpts)
z, pos, batch = create_example_batch(n_atoms=5)

pred, deriv = model(z, pos, batch)
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)

torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
assert y_std.shape == pred.shape
assert neg_dy_std.shape == deriv.shape
assert (y_std == 0).all()
assert (neg_dy_std == 0).all()

0 comments on commit 94f94d9

Please sign in to comment.