From 8f0b8b28beedd1e22de4f69a33ba9f64ef7aaf01 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 16 Feb 2024 10:21:57 +0100 Subject: [PATCH] Add test --- tests/test_model.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_model.py b/tests/test_model.py index 044e6180b..b792595b8 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -61,6 +61,18 @@ def test_torchscript(model_name, device): grad_outputs=grad_outputs, )[0] +def test_torchscript_output_modification(): + model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True)) + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.model = model + def forward(self, z, pos, batch): + y, neg_dy = self.model(z, pos, batch=batch) + # A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor] + return y, 2*neg_dy + torch.jit.script(MyModel()) + @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("device", ["cpu", "cuda"]) def test_torchscript_dynamic_shapes(model_name, device):