Skip to content

Commit

Permalink
pins emptyset value in CHII of order 0 to the baseline value
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Dec 20, 2024
1 parent ccc0b99 commit 6a7eb7f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
9 changes: 9 additions & 0 deletions shapiq/game_theory/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
like interaction indices or generalized values."""

import copy
import warnings
from typing import Callable, Optional, Union

import numpy as np
Expand Down Expand Up @@ -405,6 +406,14 @@ def base_interaction(self, index: str, order: int) -> InteractionValues:
for i, interaction in enumerate(powerset(self._grand_coalition_set, max_size=order)):
interaction_lookup[interaction] = i

# CHII is un-defined for empty set
if index == "CHII" and () in interaction_lookup:
warnings.warn(
f"CHII is not defined for the empty set. Setting to the baseline value "
f"{self.baseline_value}."
)
base_interaction_values[interaction_lookup[()]] = self.baseline_value

# Transform into InteractionValues object and store in computed dictionary
base_interaction = InteractionValues(
values=base_interaction_values,
Expand Down
14 changes: 13 additions & 1 deletion tests/tests_game_theory/test_exact_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _interaction(arr: np.ndarray): # dtype bool
("BV", 1),
("SII", 2),
("BII", 2),
# ("CHII", 2), # TODO: fix this
("CHII", 2),
("Co-Moebius", 2),
("SGV", 2),
("BGV", 2),
Expand Down Expand Up @@ -239,6 +239,18 @@ def permutation_game(X: np.ndarray):
assert (value - perm_interaction_values[perm_coalition]) < 10e-7


def test_warning_cii():
"""Checks weather a warning is raised for the CHII index and min_order = 0."""
n = 5
soum = SOUM(n, n_basis_games=10)
exact_computer = ExactComputer(n_players=n, game_fun=soum)
with pytest.warns(UserWarning):
exact_computer("CHII", 0)

# check that warning is not raised for min_order > 0
exact_computer("CHII", 1)


@pytest.mark.parametrize(
"index, order",
[
Expand Down

0 comments on commit 6a7eb7f

Please sign in to comment.