From f258fc87bc9f15a63093c6a7cc2dca5fddb135e0 Mon Sep 17 00:00:00 2001 From: Jason Cho Date: Wed, 22 Jan 2025 08:58:09 -0800 Subject: [PATCH] Adding async future functionality to ShapleyValues (#1487) Summary: This diff implements the attribute_future method for the ShapleyValueSampling class. Reviewed By: cyrjano Differential Revision: D68158802 --- captum/_utils/exceptions.py | 8 + captum/attr/_core/shapley_value.py | 352 ++++++++++++++++++++++++- captum/testing/helpers/basic_models.py | 56 ++++ tests/attr/test_shapley.py | 68 ++++- 4 files changed, 467 insertions(+), 17 deletions(-) diff --git a/captum/_utils/exceptions.py b/captum/_utils/exceptions.py index b952d37406..f548ba2075 100644 --- a/captum/_utils/exceptions.py +++ b/captum/_utils/exceptions.py @@ -9,3 +9,11 @@ class FeatureAblationFutureError(Exception): FeatureAblation attribution call""" pass + + +class ShapleyValueFutureError(Exception): + """This custom error is raised when an error + occurs within the callback chain of a + ShapleyValue attribution call""" + + pass diff --git a/captum/attr/_core/shapley_value.py b/captum/attr/_core/shapley_value.py index 83f1811aed..ca7f6f7e98 100644 --- a/captum/attr/_core/shapley_value.py +++ b/captum/attr/_core/shapley_value.py @@ -5,7 +5,7 @@ import itertools import math import warnings -from typing import Callable, cast, Iterable, Optional, Sequence, Tuple, Union +from typing import Callable, cast, Iterable, List, Optional, Sequence, Tuple, Union import torch from captum._utils.common import ( @@ -20,6 +20,7 @@ _is_tuple, _run_forward, ) +from captum._utils.exceptions import ShapleyValueFutureError from captum._utils.progress import progress from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import PerturbationAttribution @@ -29,7 +30,8 @@ _tensorize_baseline, ) from captum.log import log_usage -from torch import dtype, Tensor +from torch import dtype, Size, Tensor +from torch.futures import collect_all, Future def _all_perm_generator(num_features: int, num_samples: int) -> Iterable[Sequence[int]]: @@ -394,7 +396,6 @@ def attribute( ) if show_progress: attr_progress.update() - if agg_output_mode: eval_diff = modified_eval - prev_results prev_results = modified_eval @@ -438,7 +439,6 @@ def attribute( # (*output_shape, *input_feature_shape) total_attrib[j] += cur_attr - if show_progress: attr_progress.close() @@ -452,15 +452,318 @@ def attribute( # `Tuple[Tensor, ...]`. return formatted_attr - # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. - def attribute_future(self) -> Callable: + def attribute_future( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Optional[Tuple[object, ...]] = None, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, + n_samples: int = 25, + perturbations_per_eval: int = 1, + show_progress: bool = False, + ) -> Future[TensorOrTupleOfTensorsGeneric]: r""" This method is not implemented for ShapleyValueSampling. """ - raise NotImplementedError( - "attribute_future is not implemented for ShapleyValueSampling" + is_inputs_tuple = _is_tuple(inputs) + inputs_tuple, baselines = _format_input_baseline(inputs, baselines) + additional_forward_args = _format_additional_forward_args( + additional_forward_args + ) + formatted_feature_mask = _format_feature_mask(feature_mask, inputs_tuple) + reshaped_feature_mask = _shape_feature_mask( + formatted_feature_mask, inputs_tuple ) + assert ( + isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 + ), "Ablations per evaluation must be at least 1." + + with torch.no_grad(): + baselines = _tensorize_baseline(inputs_tuple, baselines) + num_examples = inputs_tuple[0].shape[0] + + total_features = _get_max_feature_index(reshaped_feature_mask) + 1 + + if show_progress: + attr_progress = progress( + desc=f"{self.get_name()} attribution", + total=self._get_n_evaluations( + total_features, n_samples, perturbations_per_eval + ) + + 1, # add 1 for the initial eval + ) + attr_progress.update(0) + + initial_eval: Future[Tensor] = self._strict_run_forward_future( + self.forward_func, baselines, target, additional_forward_args + ) + + if show_progress: + attr_progress.update() + + prev_result_tuple: Future[ + Tuple[Tensor, Tensor, Size, List[Tensor], bool] + ] = initial_eval.then( + lambda inp=initial_eval: self._initialEvalToPrevResultsTuple( # type: ignore # noqa: E501 line too long + inp, + num_examples, + perturbations_per_eval, + reshaped_feature_mask, + inputs_tuple, + ) + ) + + iter_count = 0 + # Iterate for number of samples, generate a permutation of the features + # and evalute the incremental increase for each feature. + for feature_permutation in self.permutation_generator( + total_features, n_samples + ): + prev_result_tuple = prev_result_tuple.then( + lambda inp=prev_result_tuple: self._setPrevResultsToInitialEval(inp) # type: ignore # noqa: E501 line too long + ) + + iter_count += 1 + for ( + current_inputs, + current_add_args, + current_target, + current_masks, + ) in self._perturbation_generator( + inputs_tuple, + additional_forward_args, + target, + baselines, + reshaped_feature_mask, + feature_permutation, + perturbations_per_eval, + ): + if sum(torch.sum(mask).item() for mask in current_masks) == 0: + warnings.warn( + "Feature mask is missing some integers between 0 and " + "num_features, for optimal performance, make sure each" + " consecutive integer corresponds to a feature.", + stacklevel=1, + ) + # modified_eval dimensions: 1D tensor with length + # equal to #num_examples * #features in batch + modified_eval = self._strict_run_forward_future( + self.forward_func, + current_inputs, + current_target, + current_add_args, + ) + if show_progress: + attr_progress.update() + + assert isinstance(modified_eval, torch.Future), ( + "when using futures method, modified_eval should have " + f"Future type rather than {type(modified_eval)}" + ) + eval_futs: Future[ + List[ + Future[ + Union[ + Tuple[Tensor, Tensor, Size, List[Tensor], bool], + Tensor, + ] + ] + ] + ] = collect_all([prev_result_tuple, modified_eval]) + + prev_result_tuple = eval_futs.then( + lambda evals=eval_futs, masks=current_masks: self._evalFutToPrevResultsTuple( # type: ignore # noqa: E501 line too long + evals, num_examples, inputs_tuple, masks + ) + ) + + if show_progress: + attr_progress.close() + + # Divide total attributions by number of random permutations and return + # formatted attributions. + formatted_attr: Future[Union[Tensor, tuple[Tensor, ...]]] = ( + prev_result_tuple.then( + lambda inp=prev_result_tuple: self._prevResultTupleToFormattedAttr( # type: ignore # noqa: E501 line too long + inp, iter_count, is_inputs_tuple + ) + ) + ) + # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got + # `Tuple[Tensor, ...]`. + return formatted_attr # type: ignore + + def _initialEvalToPrevResultsTuple( + self, + initial_eval: Future[Tensor], + num_examples: int, + perturbations_per_eval: int, + reshaped_feature_mask: TensorOrTupleOfTensorsGeneric, + inputs_tuple: Tuple[Tensor, ...], + ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]: + """Since the initial eval is a Future, it is easier to bundle the prev_result, + agg_output_mode, output_shape, and total_attrib together + as Shapley Value Feature Attributions are being calculated""" + try: + initial_eval_processed = initial_eval.value() + prev_result = initial_eval_processed + if not isinstance(initial_eval_processed, Tensor): + raise AssertionError( + "initial_eval_to_processed_initial_eval_fut: " + "initial_eval should be a Tensor" + ) + agg_output_mode = _find_output_mode_and_verify( + initial_eval_processed, + num_examples, + perturbations_per_eval, + reshaped_feature_mask, + allow_multi_outputs=True, + ) + output_shape = initial_eval_processed.shape + total_attrib: List[Tensor] = [ + torch.zeros( + tuple(output_shape) + tuple(input.shape[1:]), + dtype=torch.float, + device=inputs_tuple[0].device, + ) + for input in inputs_tuple + ] + result = ( + initial_eval_processed, + prev_result, + output_shape, + total_attrib, + agg_output_mode, + ) + except ShapleyValueFutureError as e: + raise ShapleyValueFutureError( + "_initial_eval_to_prev_results_tuple func failed" + ) from e + return result + + def _setPrevResultsToInitialEval( + self, + processed_initial_eval: Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]], + ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]: + """At the beginning of each feature permutation, the prev_results is + reset to the initial eval, and this method helps set that up""" + (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode) = ( + processed_initial_eval.value() + ) + prev_results = initial_eval + return (initial_eval, prev_results, output_shape, total_attrib, agg_output_mode) + + def _evalFutToPrevResultsTuple( + self, + eval_futs: Future[ + List[ + Union[ + Future[Tuple[Tensor, Tensor, Size, List[Tensor], bool]], + Future[Tensor], + ] + ] + ], + num_examples: int, + inputs_tuple: Tuple[Tensor, ...], + current_masks: Tuple[Tensor, ...], + ) -> Tuple[Tensor, Tensor, Size, List[Tensor], bool]: + """Helper method responsible for calculating + eval differences between the modified eval and prev_results + Tensor and storing them in total_attrib. Returns prev_results_tuple + with modified total_attrib and prev_results""" + prev_results_tuple = eval_futs.value()[0].value() + modified_eval = eval_futs.value()[1].value() + if not isinstance(modified_eval, Tensor) or not isinstance( + prev_results_tuple, tuple + ): + raise ShapleyValueFutureError( + "_eval_fut_to_prev_results_tuple func failed due to type mismatch" + ) + ( + initial_eval, + prev_results, + output_shape, + total_attrib, + agg_output_mode, + ) = prev_results_tuple + if agg_output_mode: + eval_diff = modified_eval - prev_results + prev_results = modified_eval + else: + # when perturb_per_eval > 1, every num_examples stands for + # one perturb. Since the perturbs are from a consecutive + # perumuation, each diff of a perturb is its eval minus + # the eval of the previous perturb + + all_eval = torch.cat((prev_results, modified_eval), dim=0) + eval_diff = all_eval[num_examples:] - all_eval[:-num_examples] + prev_results = all_eval[-num_examples:] + + for j in range(len(total_attrib)): + # format eval_diff to shape + # (n_perturb, *output_shape, 1,.. 1) + # where n_perturb may not be perturb_per_eval + # Append n_input_feature dim of 1 to make the tensor + # have the same dim as the mask tensor. + formatted_eval_diff = eval_diff.reshape( + (-1,) + tuple(output_shape) + (len(inputs_tuple[j].shape) - 1) * (1,) + ) + + # mask in shape (n_perturb, *mask_shape_broadcastable_to_input) + # reshape to + # ( + # n_perturb, + # *broadcastable_to_output_shape + # *broadcastable_to_input_feature_shape + # ) + cur_mask = current_masks[j] + cur_mask = cur_mask.reshape( + tuple(cur_mask.shape[:2]) + + (len(output_shape) - 1) * (1,) + + tuple(cur_mask.shape[2:]) + ) + + # aggregate n_perturb + cur_attr = (formatted_eval_diff * cur_mask.float()).sum(dim=0) + # (*output_shape, *input_feature_shape) + total_attrib[j] += cur_attr + + result = ( + initial_eval, + prev_results, + output_shape, + total_attrib, + agg_output_mode, + ) + return result + + def _prevResultTupleToFormattedAttr( + self, + prev_result_tuple: Future[ + Tuple[Tensor, Tensor, Tuple[int], List[Tensor], bool] + ], + iter_count: int, + is_inputs_tuple: bool, + ) -> Union[Tensor, Tuple[Tensor, ...]]: + """Helper method to format total_attrib, which is a + list of tensors, into formatted attributions, which + are either a single tensor or a tuple of tensors""" + + ( + _, + _, + _, + total_attrib, + _, + ) = prev_result_tuple.value() + attrib = tuple( + tensor_attrib_total / iter_count for tensor_attrib_total in total_attrib + ) + formatted_attr = _format_output(is_inputs_tuple, attrib) + return formatted_attr + def _perturbation_generator( self, inputs: Tuple[Tensor, ...], @@ -574,6 +877,39 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor: # ref: https://github.com/pytorch/pytorch/pull/21215 return torch.tensor([forward_output], dtype=cast(dtype, output_type)) + # pyre-fixme[2]: Parameter must be annotated. + def _strict_run_forward_future(self, *args, **kwargs) -> Future[Tensor]: + """ + A temp wrapper for global _run_forward util to force + forward outputtype assertion & conversion, but takes + into account the Future tensor type + """ + + def process_strict_run_forward(fut: Future[Tensor]) -> Tensor: + output = fut.value() + if isinstance(output, Tensor): + # format scalar to shape (1) so we can always + # assume non-empty output_shape + if not output.shape: + output = output.reshape(1) + return output + output_type = type(output) + assert output_type is int or output_type is float, ( + "the return of forward_func must be a Future of tensor, int, or float," + f" received: {output_type}" + ) + output = torch.tensor([output], dtype=cast(dtype, output_type)) + return output + + forward_output = _run_forward(*args, **kwargs) + assert isinstance(forward_output, torch.Future), ( + "The return type of forward_func must be a Future" + f" received: {type(forward_output)}" + ) + + return_output = forward_output.then(process_strict_run_forward) + return return_output + class ShapleyValues(ShapleyValueSampling): """ diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 8c4685f752..8d22adafc3 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -8,6 +8,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from torch.futures import Future """ @no_type_check annotation is applied to type-hinted models to avoid errors @@ -477,6 +478,61 @@ def forward( return lin2_out +class BasicModel_MultiLayer_with_Future(nn.Module): + # This model is used to test the case where the model returns a future + def __init__(self, inplace: bool = False, multi_input_module: bool = False) -> None: + super().__init__() + # Linear 0 is simply identity transform + self.multi_input_module = multi_input_module + self.linear0 = nn.Linear(3, 3) + self.linear0.weight = nn.Parameter(torch.eye(3)) + self.linear0.bias = nn.Parameter(torch.zeros(3)) + self.linear1 = nn.Linear(3, 4) + self.linear1.weight = nn.Parameter(torch.ones(4, 3)) + self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0])) + + self.linear1_alt = nn.Linear(3, 4) + self.linear1_alt.weight = nn.Parameter(torch.ones(4, 3)) + self.linear1_alt.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0])) + self.multi_relu = MultiRelu(inplace=inplace) + self.relu = nn.ReLU(inplace=inplace) + + self.linear2 = nn.Linear(4, 2) + self.linear2.weight = nn.Parameter(torch.ones(2, 4)) + self.linear2.bias = nn.Parameter(torch.tensor([-1.0, 1.0])) + + @no_type_check + # pyre-fixme[3]: Return type must be annotated. + def forward( + self, + x: Tensor, + add_input: Optional[Tensor] = None, + multidim_output: bool = False, + ): + input = x if add_input is None else x + add_input + lin0_out = self.linear0(input) + lin1_out = self.linear1(lin0_out) + if self.multi_input_module: + relu_out1, relu_out2 = self.multi_relu(lin1_out, self.linear1_alt(input)) + relu_out = relu_out1 + relu_out2 + # relu is not used when multi_input_module set to True, + # so this is to set an unsued layer intentionally for testing + # and it won't be part of return + self.relu(lin1_out) + else: + relu_out = self.relu(lin1_out) + # pyre-fixme [29]: `typing.Type[Future]` is not a function + result = Future() + lin2_out = self.linear2(relu_out) + if multidim_output: + stack_mid = torch.stack((lin2_out, 2 * lin2_out), dim=2) + result.set_result(torch.stack((stack_mid, 4 * stack_mid), dim=3)) + return result + else: + result.set_result(lin2_out) + return result + + class BasicModelBoolInput(nn.Module): def __init__(self) -> None: super().__init__() diff --git a/tests/attr/test_shapley.py b/tests/attr/test_shapley.py index 976adc55f2..b88b478b60 100644 --- a/tests/attr/test_shapley.py +++ b/tests/attr/test_shapley.py @@ -14,6 +14,7 @@ from captum.testing.helpers.basic_models import ( BasicModel_MultiLayer, BasicModel_MultiLayer_MultiInput, + BasicModel_MultiLayer_with_Future, BasicModelBoolInput, ) @@ -30,6 +31,17 @@ def test_simple_shapley_sampling(self) -> None: n_samples=250, ) + def test_simple_shapley_sampling_future(self) -> None: + net = BasicModel_MultiLayer_with_Future() + inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) + self._shapley_test_assert_future( + net, + inp, + [[76.66666, 196.66666, 116.66666]], + perturbations_per_eval=(1, 2, 3), + n_samples=250, + ) + def test_simple_shapley_sampling_with_mask(self) -> None: net = BasicModel_MultiLayer() inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) @@ -388,15 +400,6 @@ def test_shapley_sampling_with_mask_and_show_progress(self, mock_stderr) -> None mock_stderr.seek(0) mock_stderr.truncate(0) - def test_futures_not_implemented(self) -> None: - net = BasicModel_MultiLayer() - - attributions = None - shapley_samp = ShapleyValueSampling(net) - with self.assertRaises(NotImplementedError): - attributions = shapley_samp.attribute_future() - self.assertEqual(attributions, None) - def _single_input_one_sample_batch_scalar_shapley_assert( self, func: Callable ) -> None: @@ -514,6 +517,53 @@ def _shapley_test_assert( self, attributions, expected_attr, mode="max", delta=0.001 ) + def _shapley_test_assert_future( + self, + model: Callable, + test_input: TensorOrTupleOfTensorsGeneric, + expected_attr, + feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, + additional_input: Any = None, + perturbations_per_eval: Tuple[int, ...] = (1,), + baselines: BaselineType = None, + target: Union[None, int] = 0, + n_samples: int = 100, + delta: float = 1.0, + # leaving this false as it is not supported for future + test_true_shapley: bool = False, + show_progress: bool = False, + ) -> None: + for batch_size in perturbations_per_eval: + shapley_samp = ShapleyValueSampling(model) + attributions = shapley_samp.attribute_future( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + n_samples=n_samples, + show_progress=show_progress, + ) + attributions.wait() + assertTensorTuplesAlmostEqual( + self, attributions.value(), expected_attr, delta=delta, mode="max" + ) + if test_true_shapley: + shapley_val = ShapleyValues(model) + attributions = shapley_val.attribute( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + show_progress=show_progress, + ) + assertTensorTuplesAlmostEqual( + self, attributions, expected_attr, mode="max", delta=0.001 + ) + if __name__ == "__main__": unittest.main()