Skip to content

Commit

Permalink
added probabilistic values and fixed SV
Browse files Browse the repository at this point in the history
  • Loading branch information
FFmgll committed Apr 10, 2024
1 parent bdca127 commit 6ca29f0
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 65 deletions.
119 changes: 56 additions & 63 deletions shapiq/exact_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ExactComputer:
n: The number of players.
big_M: The infinite weight for KernelSHAP
n_interactions: A pre-computed numpy array containing the number of interactions up to the size of the index, e.g. n_interactions[4] is the nuber of all interactions up to size 4
computed_interactions: A dictionary that stores computations of different indices
computed: A dictionary that stores computations of different indices
game_fun: The callable game
baseline_value: The baseline value, i.e. the emptyset prediction
game_values: A numpy array containing the game evaluations of all subsets
Expand All @@ -37,7 +37,7 @@ def __init__(
self.n = len(N)
self.big_M = 10e7
self.n_interactions = self.get_n_interactions()
self.computed_interactions = {}
self.computed = {}
self.game_fun = game_fun
self.baseline_value, self.game_values, self.coalition_lookup = self.compute_game_values(
game_fun
Expand Down Expand Up @@ -78,16 +78,18 @@ def moebius_transform(self):
moebius_transform = np.zeros(2**self.n)
# compute the Moebius transform
coalition_lookup = {}
for i, S in enumerate(powerset(self.N)):
coalition_lookup[S] = i
for i, S in enumerate(powerset(self.N)):
s = len(S)
S_pos = coalition_lookup[S]
for T in powerset(S):
pos = self.coalition_lookup[T]
moebius_transform[S_pos] += (-1) ** (s - len(T)) * self.game_values[pos]

self.computed_interactions["Moebius"] = InteractionValues(
for interaction_pos, interaction in enumerate(powerset(self.N)):
coalition_lookup[interaction] = interaction_pos
for interaction in powerset(self.N):
interaction_size = len(interaction)
interaction_pos = coalition_lookup[interaction]
for coalition in powerset(interaction):
coalition_pos = self.coalition_lookup[coalition]
moebius_transform[interaction_pos] += (-1) ** (
interaction_size - len(coalition)
) * self.game_values[coalition_pos]

self.computed["Moebius"] = InteractionValues(
values=moebius_transform,
index="Moebius",
max_order=self.n,
Expand All @@ -96,7 +98,7 @@ def moebius_transform(self):
interaction_lookup=coalition_lookup,
estimated=False,
)
return copy.copy(self.computed_interactions["Moebius"])
return copy.copy(self.computed["Moebius"])

def base_weights(self, coalition_size: int, interaction_size: int, index: str):
"""Computes the weight of different indices in their common representation,
Expand Down Expand Up @@ -231,7 +233,7 @@ def get_base_weights(self, index: str, order: int):
)
return base_weights

def base_interactions(self, index: str, order: int):
def base_interaction(self, index: str, order: int):
"""Computes interactions based on representation with discrete derivatives, e.g. SII, BII
Args:
Expand Down Expand Up @@ -261,7 +263,7 @@ def base_interactions(self, index: str, order: int):
interaction_lookup[interaction] = i

# Transform into InteractionValues object
self.computed_interactions[index] = InteractionValues(
self.computed[index] = InteractionValues(
values=base_interaction_values,
index=index,
max_order=order,
Expand All @@ -272,17 +274,23 @@ def base_interactions(self, index: str, order: int):
baseline_value=self.baseline_value,
)

return copy.copy(self.computed_interactions[index])

def base_generalized_values(self, index: str, order: int):
"""Computes the Base Generalized Values according to the representation with marginal contributions, e.g. SGV, BGV, CGV
return copy.copy(self.computed[index])

def base_generalized_value(self, index: str, order: int):
"""
Computes Base Generalized Values, i.e. probabilistic generalized values that do not depend on the order
According to the underlying representation using marginal contributions from https://doi.org/10.1016/j.dam.2006.05.002
Currently covers:
- SGV: Shapley Generalized Value https://doi.org/10.1016/S0166-218X(00)00264-X
- BGV: Banzhaf Generalized Value https://doi.org/10.1016/S0166-218X(00)00264-X
- CHGV: Chaining Generalized Value https://doi.org/10.1016/j.dam.2006.05.002
Args:
index: The interaction index
order: The interaction order
order: The highest order of interactions
index: The generalized value index
Returns:
An InteractionValues object containing the base generalized values
An InteractionValues object containing generalized values
"""

base_generalized_values = np.zeros(self.n_interactions[order])
Expand All @@ -306,7 +314,7 @@ def base_generalized_values(self, index: str, order: int):
)

# Transform into InteractionValues object
self.computed_interactions[index] = InteractionValues(
self.computed[index] = InteractionValues(
values=base_generalized_values,
index=index,
max_order=order,
Expand All @@ -316,7 +324,7 @@ def base_generalized_values(self, index: str, order: int):
estimated=False,
)

return self.computed_interactions[index].__copy__()
return copy.copy(self.computed[index])

def base_aggregation(self, base_interactions: InteractionValues, order: int):
"""Transform Base Interactions into Interactions satisfying efficiency, e.g. SII to k-SII
Expand Down Expand Up @@ -615,7 +623,7 @@ def compute_jointSV(self, order):
)
return jointSV

def shapley_generalized_values(self, order: int, index: str) -> InteractionValues:
def shapley_generalized_value(self, order: int, index: str) -> InteractionValues:
"""
Computes Shapley Generalized Values, i.e. Generalized Values that satisfy efficiency
According to the underlying representation in https://doi.org/10.1016/j.dam.2006.05.002
Expand All @@ -632,29 +640,7 @@ def shapley_generalized_values(self, order: int, index: str) -> InteractionValue
if index == "JointSV":
shapley_generalized_value = self.compute_jointSV(order)

self.computed_interactions[index] = shapley_generalized_value
return copy.copy(shapley_generalized_value)

def shapleygeneralized_values(self, order: int, index: str) -> InteractionValues:
"""
Computes Shapley Generalized Values, i.e. probabilistic generalized values that do not depend on the order
According to the underlying representation using marginal contributions from https://doi.org/10.1016/j.dam.2006.05.002
Currently covers:
- SGV: Shapley Generalized Value https://doi.org/10.1016/S0166-218X(00)00264-X
- BGV: Banzhaf Generalized Value https://doi.org/10.1016/S0166-218X(00)00264-X
- CHGV: Chaining Generalized Value https://doi.org/10.1016/j.dam.2006.05.002
Args:
order: The highest order of interactions
index: The generalized value index
Returns:
An InteractionValues object containing generalized values
"""
if index == "JointSV":
shapley_generalized_value = self.compute_jointSV(order)

self.computed_interactions[index] = shapley_generalized_value
self.computed[index] = shapley_generalized_value
return copy.copy(shapley_generalized_value)

def shapley_interaction(self, order: int, index: str = "k-SII") -> InteractionValues:
Expand All @@ -676,19 +662,19 @@ def shapley_interaction(self, order: int, index: str = "k-SII") -> InteractionVa
"""
if index == "k-SII":
sii = self.base_interactions("SII", order)
self.computed_interactions["SII"] = sii
shapley_interactions = self.base_aggregation(sii, order)
sii = self.base_interaction("SII", order)
self.computed["SII"] = sii
shapley_interaction = self.base_aggregation(sii, order)
if index == "STII":
shapley_interactions = self.compute_stii(order)
shapley_interaction = self.compute_stii(order)
if index == "FSII":
shapley_interactions = self.compute_fsii(order)
shapley_interaction = self.compute_fsii(order)
if index == "kADD-SHAP":
shapley_interactions = self.compute_kadd_shap(order)
shapley_interaction = self.compute_kadd_shap(order)

self.computed_interactions[index] = shapley_interactions
self.computed[index] = shapley_interaction

return copy.copy(shapley_interactions)
return copy.copy(shapley_interaction)

def shapley_base_interaction(self, order: int, index: str) -> InteractionValues:
"""
Expand All @@ -707,11 +693,11 @@ def shapley_base_interaction(self, order: int, index: str) -> InteractionValues:
An InteractionValues object containing interaction values
"""
base_interactions = self.base_interactions(index, order)
self.computed_interactions[index] = base_interactions
return copy.copy(base_interactions)
base_interaction = self.base_interaction(index, order)
self.computed[index] = base_interaction
return copy.copy(base_interaction)

def probabilistic_values(self, index: str) -> InteractionValues:
def probabilistic_value(self, index: str) -> InteractionValues:
"""Computes common semi-values or probabilistic values, i.e. shapley values without efficiency axiom. These are special of interaction indices and generalized values for order = 1.
According to the underlying representation using marginal contributions, cf.
- semi-values https://doi.org/10.1287/moor.6.1.122
Expand All @@ -727,7 +713,14 @@ def probabilistic_values(self, index: str) -> InteractionValues:
Returns:
An InteractionValues object containing probabilistic values
"""

probabilistic_value = self.base_interactions(index, 1)
self.computed_interactions[index] = probabilistic_value
if index == "BV":
probabilistic_value = self.base_interaction(index="BII", order=1)
if index == "SV":
probabilistic_value = self.base_interaction(index="SII", order=1)
# Change emptyset value of SII to baseline value
probabilistic_value.baseline_value = self.baseline_value
probabilistic_value.values[
probabilistic_value.interaction_lookup[tuple()]
] = self.baseline_value
self.computed[index] = probabilistic_value
return copy.copy(probabilistic_value)
20 changes: 19 additions & 1 deletion shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,25 @@

from shapiq.utils import generate_interaction_lookup, powerset

AVAILABLE_INDICES = {"k-SII", "SII", "STI", "FSI", "STII", "FSII", "SV", "BV", "BZF", "Moebius"}
AVAILABLE_INDICES = {
"JointSV",
"SGV",
"BGV",
"CHGV",
"CHII",
"BII",
"kADD-SHAP",
"k-SII",
"SII",
"STI",
"FSI",
"STII",
"FSII",
"SV",
"BV",
"BZF",
"Moebius",
}


@dataclass
Expand Down
39 changes: 38 additions & 1 deletion tests/test_exact_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def test_exact_computer_on_soum():
for i in range(100):
n = np.random.randint(low=2, high=12)
n = np.random.randint(low=2, high=10)
N = set(range(n))
order = np.random.randint(low=1, high=min(n, 5))
n_basis_games = np.random.randint(low=1, high=100)
Expand All @@ -21,6 +21,10 @@ def test_exact_computer_on_soum():
# Compute via sparse Möbius representation
moebius_converter = MoebiusConverter(N, soum.moebius_coefficients)

moebius_transform = exact_computer.moebius_transform()
# Assert equality with ground truth Möbius coefficients from SOUM
assert np.sum((moebius_transform - soum.moebius_coefficients).values ** 2) < 10e-7

# Compare ground truth via MoebiusConvert with exact computation of ExactComputer
shapley_interactions_gt = {}
shapley_interactions_exact = {}
Expand All @@ -29,9 +33,42 @@ def test_exact_computer_on_soum():
order, index
)
shapley_interactions_exact[index] = exact_computer.shapley_interaction(order, index)
# Check equality with ground truth calculations from SOUM
assert (
np.sum(
(shapley_interactions_exact[index] - shapley_interactions_gt[index]).values ** 2
)
< 10e-7
)

index = "JointSV"
shapley_generalized_values = exact_computer.shapley_generalized_value(
order=order, index=index
)
# Assert efficiency
assert (np.sum(shapley_generalized_values.values) - predicted_value) ** 2 < 10e-7

index = "kADD-SHAP"
shapley_interactions_exact[index] = exact_computer.shapley_interaction(order, index)

base_interaction_indices = ["SII", "BII", "CHII"]
base_interactions = {}
for base_index in base_interaction_indices:
base_interactions[base_index] = exact_computer.shapley_base_interaction(
order=order, index=base_index
)

base_gv_indices = ["SGV", "BGV", "CHGV"]
base_gv = {}
for base_gv_index in base_gv_indices:
base_gv[base_gv_index] = exact_computer.base_generalized_value(
order=order, index=base_gv_index
)

probabilistic_values_indices = ["SV", "BV"]
probabilistic_values = {}
for pv_index in probabilistic_values_indices:
probabilistic_values[pv_index] = exact_computer.probabilistic_value(index=pv_index)

# Assert efficiency for SV
assert (np.sum(probabilistic_values["SV"].values) - predicted_value) ** 2 < 10e-7

0 comments on commit 6ca29f0

Please sign in to comment.