From 223afa3da3cdceee4c0eef631c0c1168f80ff71c Mon Sep 17 00:00:00 2001 From: Jason Cho Date: Tue, 21 Jan 2025 13:35:55 -0800 Subject: [PATCH] Add test coverage for futures in ShapleyValueSampling 1/n (#1491) Summary: This diffs adds more testing coverage for attribute_future to ShapleyValueSampling unit tests that handle simple shapley sampling and multi shapley sampling using the BasicModel_MultiLayer model note: decided to parameterize the high-level unit tests since I did not want to clutter the unit test file with tests which were doing the same calculations but with the pytorch futures api. Separated the model and shapley assert methods between non-future vs future so that the differences between the two would be clear Reviewed By: cyrjano Differential Revision: D68229981 --- tests/attr/test_shapley.py | 374 +++++++++++++++++++++++++------------ 1 file changed, 251 insertions(+), 123 deletions(-) diff --git a/tests/attr/test_shapley.py b/tests/attr/test_shapley.py index b88b478b6..b0292a7da 100644 --- a/tests/attr/test_shapley.py +++ b/tests/attr/test_shapley.py @@ -17,41 +17,54 @@ BasicModel_MultiLayer_with_Future, BasicModelBoolInput, ) +from parameterized import parameterized +from torch.futures import Future class Test(BaseTest): - def test_simple_shapley_sampling(self) -> None: - net = BasicModel_MultiLayer() - inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - self._shapley_test_assert( - net, - inp, - [[76.66666, 196.66666, 116.66666]], - perturbations_per_eval=(1, 2, 3), - n_samples=250, - ) - - def test_simple_shapley_sampling_future(self) -> None: - net = BasicModel_MultiLayer_with_Future() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - self._shapley_test_assert_future( - net, - inp, - [[76.66666, 196.66666, 116.66666]], - perturbations_per_eval=(1, 2, 3), - n_samples=250, - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=250, + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=250, + ) - def test_simple_shapley_sampling_with_mask(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_with_mask(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - self._shapley_test_assert( - net, - inp, - [[275.0, 275.0, 115.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[275.0, 275.0, 115.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[275.0, 275.0, 115.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) def test_simple_shapley_sampling_boolean(self) -> None: net = BasicModelBoolInput() @@ -76,40 +89,74 @@ def test_simple_shapley_sampling_boolean_with_baseline(self) -> None: perturbations_per_eval=(1, 2, 3), ) - def test_simple_shapley_sampling_with_baselines(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_with_baselines(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]]) - self._shapley_test_assert( - net, - inp, - [[248.0, 248.0, 104.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - baselines=4, - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[248.0, 248.0, 104.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=4, + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[248.0, 248.0, 104.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=4, + perturbations_per_eval=(1, 2, 3), + ) - def test_multi_sample_shapley_sampling(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_multi_sample_shapley_sampling(self, use_future) -> None: inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]]) - self._shapley_test_assert( - net, - inp, - [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], - perturbations_per_eval=(1, 2, 3), - n_samples=200, - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=200, + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=200, + ) - def test_multi_sample_shapley_sampling_with_mask(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_multi_sample_shapley_sampling_with_mask(self, use_future) -> None: inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]], requires_grad=True) mask = torch.tensor([[0, 0, 1], [1, 1, 0]]) - self._shapley_test_assert( - net, - inp, - [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], - feature_mask=mask, - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], + feature_mask=mask, + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], + feature_mask=mask, + perturbations_per_eval=(1, 2, 3), + ) def test_multi_input_shapley_sampling_without_mask(self) -> None: net = BasicModel_MultiLayer_MultiInput() @@ -165,86 +212,166 @@ def test_multi_input_shapley_sampling_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) - def test_shapley_sampling_multi_task_output(self) -> None: + @parameterized.expand([True, False]) + def test_shapley_sampling_multi_task_output(self, use_future) -> None: # return shape (batch size, 2) - net1 = BasicModel_MultiLayer() - - # return shape (batch size, 4) - def forward_func(*args, **kwargs): - net_output = net1(*args, **kwargs) - batch_size = net_output.size(0) - constant = torch.ones(batch_size, 2) - output = torch.cat( - [ - net_output, - constant, - ], - dim=-1, - ) - return output - inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) + if use_future: + net1_fut = BasicModel_MultiLayer_with_Future() + + def forward_func(*args, **kwargs): + net_output = net1_fut(*args, **kwargs) + net_output.wait() + batch_size = net_output.value().size(0) + constant = torch.ones(batch_size, 2) + output = torch.cat( + [ + net_output.value(), + constant, + ], + dim=-1, + ) + fut = Future() + fut.set_result(output) + return fut - self._shapley_test_assert( - forward_func, - inp, - [ + self._shapley_test_assert_future( + forward_func, + inp, [ - [76.66666, 196.66666, 116.66666], - [76.66666, 196.66666, 116.66666], - [0, 0, 0], - [0, 0, 0], - ] - ], - target=None, # no target, multi-task output for all classes - perturbations_per_eval=(1, 2, 3), - n_samples=150, - test_true_shapley=True, - ) - - def test_shapley_sampling_multi_task_output_with_mask(self) -> None: - # return shape (batch size, 2) - net1 = BasicModel_MultiLayer() - - # return shape (batch size, 4) - def forward_func(*args, **kwargs): - net_output = net1(*args, **kwargs) - batch_size = net_output.size(0) - constant = torch.ones(batch_size, 1) + [ + [76.66666, 196.66666, 116.66666], + [76.66666, 196.66666, 116.66666], + [0, 0, 0], + [0, 0, 0], + ] + ], + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + ) + else: + net1 = BasicModel_MultiLayer() + + def forward_func(*args, **kwargs): + net_output = net1(*args, **kwargs) + batch_size = net_output.size(0) + constant = torch.ones(batch_size, 2) + output = torch.cat( + [ + net_output, + constant, + ], + dim=-1, + ) + return output - output = torch.cat( + # return shape (batch size, 4) + self._shapley_test_assert( + forward_func, + inp, [ - net_output, - constant, + [ + [76.66666, 196.66666, 116.66666], + [76.66666, 196.66666, 116.66666], + [0, 0, 0], + [0, 0, 0], + ] ], - dim=-1, + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, ) - return output + @parameterized.expand([True, False]) + def test_shapley_sampling_multi_task_output_with_mask(self, use_future) -> None: + # return shape (batch size, 2) inp = torch.tensor([[20.0, 50.0, 30.0], [20.0, 50.0, 30.0]], requires_grad=True) mask = torch.tensor([[1, 1, 0], [0, 1, 1]]) + if use_future: + net1_fut = BasicModel_MultiLayer_with_Future() + + # return shape (batch size, 4) + def forward_func(*args, **kwargs): + net_output = net1_fut(*args, **kwargs) + net_output.wait() + batch_size = net_output.value().size(0) + constant = torch.ones(batch_size, 1) + + output = torch.cat( + [ + net_output.value(), + constant, + ], + dim=-1, + ) + fut = Future() + fut.set_result(output) + return fut - self._shapley_test_assert( - forward_func, - inp, - [ + self._shapley_test_assert_future( + forward_func, + inp, [ - [275.0, 275.0, 115.0], - [275.0, 275.0, 115.0], - [0, 0, 0], + [ + [275.0, 275.0, 115.0], + [275.0, 275.0, 115.0], + [0, 0, 0], + ], + [ + [75.0, 315.0, 315.0], + [75.0, 315.0, 315.0], + [0, 0, 0], + ], ], + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + feature_mask=mask, + ) + else: + + net1 = BasicModel_MultiLayer() + + # return shape (batch size, 4) + def forward_func(*args, **kwargs): + net_output = net1(*args, **kwargs) + batch_size = net_output.size(0) + constant = torch.ones(batch_size, 1) + + output = torch.cat( + [ + net_output, + constant, + ], + dim=-1, + ) + return output + + self._shapley_test_assert( + forward_func, + inp, [ - [75.0, 315.0, 315.0], - [75.0, 315.0, 315.0], - [0, 0, 0], + [ + [275.0, 275.0, 115.0], + [275.0, 275.0, 115.0], + [0, 0, 0], + ], + [ + [75.0, 315.0, 315.0], + [75.0, 315.0, 315.0], + [0, 0, 0], + ], ], - ], - target=None, # no target, multi-task output for all classes - perturbations_per_eval=(1, 2, 3), - n_samples=150, - test_true_shapley=True, - feature_mask=mask, - ) + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + feature_mask=mask, + ) # Remaining tests are for cases where forward function returns a scalar # per batch, as either a float, integer, 0d tensor or 1d tensor. @@ -551,7 +678,7 @@ def _shapley_test_assert_future( ) if test_true_shapley: shapley_val = ShapleyValues(model) - attributions = shapley_val.attribute( + attributions = shapley_val.attribute_future( test_input, target=target, feature_mask=feature_mask, @@ -560,8 +687,9 @@ def _shapley_test_assert_future( perturbations_per_eval=batch_size, show_progress=show_progress, ) + attributions.wait() assertTensorTuplesAlmostEqual( - self, attributions, expected_attr, mode="max", delta=0.001 + self, attributions.value(), expected_attr, mode="max", delta=0.001 )