Skip to content

Commit

Permalink
Add test coverage for futures in ShapleyValueSampling 2/n (#1490)
Browse files Browse the repository at this point in the history
Summary:

This diffs adds more testing coverage for attribute_future to ShapleyValueSampling unit tests that handle shapley sampling with boolean inputs

Reviewed By: cyrjano

Differential Revision: D68230069
  • Loading branch information
jjuncho authored and facebook-github-bot committed Jan 22, 2025
1 parent 4ef915b commit 0484226
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 19 deletions.
16 changes: 16 additions & 0 deletions captum/testing/helpers/basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,22 @@ def forward(
return result


class BasicModelBoolInput_with_Future(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mod = BasicModel_MultiLayer_with_Future()

# pyre-fixme[3]: Return type must be annotated.
def forward(
self,
x: Tensor,
add_input: Optional[Tensor] = None,
mult: float = 10.0,
):
assert x.dtype is torch.bool, "Input must be boolean"
return self.mod(x.float() * mult, add_input)


class BasicModelBoolInput(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down
62 changes: 43 additions & 19 deletions tests/attr/test_shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BasicModel_MultiLayer_MultiInput,
BasicModel_MultiLayer_with_Future,
BasicModelBoolInput,
BasicModelBoolInput_with_Future,
)
from parameterized import parameterized
from torch.futures import Future
Expand Down Expand Up @@ -66,28 +67,51 @@ def test_simple_shapley_sampling_with_mask(self, use_future) -> None:
perturbations_per_eval=(1, 2, 3),
)

def test_simple_shapley_sampling_boolean(self) -> None:
net = BasicModelBoolInput()
@parameterized.expand([True, False])
def test_simple_shapley_sampling_boolean(self, use_future) -> None:
inp = torch.tensor([[True, False, True]])
self._shapley_test_assert(
net,
inp,
[[35.0, 35.0, 35.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
perturbations_per_eval=(1, 2, 3),
)
if use_future:
net_fut = BasicModelBoolInput_with_Future()
self._shapley_test_assert_future(
net_fut,
inp,
[[35.0, 35.0, 35.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
perturbations_per_eval=(1, 2, 3),
)
else:
net = BasicModelBoolInput()
self._shapley_test_assert(
net,
inp,
[[35.0, 35.0, 35.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
perturbations_per_eval=(1, 2, 3),
)

def test_simple_shapley_sampling_boolean_with_baseline(self) -> None:
net = BasicModelBoolInput()
@parameterized.expand([True, False])
def test_simple_shapley_sampling_boolean_with_baseline(self, use_future) -> None:
inp = torch.tensor([[True, False, True]])
self._shapley_test_assert(
net,
inp,
[[-40.0, -40.0, 0.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
baselines=True,
perturbations_per_eval=(1, 2, 3),
)
if use_future:
net_fut = BasicModelBoolInput_with_Future()
self._shapley_test_assert_future(
net_fut,
inp,
[[-40.0, -40.0, 0.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
baselines=True,
perturbations_per_eval=(1, 2, 3),
)
else:
net = BasicModelBoolInput()
self._shapley_test_assert(
net,
inp,
[[-40.0, -40.0, 0.0]],
feature_mask=torch.tensor([[0, 0, 1]]),
baselines=True,
perturbations_per_eval=(1, 2, 3),
)

@parameterized.expand([True, False])
def test_simple_shapley_sampling_with_baselines(self, use_future) -> None:
Expand Down

0 comments on commit 0484226

Please sign in to comment.