diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 8d22adafc..77a96aa01 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -533,6 +533,22 @@ def forward( return result +class BasicModelBoolInput_with_Future(nn.Module): + def __init__(self) -> None: + super().__init__() + self.mod = BasicModel_MultiLayer_with_Future() + + # pyre-fixme[3]: Return type must be annotated. + def forward( + self, + x: Tensor, + add_input: Optional[Tensor] = None, + mult: float = 10.0, + ): + assert x.dtype is torch.bool, "Input must be boolean" + return self.mod(x.float() * mult, add_input) + + class BasicModelBoolInput(nn.Module): def __init__(self) -> None: super().__init__() diff --git a/tests/attr/test_shapley.py b/tests/attr/test_shapley.py index b0292a7da..444c12dfc 100644 --- a/tests/attr/test_shapley.py +++ b/tests/attr/test_shapley.py @@ -16,6 +16,7 @@ BasicModel_MultiLayer_MultiInput, BasicModel_MultiLayer_with_Future, BasicModelBoolInput, + BasicModelBoolInput_with_Future, ) from parameterized import parameterized from torch.futures import Future @@ -66,28 +67,51 @@ def test_simple_shapley_sampling_with_mask(self, use_future) -> None: perturbations_per_eval=(1, 2, 3), ) - def test_simple_shapley_sampling_boolean(self) -> None: - net = BasicModelBoolInput() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_boolean(self, use_future) -> None: inp = torch.tensor([[True, False, True]]) - self._shapley_test_assert( - net, - inp, - [[35.0, 35.0, 35.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModelBoolInput_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[35.0, 35.0, 35.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModelBoolInput() + self._shapley_test_assert( + net, + inp, + [[35.0, 35.0, 35.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) - def test_simple_shapley_sampling_boolean_with_baseline(self) -> None: - net = BasicModelBoolInput() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_boolean_with_baseline(self, use_future) -> None: inp = torch.tensor([[True, False, True]]) - self._shapley_test_assert( - net, - inp, - [[-40.0, -40.0, 0.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - baselines=True, - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModelBoolInput_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[-40.0, -40.0, 0.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=True, + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModelBoolInput() + self._shapley_test_assert( + net, + inp, + [[-40.0, -40.0, 0.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=True, + perturbations_per_eval=(1, 2, 3), + ) @parameterized.expand([True, False]) def test_simple_shapley_sampling_with_baselines(self, use_future) -> None: