Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Feb 16, 2024
1 parent d3cf557 commit 8f0b8b2
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def test_torchscript(model_name, device):
grad_outputs=grad_outputs,
)[0]

def test_torchscript_output_modification():
model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True))
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.model = model
def forward(self, z, pos, batch):
y, neg_dy = self.model(z, pos, batch=batch)
# A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor]
return y, 2*neg_dy
torch.jit.script(MyModel())

@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("device", ["cpu", "cuda"])
def test_torchscript_dynamic_shapes(model_name, device):
Expand Down

0 comments on commit 8f0b8b2

Please sign in to comment.