From 38ff9a87013448e32fa48da4a93ffeae15c509b8 Mon Sep 17 00:00:00 2001 From: Jason Cho Date: Tue, 21 Jan 2025 10:42:59 -0800 Subject: [PATCH] Add test coverage for futures in ShapleyValueSampling 2/n (#1490) Summary: This diffs adds more testing coverage for attribute_future to ShapleyValueSampling unit tests that handle shapley sampling with boolean inputs Reviewed By: cyrjano Differential Revision: D68230069 --- captum/testing/helpers/basic_models.py | 16 ++++++ tests/attr/test_shapley.py | 67 ++++++++++++++++++-------- 2 files changed, 64 insertions(+), 19 deletions(-) diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 8d22adafc..77a96aa01 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -533,6 +533,22 @@ def forward( return result +class BasicModelBoolInput_with_Future(nn.Module): + def __init__(self) -> None: + super().__init__() + self.mod = BasicModel_MultiLayer_with_Future() + + # pyre-fixme[3]: Return type must be annotated. + def forward( + self, + x: Tensor, + add_input: Optional[Tensor] = None, + mult: float = 10.0, + ): + assert x.dtype is torch.bool, "Input must be boolean" + return self.mod(x.float() * mult, add_input) + + class BasicModelBoolInput(nn.Module): def __init__(self) -> None: super().__init__() diff --git a/tests/attr/test_shapley.py b/tests/attr/test_shapley.py index a9be6f6c4..137074ac4 100644 --- a/tests/attr/test_shapley.py +++ b/tests/attr/test_shapley.py @@ -16,6 +16,7 @@ BasicModel_MultiLayer_MultiInput, BasicModel_MultiLayer_with_Future, BasicModelBoolInput, + BasicModelBoolInput_with_Future, ) from parameterized import parameterized from torch.futures import Future @@ -66,28 +67,56 @@ def test_simple_shapley_sampling_with_mask(self, use_future) -> None: perturbations_per_eval=(1, 2, 3), ) - def test_simple_shapley_sampling_boolean(self) -> None: - net = BasicModelBoolInput() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_boolean(self, use_future) -> None: + if use_future: + net = BasicModelBoolInput_with_Future() + else: + net = BasicModelBoolInput() inp = torch.tensor([[True, False, True]]) - self._shapley_test_assert( - net, - inp, - [[35.0, 35.0, 35.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + self._shapley_test_assert_future( + net, + inp, + [[35.0, 35.0, 35.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) + else: + self._shapley_test_assert( + net, + inp, + [[35.0, 35.0, 35.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) - def test_simple_shapley_sampling_boolean_with_baseline(self) -> None: - net = BasicModelBoolInput() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_boolean_with_baseline(self, use_future) -> None: + if use_future: + net = BasicModelBoolInput_with_Future() + else: + net = BasicModelBoolInput() inp = torch.tensor([[True, False, True]]) - self._shapley_test_assert( - net, - inp, - [[-40.0, -40.0, 0.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - baselines=True, - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + self._shapley_test_assert_future( + net, + inp, + [[-40.0, -40.0, 0.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=True, + perturbations_per_eval=(1, 2, 3), + ) + else: + + self._shapley_test_assert( + net, + inp, + [[-40.0, -40.0, 0.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=True, + perturbations_per_eval=(1, 2, 3), + ) @parameterized.expand([True, False]) def test_simple_shapley_sampling_with_baselines(self, use_future) -> None: