diff --git a/tests/attr/test_shapley.py b/tests/attr/test_shapley.py index b88b478b6..a9be6f6c4 100644 --- a/tests/attr/test_shapley.py +++ b/tests/attr/test_shapley.py @@ -17,41 +17,54 @@ BasicModel_MultiLayer_with_Future, BasicModelBoolInput, ) +from parameterized import parameterized +from torch.futures import Future class Test(BaseTest): - def test_simple_shapley_sampling(self) -> None: - net = BasicModel_MultiLayer() - inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - self._shapley_test_assert( - net, - inp, - [[76.66666, 196.66666, 116.66666]], - perturbations_per_eval=(1, 2, 3), - n_samples=250, - ) - - def test_simple_shapley_sampling_future(self) -> None: - net = BasicModel_MultiLayer_with_Future() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - self._shapley_test_assert_future( - net, - inp, - [[76.66666, 196.66666, 116.66666]], - perturbations_per_eval=(1, 2, 3), - n_samples=250, - ) + if use_future: + net = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net, + inp, + [[76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=250, + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=250, + ) - def test_simple_shapley_sampling_with_mask(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_with_mask(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - self._shapley_test_assert( - net, - inp, - [[275.0, 275.0, 115.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net, + inp, + [[275.0, 275.0, 115.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[275.0, 275.0, 115.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) def test_simple_shapley_sampling_boolean(self) -> None: net = BasicModelBoolInput() @@ -76,40 +89,74 @@ def test_simple_shapley_sampling_boolean_with_baseline(self) -> None: perturbations_per_eval=(1, 2, 3), ) - def test_simple_shapley_sampling_with_baselines(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_with_baselines(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]]) - self._shapley_test_assert( - net, - inp, - [[248.0, 248.0, 104.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - baselines=4, - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net, + inp, + [[248.0, 248.0, 104.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=4, + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[248.0, 248.0, 104.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=4, + perturbations_per_eval=(1, 2, 3), + ) - def test_multi_sample_shapley_sampling(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_multi_sample_shapley_sampling(self, use_future) -> None: inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]]) - self._shapley_test_assert( - net, - inp, - [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], - perturbations_per_eval=(1, 2, 3), - n_samples=200, - ) + if use_future: + net = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net, + inp, + [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=200, + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=200, + ) - def test_multi_sample_shapley_sampling_with_mask(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_multi_sample_shapley_sampling_with_mask(self, use_future) -> None: inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]], requires_grad=True) mask = torch.tensor([[0, 0, 1], [1, 1, 0]]) - self._shapley_test_assert( - net, - inp, - [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], - feature_mask=mask, - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net, + inp, + [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], + feature_mask=mask, + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], + feature_mask=mask, + perturbations_per_eval=(1, 2, 3), + ) def test_multi_input_shapley_sampling_without_mask(self) -> None: net = BasicModel_MultiLayer_MultiInput() @@ -165,86 +212,166 @@ def test_multi_input_shapley_sampling_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) - def test_shapley_sampling_multi_task_output(self) -> None: + @parameterized.expand([True, False]) + def test_shapley_sampling_multi_task_output(self, use_future) -> None: # return shape (batch size, 2) - net1 = BasicModel_MultiLayer() - - # return shape (batch size, 4) - def forward_func(*args, **kwargs): - net_output = net1(*args, **kwargs) - batch_size = net_output.size(0) - constant = torch.ones(batch_size, 2) - output = torch.cat( - [ - net_output, - constant, - ], - dim=-1, - ) - return output - inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) + if use_future: + net1 = BasicModel_MultiLayer_with_Future() + + def forward_func(*args, **kwargs): + net_output = net1(*args, **kwargs) + net_output.wait() + batch_size = net_output.value().size(0) + constant = torch.ones(batch_size, 2) + output = torch.cat( + [ + net_output.value(), + constant, + ], + dim=-1, + ) + fut = Future() + fut.set_result(output) + return fut - self._shapley_test_assert( - forward_func, - inp, - [ + self._shapley_test_assert_future( + forward_func, + inp, [ - [76.66666, 196.66666, 116.66666], - [76.66666, 196.66666, 116.66666], - [0, 0, 0], - [0, 0, 0], - ] - ], - target=None, # no target, multi-task output for all classes - perturbations_per_eval=(1, 2, 3), - n_samples=150, - test_true_shapley=True, - ) - - def test_shapley_sampling_multi_task_output_with_mask(self) -> None: - # return shape (batch size, 2) - net1 = BasicModel_MultiLayer() - - # return shape (batch size, 4) - def forward_func(*args, **kwargs): - net_output = net1(*args, **kwargs) - batch_size = net_output.size(0) - constant = torch.ones(batch_size, 1) + [ + [76.66666, 196.66666, 116.66666], + [76.66666, 196.66666, 116.66666], + [0, 0, 0], + [0, 0, 0], + ] + ], + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + ) + else: + net1 = BasicModel_MultiLayer() + + def forward_func(*args, **kwargs): + net_output = net1(*args, **kwargs) + batch_size = net_output.size(0) + constant = torch.ones(batch_size, 2) + output = torch.cat( + [ + net_output, + constant, + ], + dim=-1, + ) + return output - output = torch.cat( + # return shape (batch size, 4) + self._shapley_test_assert( + forward_func, + inp, [ - net_output, - constant, + [ + [76.66666, 196.66666, 116.66666], + [76.66666, 196.66666, 116.66666], + [0, 0, 0], + [0, 0, 0], + ] ], - dim=-1, + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, ) - return output + @parameterized.expand([True, False]) + def test_shapley_sampling_multi_task_output_with_mask(self, use_future) -> None: + # return shape (batch size, 2) inp = torch.tensor([[20.0, 50.0, 30.0], [20.0, 50.0, 30.0]], requires_grad=True) mask = torch.tensor([[1, 1, 0], [0, 1, 1]]) + if use_future: + net1 = BasicModel_MultiLayer_with_Future() + + # return shape (batch size, 4) + def forward_func(*args, **kwargs): + net_output = net1(*args, **kwargs) + net_output.wait() + batch_size = net_output.value().size(0) + constant = torch.ones(batch_size, 1) + + output = torch.cat( + [ + net_output.value(), + constant, + ], + dim=-1, + ) + fut = Future() + fut.set_result(output) + return fut - self._shapley_test_assert( - forward_func, - inp, - [ + self._shapley_test_assert_future( + forward_func, + inp, [ - [275.0, 275.0, 115.0], - [275.0, 275.0, 115.0], - [0, 0, 0], + [ + [275.0, 275.0, 115.0], + [275.0, 275.0, 115.0], + [0, 0, 0], + ], + [ + [75.0, 315.0, 315.0], + [75.0, 315.0, 315.0], + [0, 0, 0], + ], ], + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + feature_mask=mask, + ) + else: + + net1 = BasicModel_MultiLayer() + + # return shape (batch size, 4) + def forward_func(*args, **kwargs): + net_output = net1(*args, **kwargs) + batch_size = net_output.size(0) + constant = torch.ones(batch_size, 1) + + output = torch.cat( + [ + net_output, + constant, + ], + dim=-1, + ) + return output + + self._shapley_test_assert( + forward_func, + inp, [ - [75.0, 315.0, 315.0], - [75.0, 315.0, 315.0], - [0, 0, 0], + [ + [275.0, 275.0, 115.0], + [275.0, 275.0, 115.0], + [0, 0, 0], + ], + [ + [75.0, 315.0, 315.0], + [75.0, 315.0, 315.0], + [0, 0, 0], + ], ], - ], - target=None, # no target, multi-task output for all classes - perturbations_per_eval=(1, 2, 3), - n_samples=150, - test_true_shapley=True, - feature_mask=mask, - ) + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + feature_mask=mask, + ) # Remaining tests are for cases where forward function returns a scalar # per batch, as either a float, integer, 0d tensor or 1d tensor. @@ -551,7 +678,7 @@ def _shapley_test_assert_future( ) if test_true_shapley: shapley_val = ShapleyValues(model) - attributions = shapley_val.attribute( + attributions = shapley_val.attribute_future( test_input, target=target, feature_mask=feature_mask, @@ -560,8 +687,9 @@ def _shapley_test_assert_future( perturbations_per_eval=batch_size, show_progress=show_progress, ) + attributions.wait() assertTensorTuplesAlmostEqual( - self, attributions, expected_attr, mode="max", delta=0.001 + self, attributions.value(), expected_attr, mode="max", delta=0.001 )