From 5d93e40c34e9ac3f6d2fb63ddf75a46b7b626a15 Mon Sep 17 00:00:00 2001 From: Jason Cho Date: Wed, 22 Jan 2025 09:05:05 -0800 Subject: [PATCH 1/2] Adding async future functionality to ShapleyValues (#1487) Summary: This diff implements the attribute_future method for the ShapleyValueSampling class. Reviewed By: cyrjano Differential Revision: D68158802 --- captum/_utils/exceptions.py | 8 + captum/attr/_core/shapley_value.py | 352 ++++++++++++++++++++++++- captum/testing/helpers/basic_models.py | 56 ++++ tests/attr/test_shapley.py | 68 ++++- 4 files changed, 467 insertions(+), 17 deletions(-) diff --git a/captum/_utils/exceptions.py b/captum/_utils/exceptions.py index b952d37406..f548ba2075 100644 --- a/captum/_utils/exceptions.py +++ b/captum/_utils/exceptions.py @@ -9,3 +9,11 @@ class FeatureAblationFutureError(Exception): FeatureAblation attribution call""" pass + + +class ShapleyValueFutureError(Exception): + """This custom error is raised when an error + occurs within the callback chain of a + ShapleyValue attribution call""" + + pass diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py index 83f1811aed..ca7f6f7e98 100644 --- a/captum/attr/_core/shapley_value.py +++ b/captum/attr/_core/shapley_value.py @@ -5,7 +5,7 @@ import itertools import math import warnings -from typing import Callable, cast, Iterable, Optional, Sequence, Tuple, Union +from typing import Callable, cast, Iterable, List, Optional, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -20,6 +20,7 @@ _is_tuple, _run_forward, ) +from captum._utils.exceptions import ShapleyValueFutureError from captum._utils.progress import progress from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import PerturbationAttribution @@ -29,7 +30,8 @@ _tensorize_baseline, ) from captum.log import log_usage -from torch import dtype, Tensor +from torch import dtype, Size, Tensor +from torch.futures import collect_all, Future def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]: @@ -394,7 +396,6 @@ def attribute( ) if show_progress: attr_progress.update() - if agg_output_mode: eval_diff = modified_eval - prev_results prev_results = modified_eval @@ -438,7 +439,6 @@ def attribute( # (*output_shape, *input_feature_shape) total_attrib[j] += cur_attr - if show_progress: attr_progress.close() @@ -452,15 +452,318 @@ def attribute( # `Tuple[Tensor, ...]`. return formatted_attr - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Optional[Tuple[object, ...]] = None, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, + n_samples: int = 25, + perturbations_per_eval: int = 1, + show_progress: bool = False, + ) -> Future[TensorOrTupleOfTensorsGeneric]: r""" This method is not implemented for ShapleyValueSampling. """ - raise NotImplementedError( - "attribute_future is not implemented for ShapleyValueSampling" + is_inputs_tuple = _is_tuple(inputs) + inputs_tuple, baselines = _format_input_baseline(inputs, baselines) + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + formatted_feature_mask = _format_feature_mask(feature_mask, inputs_tuple) + reshaped_feature_mask = _shape_feature_mask( + formatted_feature_mask, inputs_tuple ) + assert ( + isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 + ), "Ablations per evaluation must be at least 1." + + with torch.no_grad(): + baselines = _tensorize_baseline(inputs_tuple, baselines) + num_examples = inputs_tuple[0].shape[0] + + total_features = _get_max_feature_index(reshaped_feature_mask) + 1 + + if show_progress: + attr_progress = progress( + desc=f"{self.get_name()} attribution", + total=self._get_n_evaluations( + total_features, n_samples, perturbations_per_eval + ) + + 1, # add 1 for the initial eval + ) + attr_progress.update(0) + + initial_eval: Future[Tensor] = self._strict_run_forward_future( + self.forward_func, baselines, target, additional_forward_args + ) + + if show_progress: + attr_progress.update() + + prev_result_tuple: Future[ + Tuple[Tensor, Tensor, Size, List[Tensor], bool] + ] = initial_eval.then( + lambda inp=initial_eval: self._initialEvalToPrevResultsTuple( # type: ignore # noqa: E501 line too long + inp, + num_examples, + perturbations_per_eval, + reshaped_feature_mask, + inputs_tuple, + ) + ) + + iter_count = 0 + # Iterate for number of samples, generate a permutation of the features + # and evalute the incremental increase for each feature. + for feature_permutation in self.permutation_generator( + total_features, n_samples + ): + prev_result_tuple = prev_result_tuple.then( + lambda inp=prev_result_tuple: self._setPrevResultsToInitialEval(inp) # type: ignore # noqa: E501 line too long + ) + + iter_count += 1 + for ( + current_inputs, + current_add_args, + current_target, + current_masks, + ) in self._perturbation_generator( + inputs_tuple, + additional_forward_args, + target, + baselines, + reshaped_feature_mask, + feature_permutation, + perturbations_per_eval, + ): + if sum(torch.sum(mask).item() for mask in current_masks) == 0: + warnings.warn( + "Feature mask is missing some integers between 0 and " + "num_features, for optimal performance, make sure each" + " consecutive integer corresponds to a feature.", + stacklevel=1, + ) + # modified_eval dimensions: 1D tensor with length + # equal to #num_examples * #features in batch + modified_eval = self._strict_run_forward_future( + self.forward_func, + current_inputs, + current_target, + current_add_args, + ) + if show_progress: + attr_progress.update() + + assert isinstance(modified_eval, torch.Future), ( + "when using futures method, modified_eval should have " + f"Future type rather than {type(modified_eval)}" + ) + eval_futs: Future[ + List[ + Future[ + Union[ + Tuple[Tensor, Tensor, Size, List[Tensor], bool], + Tensor, + ] + ] + ] + ] = collect_all([prev_result_tuple, modified_eval]) + + prev_result_tuple = eval_futs.then( + lambda evals=eval_futs, masks=current_masks: self._evalFutToPrevResultsTuple( # type: ignore # noqa: E501 line too long + evals, num_examples, inputs_tuple, masks + ) + ) + + if show_progress: + attr_progress.close() + + # Divide total attributions by number of random permutations and return + # formatted attributions. + formatted_attr: Future[Union[Tensor, tuple[Tensor, ...]]] = ( + prev_result_tuple.then( + lambda inp=prev_result_tuple: self._prevResultTupleToFormattedAttr( # type: ignore # noqa: E501 line too long + inp, iter_count, is_inputs_tuple + ) + ) + ) + # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got + # `Tuple[Tensor, ...]`. + return formatted_attr # type: ignore + + def _initialEvalToPrevResultsTuple( + self, + initial_eval: Future[Tensor], + num_examples: int, + perturbations_per_eval: int, + reshaped_feature_mask: TensorOrTupleOfTensorsGeneric, + inputs_tuple: Tuple[Tensor, ...], + ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]: + """Since the initial eval is a Future, it is easier to bundle the prev_result, + agg_output_mode, output_shape, and total_attrib together + as Shapley Value Feature Attributions are being calculated""" + try: + initial_eval_processed = initial_eval.value() + prev_result = initial_eval_processed + if not isinstance(initial_eval_processed, Tensor): + raise AssertionError( + "initial_eval_to_processed_initial_eval_fut: " + "initial_eval should be a Tensor" + ) + agg_output_mode = _find_output_mode_and_verify( + initial_eval_processed, + num_examples, + perturbations_per_eval, + reshaped_feature_mask, + allow_multi_outputs=True, + ) + output_shape = initial_eval_processed.shape + total_attrib: List[Tensor] = [ + torch.zeros( + tuple(output_shape) + tuple(input.shape[1:]), + dtype=torch.float, + device=inputs_tuple[0].device, + ) + for input in inputs_tuple + ] + result = ( + initial_eval_processed, + prev_result, + output_shape, + total_attrib, + agg_output_mode, + ) + except ShapleyValueFutureError as e: + raise ShapleyValueFutureError( + "_initial_eval_to_prev_results_tuple func failed" + ) from e + return result + + def _setPrevResultsToInitialEval( + self, + processed_initial_eval: Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]], + ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]: + """At the beginning of each feature permutation, the prev_results is + reset to the initial eval, and this method helps set that up""" + (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode) = ( + processed_initial_eval.value() + ) + prev_results = initial_eval + return (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode) + + def _evalFutToPrevResultsTuple( + self, + eval_futs: Future[ + List[ + Union[ + Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]], + Future[Tensor], + ] + ] + ], + num_examples: int, + inputs_tuple: Tuple[Tensor, ...], + current_masks: Tuple[Tensor, ...], + ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]: + """Helper method responsible for calculating + eval differences between the modified eval and prev_results + Tensor and storing them in total_attrib. Returns prev_results_tuple + with modified total_attrib and prev_results""" + prev_results_tuple = eval_futs.value()[0].value() + modified_eval = eval_futs.value()[1].value() + if not isinstance(modified_eval, Tensor) or not isinstance( + prev_results_tuple, tuple + ): + raise ShapleyValueFutureError( + "_eval_fut_to_prev_results_tuple func failed due to type mismatch" + ) + ( + initial_eval, + prev_results, + output_shape, + total_attrib, + agg_output_mode, + ) = prev_results_tuple + if agg_output_mode: + eval_diff = modified_eval - prev_results + prev_results = modified_eval + else: + # when perturb_per_eval > 1, every num_examples stands for + # one perturb. Since the perturbs are from a consecutive + # perumuation, each diff of a perturb is its eval minus + # the eval of the previous perturb + + all_eval = torch.cat((prev_results, modified_eval), dim=0) + eval_diff = all_eval[num_examples:] - all_eval[:-num_examples] + prev_results = all_eval[-num_examples:] + + for j in range(len(total_attrib)): + # format eval_diff to shape + # (n_perturb, *output_shape, 1,.. 1) + # where n_perturb may not be perturb_per_eval + # Append n_input_feature dim of 1 to make the tensor + # have the same dim as the mask tensor. + formatted_eval_diff = eval_diff.reshape( + (-1,) + tuple(output_shape) + (len(inputs_tuple[j].shape) - 1) * (1,) + ) + + # mask in shape (n_perturb, *mask_shape_broadcastable_to_input) + # reshape to + # ( + # n_perturb, + # *broadcastable_to_output_shape + # *broadcastable_to_input_feature_shape + # ) + cur_mask = current_masks[j] + cur_mask = cur_mask.reshape( + tuple(cur_mask.shape[:2]) + + (len(output_shape) - 1) * (1,) + + tuple(cur_mask.shape[2:]) + ) + + # aggregate n_perturb + cur_attr = (formatted_eval_diff * cur_mask.float()).sum(dim=0) + # (*output_shape, *input_feature_shape) + total_attrib[j] += cur_attr + + result = ( + initial_eval, + prev_results, + output_shape, + total_attrib, + agg_output_mode, + ) + return result + + def _prevResultTupleToFormattedAttr( + self, + prev_result_tuple: Future[ + Tuple[Tensor, Tensor, Tuple[int], List[Tensor], bool] + ], + iter_count: int, + is_inputs_tuple: bool, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + """Helper method to format total_attrib, which is a + list of tensors, into formatted attributions, which + are either a single tensor or a tuple of tensors""" + + ( + _, + _, + _, + total_attrib, + _, + ) = prev_result_tuple.value() + attrib = tuple( + tensor_attrib_total / iter_count for tensor_attrib_total in total_attrib + ) + formatted_attr = _format_output(is_inputs_tuple, attrib) + return formatted_attr + def _perturbation_generator( self, inputs: Tuple[Tensor, ...], @@ -574,6 +877,39 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor: # ref: https://github.com/pytorch/pytorch/pull/21215 return torch.tensor([forward_output], dtype=cast(dtype, output_type)) + # pyre-fixme[2]: Parameter must be annotated. + def _strict_run_forward_future(self, *args, **kwargs) -> Future[Tensor]: + """ + A temp wrapper for global _run_forward util to force + forward outputtype assertion & conversion, but takes + into account the Future tensor type + """ + + def process_strict_run_forward(fut: Future[Tensor]) -> Tensor: + output = fut.value() + if isinstance(output, Tensor): + # format scalar to shape (1) so we can always + # assume non-empty output_shape + if not output.shape: + output = output.reshape(1) + return output + output_type = type(output) + assert output_type is int or output_type is float, ( + "the return of forward_func must be a Future of tensor, int, or float," + f" received: {output_type}" + ) + output = torch.tensor([output], dtype=cast(dtype, output_type)) + return output + + forward_output = _run_forward(*args, **kwargs) + assert isinstance(forward_output, torch.Future), ( + "The return type of forward_func must be a Future" + f" received: {type(forward_output)}" + ) + + return_output = forward_output.then(process_strict_run_forward) + return return_output + class ShapleyValues(ShapleyValueSampling): """ diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 8c4685f752..8d22adafc3 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -8,6 +8,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from torch.futures import Future """ @no_type_check annotation is applied to type-hinted models to avoid errors @@ -477,6 +478,61 @@ def forward( return lin2_out +class BasicModel_MultiLayer_with_Future(nn.Module): + # This model is used to test the case where the model returns a future + def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None: + super().__init__() + # Linear 0 is simply identity transform + self.multi_input_module = multi_input_module + self.linear0 = nn.Linear(3, 3) + self.linear0.weight = nn.Parameter(torch.eye(3)) + self.linear0.bias = nn.Parameter(torch.zeros(3)) + self.linear1 = nn.Linear(3, 4) + self.linear1.weight = nn.Parameter(torch.ones(4, 3)) + self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0])) + + self.linear1_alt = nn.Linear(3, 4) + self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3)) + self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0])) + self.multi_relu = MultiRelu(inplace=inplace) + self.relu = nn.ReLU(inplace=inplace) + + self.linear2 = nn.Linear(4, 2) + self.linear2.weight = nn.Parameter(torch.ones(2, 4)) + self.linear2.bias = nn.Parameter(torch.tensor([-1.0, 1.0])) + + @no_type_check + # pyre-fixme[3]: Return type must be annotated. + def forward( + self, + x: Tensor, + add_input: Optional[Tensor] = None, + multidim_output: bool = False, + ): + input = x if add_input is None else x + add_input + lin0_out = self.linear0(input) + lin1_out = self.linear1(lin0_out) + if self.multi_input_module: + relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input)) + relu_out = relu_out1 + relu_out2 + # relu is not used when multi_input_module set to True, + # so this is to set an unsued layer intentionally for testing + # and it won't be part of return + self.relu(lin1_out) + else: + relu_out = self.relu(lin1_out) + # pyre-fixme [29]: `typing.Type[Future]` is not a function + result = Future() + lin2_out = self.linear2(relu_out) + if multidim_output: + stack_mid = torch.stack((lin2_out, 2 * lin2_out), dim=2) + result.set_result(torch.stack((stack_mid, 4 * stack_mid), dim=3)) + return result + else: + result.set_result(lin2_out) + return result + + class BasicModelBoolInput(nn.Module): def __init__(self) -> None: super().__init__() diff --git a/tests/attr/test_shapley.py b/tests/attr/test_shapley.py index 976adc55f2..b88b478b60 100644 --- a/tests/attr/test_shapley.py +++ b/tests/attr/test_shapley.py @@ -14,6 +14,7 @@ from captum.testing.helpers.basic_models import ( BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, + BasicModel_MultiLayer_with_Future, BasicModelBoolInput, ) @@ -30,6 +31,17 @@ def test_simple_shapley_sampling(self) -> None: n_samples=250, ) + def test_simple_shapley_sampling_future(self) -> None: + net = BasicModel_MultiLayer_with_Future() + inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) + self._shapley_test_assert_future( + net, + inp, + [[76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=250, + ) + def test_simple_shapley_sampling_with_mask(self) -> None: net = BasicModel_MultiLayer() inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) @@ -388,15 +400,6 @@ def test_shapley_sampling_with_mask_and_show_progress(self, mock_stderr) -> None mock_stderr.seek(0) mock_stderr.truncate(0) - def test_futures_not_implemented(self) -> None: - net = BasicModel_MultiLayer() - - attributions = None - shapley_samp = ShapleyValueSampling(net) - with self.assertRaises(NotImplementedError): - attributions = shapley_samp.attribute_future() - self.assertEqual(attributions, None) - def _single_input_one_sample_batch_scalar_shapley_assert( self, func: Callable ) -> None: @@ -514,6 +517,53 @@ def _shapley_test_assert( self, attributions, expected_attr, mode="max", delta=0.001 ) + def _shapley_test_assert_future( + self, + model: Callable, + test_input: TensorOrTupleOfTensorsGeneric, + expected_attr, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, + additional_input: Any = None, + perturbations_per_eval: Tuple[int, ...] = (1,), + baselines: BaselineType = None, + target: Union[None, int] = 0, + n_samples: int = 100, + delta: float = 1.0, + # leaving this false as it is not supported for future + test_true_shapley: bool = False, + show_progress: bool = False, + ) -> None: + for batch_size in perturbations_per_eval: + shapley_samp = ShapleyValueSampling(model) + attributions = shapley_samp.attribute_future( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + n_samples=n_samples, + show_progress=show_progress, + ) + attributions.wait() + assertTensorTuplesAlmostEqual( + self, attributions.value(), expected_attr, delta=delta, mode="max" + ) + if test_true_shapley: + shapley_val = ShapleyValues(model) + attributions = shapley_val.attribute( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + show_progress=show_progress, + ) + assertTensorTuplesAlmostEqual( + self, attributions, expected_attr, mode="max", delta=0.001 + ) + if __name__ == "__main__": unittest.main() From a123ac7d83ed7057f0dbd40e829a82c5f5f425ce Mon Sep 17 00:00:00 2001 From: Jason Cho Date: Wed, 22 Jan 2025 09:05:05 -0800 Subject: [PATCH 2/2] Add test coverage for futures in ShapleyValueSampling 1/n (#1491) Summary: This diffs adds more testing coverage for attribute_future to ShapleyValueSampling unit tests that handle simple shapley sampling and multi shapley sampling using the BasicModel_MultiLayer model note: decided to parameterize the high-level unit tests since I did not want to clutter the unit test file with tests which were doing the same calculations but with the pytorch futures api. Separated the model and shapley assert methods between non-future vs future so that the differences between the two would be clear Reviewed By: cyrjano Differential Revision: D68229981 --- tests/attr/test_shapley.py | 374 +++++++++++++++++++++++++------------ 1 file changed, 251 insertions(+), 123 deletions(-) diff --git a/tests/attr/test_shapley.py b/tests/attr/test_shapley.py index b88b478b60..b0292a7da8 100644 --- a/tests/attr/test_shapley.py +++ b/tests/attr/test_shapley.py @@ -17,41 +17,54 @@ BasicModel_MultiLayer_with_Future, BasicModelBoolInput, ) +from parameterized import parameterized +from torch.futures import Future class Test(BaseTest): - def test_simple_shapley_sampling(self) -> None: - net = BasicModel_MultiLayer() - inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - self._shapley_test_assert( - net, - inp, - [[76.66666, 196.66666, 116.66666]], - perturbations_per_eval=(1, 2, 3), - n_samples=250, - ) - - def test_simple_shapley_sampling_future(self) -> None: - net = BasicModel_MultiLayer_with_Future() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - self._shapley_test_assert_future( - net, - inp, - [[76.66666, 196.66666, 116.66666]], - perturbations_per_eval=(1, 2, 3), - n_samples=250, - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=250, + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=250, + ) - def test_simple_shapley_sampling_with_mask(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_with_mask(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) - self._shapley_test_assert( - net, - inp, - [[275.0, 275.0, 115.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[275.0, 275.0, 115.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[275.0, 275.0, 115.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + perturbations_per_eval=(1, 2, 3), + ) def test_simple_shapley_sampling_boolean(self) -> None: net = BasicModelBoolInput() @@ -76,40 +89,74 @@ def test_simple_shapley_sampling_boolean_with_baseline(self) -> None: perturbations_per_eval=(1, 2, 3), ) - def test_simple_shapley_sampling_with_baselines(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_simple_shapley_sampling_with_baselines(self, use_future) -> None: inp = torch.tensor([[20.0, 50.0, 30.0]]) - self._shapley_test_assert( - net, - inp, - [[248.0, 248.0, 104.0]], - feature_mask=torch.tensor([[0, 0, 1]]), - baselines=4, - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[248.0, 248.0, 104.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=4, + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[248.0, 248.0, 104.0]], + feature_mask=torch.tensor([[0, 0, 1]]), + baselines=4, + perturbations_per_eval=(1, 2, 3), + ) - def test_multi_sample_shapley_sampling(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_multi_sample_shapley_sampling(self, use_future) -> None: inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]]) - self._shapley_test_assert( - net, - inp, - [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], - perturbations_per_eval=(1, 2, 3), - n_samples=200, - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=200, + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[7.0, 32.5, 10.5], [76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=200, + ) - def test_multi_sample_shapley_sampling_with_mask(self) -> None: - net = BasicModel_MultiLayer() + @parameterized.expand([True, False]) + def test_multi_sample_shapley_sampling_with_mask(self, use_future) -> None: inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]], requires_grad=True) mask = torch.tensor([[0, 0, 1], [1, 1, 0]]) - self._shapley_test_assert( - net, - inp, - [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], - feature_mask=mask, - perturbations_per_eval=(1, 2, 3), - ) + if use_future: + net_fut = BasicModel_MultiLayer_with_Future() + self._shapley_test_assert_future( + net_fut, + inp, + [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], + feature_mask=mask, + perturbations_per_eval=(1, 2, 3), + ) + else: + net = BasicModel_MultiLayer() + self._shapley_test_assert( + net, + inp, + [[39.5, 39.5, 10.5], [275.0, 275.0, 115.0]], + feature_mask=mask, + perturbations_per_eval=(1, 2, 3), + ) def test_multi_input_shapley_sampling_without_mask(self) -> None: net = BasicModel_MultiLayer_MultiInput() @@ -165,86 +212,166 @@ def test_multi_input_shapley_sampling_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) - def test_shapley_sampling_multi_task_output(self) -> None: + @parameterized.expand([True, False]) + def test_shapley_sampling_multi_task_output(self, use_future) -> None: # return shape (batch size, 2) - net1 = BasicModel_MultiLayer() - - # return shape (batch size, 4) - def forward_func(*args, **kwargs): - net_output = net1(*args, **kwargs) - batch_size = net_output.size(0) - constant = torch.ones(batch_size, 2) - output = torch.cat( - [ - net_output, - constant, - ], - dim=-1, - ) - return output - inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) + if use_future: + net1_fut = BasicModel_MultiLayer_with_Future() + + def forward_func(*args, **kwargs): + net_output = net1_fut(*args, **kwargs) + net_output.wait() + batch_size = net_output.value().size(0) + constant = torch.ones(batch_size, 2) + output = torch.cat( + [ + net_output.value(), + constant, + ], + dim=-1, + ) + fut = Future() + fut.set_result(output) + return fut - self._shapley_test_assert( - forward_func, - inp, - [ + self._shapley_test_assert_future( + forward_func, + inp, [ - [76.66666, 196.66666, 116.66666], - [76.66666, 196.66666, 116.66666], - [0, 0, 0], - [0, 0, 0], - ] - ], - target=None, # no target, multi-task output for all classes - perturbations_per_eval=(1, 2, 3), - n_samples=150, - test_true_shapley=True, - ) - - def test_shapley_sampling_multi_task_output_with_mask(self) -> None: - # return shape (batch size, 2) - net1 = BasicModel_MultiLayer() - - # return shape (batch size, 4) - def forward_func(*args, **kwargs): - net_output = net1(*args, **kwargs) - batch_size = net_output.size(0) - constant = torch.ones(batch_size, 1) + [ + [76.66666, 196.66666, 116.66666], + [76.66666, 196.66666, 116.66666], + [0, 0, 0], + [0, 0, 0], + ] + ], + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + ) + else: + net1 = BasicModel_MultiLayer() + + def forward_func(*args, **kwargs): + net_output = net1(*args, **kwargs) + batch_size = net_output.size(0) + constant = torch.ones(batch_size, 2) + output = torch.cat( + [ + net_output, + constant, + ], + dim=-1, + ) + return output - output = torch.cat( + # return shape (batch size, 4) + self._shapley_test_assert( + forward_func, + inp, [ - net_output, - constant, + [ + [76.66666, 196.66666, 116.66666], + [76.66666, 196.66666, 116.66666], + [0, 0, 0], + [0, 0, 0], + ] ], - dim=-1, + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, ) - return output + @parameterized.expand([True, False]) + def test_shapley_sampling_multi_task_output_with_mask(self, use_future) -> None: + # return shape (batch size, 2) inp = torch.tensor([[20.0, 50.0, 30.0], [20.0, 50.0, 30.0]], requires_grad=True) mask = torch.tensor([[1, 1, 0], [0, 1, 1]]) + if use_future: + net1_fut = BasicModel_MultiLayer_with_Future() + + # return shape (batch size, 4) + def forward_func(*args, **kwargs): + net_output = net1_fut(*args, **kwargs) + net_output.wait() + batch_size = net_output.value().size(0) + constant = torch.ones(batch_size, 1) + + output = torch.cat( + [ + net_output.value(), + constant, + ], + dim=-1, + ) + fut = Future() + fut.set_result(output) + return fut - self._shapley_test_assert( - forward_func, - inp, - [ + self._shapley_test_assert_future( + forward_func, + inp, [ - [275.0, 275.0, 115.0], - [275.0, 275.0, 115.0], - [0, 0, 0], + [ + [275.0, 275.0, 115.0], + [275.0, 275.0, 115.0], + [0, 0, 0], + ], + [ + [75.0, 315.0, 315.0], + [75.0, 315.0, 315.0], + [0, 0, 0], + ], ], + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + feature_mask=mask, + ) + else: + + net1 = BasicModel_MultiLayer() + + # return shape (batch size, 4) + def forward_func(*args, **kwargs): + net_output = net1(*args, **kwargs) + batch_size = net_output.size(0) + constant = torch.ones(batch_size, 1) + + output = torch.cat( + [ + net_output, + constant, + ], + dim=-1, + ) + return output + + self._shapley_test_assert( + forward_func, + inp, [ - [75.0, 315.0, 315.0], - [75.0, 315.0, 315.0], - [0, 0, 0], + [ + [275.0, 275.0, 115.0], + [275.0, 275.0, 115.0], + [0, 0, 0], + ], + [ + [75.0, 315.0, 315.0], + [75.0, 315.0, 315.0], + [0, 0, 0], + ], ], - ], - target=None, # no target, multi-task output for all classes - perturbations_per_eval=(1, 2, 3), - n_samples=150, - test_true_shapley=True, - feature_mask=mask, - ) + target=None, # no target, multi-task output for all classes + perturbations_per_eval=(1, 2, 3), + n_samples=150, + test_true_shapley=True, + feature_mask=mask, + ) # Remaining tests are for cases where forward function returns a scalar # per batch, as either a float, integer, 0d tensor or 1d tensor. @@ -551,7 +678,7 @@ def _shapley_test_assert_future( ) if test_true_shapley: shapley_val = ShapleyValues(model) - attributions = shapley_val.attribute( + attributions = shapley_val.attribute_future( test_input, target=target, feature_mask=feature_mask, @@ -560,8 +687,9 @@ def _shapley_test_assert_future( perturbations_per_eval=batch_size, show_progress=show_progress, ) + attributions.wait() assertTensorTuplesAlmostEqual( - self, attributions, expected_attr, mode="max", delta=0.001 + self, attributions.value(), expected_attr, mode="max", delta=0.001 )