diff --git a/.coveragerc b/.coveragerc index 159758d3..aad397a2 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,3 @@ [run] source = shapiq -omit = *tests* *venv* \ No newline at end of file +omit = *tests* *venv* *docs* *examples* \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe3f42e8..1fb01a29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,4 +21,4 @@ repos: name: ruff language: python types: [python] - entry: ruff + entry: ruff --fix diff --git a/docs/source/_static/logo_shapiq_light.png b/docs/source/_static/logo_shapiq_light.png new file mode 100644 index 00000000..e69de29b diff --git a/shapiq/__init__.py b/shapiq/__init__.py index 38adc37a..2369cf86 100644 --- a/shapiq/__init__.py +++ b/shapiq/__init__.py @@ -4,7 +4,13 @@ from __version__ import __version__ # approximator classes -from .approximator import PermutationSamplingSII, PermutationSamplingSTI, RegressionFSI, ShapIQ +from .approximator import ( + PermutationSamplingSII, + PermutationSamplingSTI, + RegressionSII, + RegressionFSI, + ShapIQ, +) # explainer classes from .explainer import Explainer @@ -31,6 +37,7 @@ "ShapIQ", "PermutationSamplingSII", "PermutationSamplingSTI", + "RegressionSII", "RegressionFSI", # explainers "Explainer", diff --git a/shapiq/approximator/__init__.py b/shapiq/approximator/__init__.py index d77b0a2a..dbc73d19 100644 --- a/shapiq/approximator/__init__.py +++ b/shapiq/approximator/__init__.py @@ -1,12 +1,16 @@ """This module contains the approximators to estimate the Shapley interaction values.""" +from ._base import convert_nsii_into_one_dimension, transforms_sii_to_nsii # TODO add to tests from .permutation.sii import PermutationSamplingSII from .permutation.sti import PermutationSamplingSTI -from .regression import RegressionFSI +from .regression import RegressionSII, RegressionFSI from .shapiq import ShapIQ __all__ = [ "PermutationSamplingSII", "PermutationSamplingSTI", "RegressionFSI", + "RegressionSII", "ShapIQ", + "transforms_sii_to_nsii", + "convert_nsii_into_one_dimension", ] diff --git a/shapiq/approximator/_base.py b/shapiq/approximator/_base.py index d6fc45c4..66d1559d 100644 --- a/shapiq/approximator/_base.py +++ b/shapiq/approximator/_base.py @@ -2,15 +2,25 @@ import copy from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Callable, Optional +from typing import Callable, Optional, Union import numpy as np -from scipy.special import binom +from scipy.special import binom, bernoulli from utils import get_explicit_subsets, powerset, split_subsets_budget AVAILABLE_INDICES = {"SII", "nSII", "STI", "FSI"} +__all__ = [ + "InteractionValues", + "Approximator", + "ShapleySamplingMixin", + "NShapleyMixin", + "transforms_sii_to_nsii", + "convert_nsii_into_one_dimension", +] + + @dataclass class InteractionValues: """This class contains the interaction values as estimated by an approximator. @@ -46,14 +56,9 @@ def __post_init__(self) -> None: f"Available indices are 'SII', 'nSII', 'STI', and 'FSI'." ) if self.interaction_lookup is None: - self.interaction_lookup = { - interaction: i - for i, interaction in enumerate( - powerset( - range(self.n_players), min_size=self.min_order, max_size=self.max_order - ) - ) - } + self.interaction_lookup = _generate_interaction_lookup( + self.n_players, self.min_order, self.max_order + ) def __repr__(self) -> str: """Returns the representation of the InteractionValues object.""" @@ -201,13 +206,10 @@ def __init__( self.top_order: bool = top_order self.max_order: int = max_order self.min_order: int = self.max_order if self.top_order else 1 - self.iteration_cost: Optional[int] = None - self._interaction_lookup: dict[tuple[int], int] = { - interaction: i - for i, interaction in enumerate( - powerset(self.N, min_size=self.min_order, max_size=self.max_order) - ) - } + self.iteration_cost: int = 1 # default value, can be overwritten by subclasses + self._interaction_lookup = _generate_interaction_lookup( + self.n, self.min_order, self.max_order + ) self._random_state: Optional[int] = random_state self._rng: Optional[np.random.Generator] = np.random.default_rng(seed=self._random_state) @@ -253,7 +255,11 @@ def _order_iterator(self) -> range: return range(self.min_order, self.max_order + 1) def _finalize_result( - self, result, estimated: bool = True, budget: Optional[int] = None + self, + result, + estimated: bool = True, + budget: Optional[int] = None, + index: Optional[str] = None, ) -> InteractionValues: """Finalizes the result dictionary. @@ -261,16 +267,19 @@ def _finalize_result( result: The result dictionary. estimated: Whether the interaction values are estimated or not. Defaults to True. budget: The budget used for the estimation. Defaults to None. + index: The interaction index estimated. Available indices are 'SII', 'nSII', 'STI', and + 'FSI'. Defaults to None (i.e., the index of the approximator is used). Returns: The interaction values. """ - # create InteractionValues object + if index is None: + index = self.index return InteractionValues( values=result, estimated=estimated, estimation_budget=budget, - index=self.index, + index=index, min_order=self.min_order, max_order=self.max_order, n_players=self.n, @@ -336,8 +345,12 @@ def __hash__(self) -> int: """Returns the hash of the Approximator object.""" return hash((self.n, self.max_order, self.index, self.top_order, self._random_state)) + @property + def interaction_lookup(self): + return self._interaction_lookup + -class ShapleySamplingMixin: +class ShapleySamplingMixin(ABC): """Mixin class for the computation of Shapley weights. Provides the common functionality for regression-based approximators like @@ -345,7 +358,9 @@ class ShapleySamplingMixin: and the corresponding sampling weights for the KernelSHAP-like estimation approaches. """ - def _init_ksh_sampling_weights(self) -> np.ndarray[float]: + def _init_ksh_sampling_weights( + self: Union[Approximator, "ShapleySamplingMixin"] + ) -> np.ndarray[float]: """Initializes the weights for sampling subsets. The sampling weights are of size n + 1 and indexed by the size of the subset. The edges @@ -354,13 +369,16 @@ def _init_ksh_sampling_weights(self) -> np.ndarray[float]: Returns: The weights for sampling subsets of size s in shape (n + 1,). """ + weight_vector = np.zeros(shape=self.n - 1, dtype=float) for subset_size in range(1, self.n): weight_vector[subset_size - 1] = (self.n - 1) / (subset_size * (self.n - subset_size)) sampling_weight = (np.asarray([0] + [*weight_vector] + [0])) / sum(weight_vector) return sampling_weight - def _get_ksh_subset_weights(self, subsets: np.ndarray[bool]) -> np.ndarray[float]: + def _get_ksh_subset_weights( + self: Union[Approximator, "ShapleySamplingMixin"], subsets: np.ndarray[bool] + ) -> np.ndarray[float]: """Computes the KernelSHAP regression weights for the given subsets. The weights for the subsets of size s are set to ksh_weights[s] / binom(n, s). The weights @@ -385,7 +403,7 @@ def _get_ksh_subset_weights(self, subsets: np.ndarray[bool]) -> np.ndarray[float return weights def _sample_subsets( - self, + self: Union[Approximator, "ShapleySamplingMixin"], budget: int, sampling_weights: np.ndarray[float], replacement: bool = False, @@ -439,7 +457,10 @@ def _sample_subsets( return subset_matrix def _generate_shapley_dataset( - self, budget: int, pairing: bool = True, replacement: bool = False + self: Union[Approximator, "ShapleySamplingMixin"], + budget: int, + pairing: bool = True, + replacement: bool = False, ) -> tuple[np.ndarray[bool], bool, int]: """Generates the two-part dataset containing explicit and sampled subsets. @@ -499,3 +520,179 @@ def _generate_shapley_dataset( ) n_explicit_subsets += 2 # add empty and full set return all_subsets, estimation_flag, n_explicit_subsets + + +class NShapleyMixin: + """Mixin class for the computation of n-Shapley values from SII estimators. + + Provides the common functionality for SII-based approximators like `PermutationSamplingSII` or + `ShapIQ` for SII to transform their interaction scores into nSII values. The nSII values are + proposed in this `paper`_. + """ + + def transforms_sii_to_nsii( + self: Approximator, + sii_values: Union[np.ndarray[float], InteractionValues], + ) -> Union[np.ndarray[float], InteractionValues]: + """Transforms the SII values into nSII values. + + Args: + sii_values: The SII values to transform. Can be either a numpy array or an + InteractionValues object. The output will be of the same type. + + Returns: + The nSII values in the same format as the input. + """ + return transforms_sii_to_nsii( + sii_values=sii_values, + approximator=self, + ) + + +def transforms_sii_to_nsii( + sii_values: Union[np.ndarray[float], InteractionValues], + *, + approximator: Optional[Approximator] = None, + n: Optional[int] = None, + max_order: Optional[int] = None, + interaction_lookup: Optional[dict] = None, +) -> Union[np.ndarray[float], InteractionValues]: + """Transforms the SII values into nSII values. + + Args: + sii_values: The SII values to transform. Can be either a numpy array or an + InteractionValues object. The output will be of the same type. + approximator: The approximator used to estimate the SII values. If provided, meta + information for the transformation is taken from the approximator. Defaults to None. + n: The number of players. Required if `approximator` is not provided. Defaults to None. + max_order: The maximum order of the approximation. Required if `approximator` is not + provided. Defaults to None. + interaction_lookup: A dictionary that maps interactions to their index in the values + vector. If `interaction_lookup` is not provided, it is computed from the `n_players` + and the `max_order` parameters. Defaults to `None`. + + Returns: + The nSII values in the same format as the input. + """ + if isinstance(sii_values, InteractionValues): + n_sii_values = _calculate_nsii_from_sii( + sii_values.values, + sii_values.n_players, + sii_values.max_order, + sii_values.interaction_lookup, + ) + return InteractionValues( + values=n_sii_values, + index="nSII", + max_order=sii_values.max_order, + min_order=sii_values.min_order, + n_players=sii_values.n_players, + interaction_lookup=sii_values.interaction_lookup, + estimated=sii_values.estimated, + estimation_budget=sii_values.estimation_budget, + ) + elif approximator is not None: + return _calculate_nsii_from_sii( + sii_values, approximator.n, approximator.max_order, approximator.interaction_lookup + ) + elif n is not None and max_order is not None: + if interaction_lookup is None: + interaction_lookup = _generate_interaction_lookup(n, 1, max_order) + return _calculate_nsii_from_sii(sii_values, n, max_order, interaction_lookup) + else: + raise ValueError( + "If the SII values are not provided as InteractionValues, the approximator " + "or the number of players and the maximum order of the approximation must be " + "provided." + ) + + +def _calculate_nsii_from_sii( + sii_values: np.ndarray[float], + n: int, + max_order: int, + interaction_lookup: Optional[dict] = None, +) -> np.ndarray[float]: + """Calculates the nSII values from the SII values. + + Args: + sii_values: The SII values to transform. + n: The number of players. + max_order: The maximum order of the approximation. + interaction_lookup: A dictionary that maps interactions to their index in the values + vector. If `interaction_lookup` is not provided, it is computed from the `n_players`, + `min_order`, and `max_order` parameters. Defaults to `None`. + + Returns: + The nSII values. + """ + # compute nSII values from SII values + bernoulli_numbers = bernoulli(max_order) + nsii_values = np.zeros_like(sii_values) + # all subsets S with 1 <= |S| <= max_order + for subset in powerset(set(range(n)), min_size=1, max_size=max_order): + interaction_index = interaction_lookup[subset] + interaction_size = len(subset) + n_sii_value = sii_values[interaction_index] + # go over all subsets T of length |S| + 1, ..., n that contain S + for T in powerset(set(range(n)), min_size=interaction_size + 1, max_size=max_order): + if set(subset).issubset(T): + effect_index = interaction_lookup[T] # get the index of T + effect_value = sii_values[effect_index] # get the effect of T + bernoulli_factor = bernoulli_numbers[len(T) - interaction_size] + n_sii_value += bernoulli_factor * effect_value + nsii_values[interaction_index] = n_sii_value + return nsii_values + + +def convert_nsii_into_one_dimension( + n_sii_values: InteractionValues, +) -> tuple[np.ndarray[float], np.ndarray[float]]: + """Converts the nSII values into one-dimensional values. + + Args: + n_sii_values: The nSII values to convert. + + Returns: + The positive and negative one-dimensional values. + """ + if n_sii_values.index != "nSII": + raise ValueError( + "Only nSII values can be converted into one-dimensional nSII values. Please use the " + "transforms_sii_to_nsii method to convert SII values into nSII values." + ) + max_order = n_sii_values.max_order + min_order = n_sii_values.min_order + n = n_sii_values.n_players + + pos_nsii_values = np.zeros(shape=(n,), dtype=float) + neg_nsii_values = np.zeros(shape=(n,), dtype=float) + + for subset in powerset(set(range(n)), min_size=min_order, max_size=max_order): + n_sii_value = n_sii_values[subset] / len(subset) # distribute uniformly + for player in subset: + if n_sii_value >= 0: + pos_nsii_values[player] += n_sii_value + else: + neg_nsii_values[player] += n_sii_value + return pos_nsii_values, neg_nsii_values + + +def _generate_interaction_lookup(n: int, min_order: int, max_order: int) -> dict[tuple[int], int]: + """Generates a lookup dictionary for interactions. + + Args: + n: The number of players. + min_order: The minimum order of the approximation. + max_order: The maximum order of the approximation. + + Returns: + A dictionary that maps interactions to their index in the values vector. + """ + interaction_lookup = { + interaction: i + for i, interaction in enumerate( + powerset(set(range(n)), min_size=min_order, max_size=max_order) + ) + } + return interaction_lookup diff --git a/shapiq/approximator/permutation/sii.py b/shapiq/approximator/permutation/sii.py index d4809519..09b527b7 100644 --- a/shapiq/approximator/permutation/sii.py +++ b/shapiq/approximator/permutation/sii.py @@ -2,11 +2,11 @@ from typing import Callable, Optional import numpy as np -from approximator._base import Approximator, InteractionValues +from approximator._base import Approximator, InteractionValues, NShapleyMixin from utils import powerset -class PermutationSamplingSII(Approximator): +class PermutationSamplingSII(Approximator, NShapleyMixin): """Permutation Sampling approximator for the SII (and nSII) index. Args: @@ -56,10 +56,13 @@ def __init__( self, n: int, max_order: int, + index: str = "SII", top_order: bool = False, random_state: Optional[int] = None, ) -> None: - super().__init__(n, max_order, "SII", top_order, random_state) + if index not in ["SII", "nSII"]: + raise ValueError(f"Invalid index {index}. Must be either 'SII' or 'nSII'.") + super().__init__(n, max_order, index, top_order, random_state) self.iteration_cost: int = self._compute_iteration_cost() def _compute_iteration_cost(self) -> int: @@ -150,4 +153,7 @@ def approximate( # compute mean of interactions result = np.divide(result, counts, out=result, where=counts != 0) + if self.index == "nSII": + result: np.ndarray[float] = self.transforms_sii_to_nsii(result) + return self._finalize_result(result, budget=used_budget, estimated=True) diff --git a/shapiq/approximator/regression/__init__.py b/shapiq/approximator/regression/__init__.py index 9c2c6cbf..203525b2 100644 --- a/shapiq/approximator/regression/__init__.py +++ b/shapiq/approximator/regression/__init__.py @@ -1,5 +1,6 @@ """This module contains the regression-based approximators to estimate Shapley interaction values. """ +from .sii import RegressionSII from .fsi import RegressionFSI -__all__ = ["RegressionFSI"] +__all__ = ["RegressionSII", "RegressionFSI"] diff --git a/shapiq/approximator/regression/_base.py b/shapiq/approximator/regression/_base.py new file mode 100644 index 00000000..838ffe0d --- /dev/null +++ b/shapiq/approximator/regression/_base.py @@ -0,0 +1,234 @@ +"""This module contains the regression algorithms to estimate FSI and SII scores.""" +from typing import Callable, Optional + +import numpy as np +from approximator._base import Approximator, InteractionValues, ShapleySamplingMixin +from scipy.special import binom, bernoulli + +from utils import powerset, get_explicit_subsets + +AVAILABLE_INDICES_REGRESSION = ["FSI", "SII"] + + +class Regression(Approximator, ShapleySamplingMixin): + """Estimates the InteractionScores values using the weighted least square approach. + + Args: + n: The number of players. + max_order: The interaction order of the approximation. + random_state: The random state of the estimator. Defaults to `None`. + + Attributes: + n: The number of players. + N: The set of players (starting from 0 to n - 1). + max_order: The interaction order of the approximation. + min_order: The minimum order of the approximation. For FSI, min_order is equal to 1. + iteration_cost: The cost of a single iteration of the regression FSI. + + Example: + >>> from games import DummyGame + >>> from approximator import RegressionSII + >>> game = DummyGame(n=5, interaction=(1, 2)) + >>> approximator = RegressionSII(n=5, max_order=2) + >>> approximator.approximate(budget=100, game=game) + InteractionValues( + index=FSI, order=2, estimated=False, estimation_budget=32, + values={ + (0,): 0.2, + (1,): 0.2, + (2,): 0.2, + (3,): 0.2, + (4,): 0.2, + (0, 1): 0, + (0, 2): 0, + (0, 3): 0, + (0, 4): 0, + (1, 2): 1.0, + (1, 3): 0, + (1, 4): 0, + (2, 3): 0, + (2, 4): 0, + (3, 4): 0 + } + ) + """ + + def __init__( + self, + n: int, + max_order: int, + index: str = "FSI", + random_state: Optional[int] = None, + ) -> None: + if index not in AVAILABLE_INDICES_REGRESSION: + raise ValueError( + f"Index {index} not available for regression. Choose from " + f"{AVAILABLE_INDICES_REGRESSION}." + ) + super().__init__( + n, max_order=max_order, index=index, top_order=False, random_state=random_state + ) + self.iteration_cost: int = 1 + self._bernoulli_numbers = bernoulli(self.n) # used for SII + + def approximate( + self, + budget: int, + game: Callable[[np.ndarray], np.ndarray], + batch_size: Optional[int] = None, + replacement: bool = False, + pairing: bool = True, + ) -> InteractionValues: + """Approximates the interaction values. + + Args: + budget: The budget of the approximation (how many times the game is queried). The game + is always queried for the empty and full set (`budget += 2`). + game: The game to be approximated. + batch_size: The batch size for the approximation. Defaults to `None`. If `None` the + batch size is set to the approximation budget. + replacement: Whether to sample subsets with replacement (`True`) or without replacement + (`False`). Defaults to `False`. + pairing: Whether to use the pairing sampling strategy or not. If paired sampling + (`True`) is used a subset is always paired with its complement subset and sampled + together. This may increase approximation quality. Defaults to `True`. + + Returns: + The interaction values. + + Raises: + np.linalg.LinAlgError: If the regression fails. + """ + # validate input parameters + batch_size = budget + 2 if batch_size is None else batch_size + used_budget = 0 + + # generate the dataset containing explicit and sampled subsets + all_subsets, estimation_flag, n_explicit_subsets = self._generate_shapley_dataset( + budget, pairing, replacement + ) + n_subsets = all_subsets.shape[0] + + # calculate the number of iterations and the last batch size + n_iterations, last_batch_size = self._calc_iteration_count( + n_subsets, batch_size, iteration_cost=self.iteration_cost + ) + + # get the fsi representation of the subsets + regression_weights = self._get_ksh_subset_weights(all_subsets) # W(|S|) + + # if SII is used regression_subsets needs to be changed + if self.index == "SII": + regression_subsets, num_players = self._get_sii_subset_representation(all_subsets) # A + else: + regression_subsets, num_players = self._get_fsi_subset_representation(all_subsets) # A + + # initialize the regression variables + game_values: np.ndarray[float] = np.zeros(shape=(n_subsets,), dtype=float) # \nu(S) + result: np.ndarray[float] = np.zeros(shape=(num_players,), dtype=float) + + # main regression loop computing the FSI values + for iteration in range(1, n_iterations + 1): + batch_size = batch_size if iteration != n_iterations else last_batch_size + batch_index = (iteration - 1) * batch_size + + # query the game for the batch of subsets + batch_subsets = all_subsets[batch_index : batch_index + batch_size] + game_values[batch_index : batch_index + batch_size] = game(batch_subsets) + + # compute the FSI values up to now + A = regression_subsets[0 : batch_index + batch_size] + B = game_values[0 : batch_index + batch_size] + W = regression_weights[0 : batch_index + batch_size] + W = np.sqrt(np.diag(W)) + Aw = np.dot(W, A) + Bw = np.dot(W, B) + + result = np.linalg.lstsq(Aw, Bw, rcond=None)[0] # \phi_i + + used_budget += batch_size + + return self._finalize_result(result, budget=used_budget, estimated=estimation_flag) + + def _get_fsi_subset_representation( + self, all_subsets: np.ndarray[bool] + ) -> tuple[np.ndarray[bool], int]: + """Transforms a subset matrix into the FSI representation. + + The FSI representation is a matrix of shape (n_subsets, num_players) where each interaction + up to the maximum order is an individual player. + + Args: + all_subsets: subset matrix in shape (n_subsets, n). + + Returns: + FSI representation of the subset matrix in shape (n_subsets, num_players) and the number + of players. + """ + n_subsets = all_subsets.shape[0] + num_players = sum(int(binom(self.n, order)) for order in range(1, self.max_order + 1)) + regression_subsets = np.zeros(shape=(n_subsets, num_players), dtype=bool) + for interaction_index, interaction in enumerate( + powerset(self.N, min_size=1, max_size=self.max_order) + ): + regression_subsets[:, interaction_index] = all_subsets[:, interaction].all(axis=1) + return regression_subsets, num_players + + def _get_sii_subset_representation( + self, all_subsets: np.ndarray[bool] + ) -> tuple[np.ndarray[bool], int]: + """Transforms a subset matrix into the SII representation. + + The SII representation is a matrix of shape (n_subsets, num_players) where each interaction + up to the maximum order is an individual player. + + Args: + all_subsets: subset matrix in shape (n_subsets, n). + + Returns: + SII representation of the subset matrix in shape (n_subsets, num_players) and the number + of players. + """ + n_subsets = all_subsets.shape[0] + num_players = sum(int(binom(self.n, order)) for order in range(1, self.max_order + 1)) + regression_subsets = np.zeros(shape=(n_subsets, num_players), dtype=float) + for interaction_index, interaction in enumerate( + powerset(self.N, min_size=1, max_size=self.max_order) + ): + intersection_size = np.sum(all_subsets[:, interaction], axis=1) + r_prime = np.full(shape=(n_subsets,), fill_value=len(interaction)) + weights = self._get_bernoulli_weights(intersection_size, r_prime) + regression_subsets[:, interaction_index] = weights + return regression_subsets, num_players + + def _get_bernoulli_weight(self, intersection_size: int, r_prime: int) -> float: + """Calculates the Bernoulli weights for the SII. + + Args: + intersection_size: The orders of the interactions. + r_prime: The orders of the interactions. + + Returns: + The Bernoulli weights. + """ + weight = 0 + for l in range(1, intersection_size + 1): + weight += binom(intersection_size, l) * self._bernoulli_numbers[r_prime - l] + return weight + + def _get_bernoulli_weights( + self, intersection_size: np.ndarray[int], r_prime: np.ndarray[int] + ) -> np.ndarray[float]: + """Calculates the Bernoulli weights for the SII. + + Args: + intersection_size: The orders of the interactions. + r_prime: The orders of the interactions. + + Returns: + The Bernoulli weights. + """ + weights = np.zeros(shape=(intersection_size.shape[0],), dtype=float) + for index, (intersection_size_i, r_prime_i) in enumerate(zip(intersection_size, r_prime)): + weights[index] = self._get_bernoulli_weight(intersection_size_i, r_prime_i) + return weights diff --git a/shapiq/approximator/regression/fsi.py b/shapiq/approximator/regression/fsi.py index 639020a5..4d8fb56f 100644 --- a/shapiq/approximator/regression/fsi.py +++ b/shapiq/approximator/regression/fsi.py @@ -1,13 +1,11 @@ -"""This module contains the regression algorithms to estimate FSI scores.""" -from typing import Callable, Optional +"""Regression with Faithful Shapley Interaction (FSI) index approximation.""" +from typing import Optional -import numpy as np -from approximator._base import Approximator, InteractionValues, ShapleySamplingMixin -from scipy.special import binom -from utils import powerset +from ._base import Regression +from .._base import NShapleyMixin -class RegressionFSI(Approximator, ShapleySamplingMixin): +class RegressionFSI(Regression, NShapleyMixin): """Estimates the FSI values using the weighted least square approach. Args: @@ -26,14 +24,14 @@ class RegressionFSI(Approximator, ShapleySamplingMixin): >>> from games import DummyGame >>> from approximator import RegressionFSI >>> game = DummyGame(n=5, interaction=(1, 2)) - >>> approximator = RegressionFSI(n=5, max_order=2) + >>> approximator = RegressionFsi(n=5, max_order=2) >>> approximator.approximate(budget=100, game=game) InteractionValues( index=FSI, order=2, estimated=False, estimation_budget=32, values={ (0,): 0.2, - (1,): 0.2, - (2,): 0.2, + (1,): 0.7, + (2,): 0.7, (3,): 0.2, (4,): 0.2, (0, 1): 0, @@ -50,110 +48,5 @@ class RegressionFSI(Approximator, ShapleySamplingMixin): ) """ - def __init__( - self, - n: int, - max_order: int, - random_state: Optional[int] = None, - ) -> None: - super().__init__( - n, max_order=max_order, index="FSI", top_order=False, random_state=random_state - ) - self.iteration_cost: int = 1 - - def approximate( - self, - budget: int, - game: Callable[[np.ndarray], np.ndarray], - batch_size: Optional[int] = None, - replacement: bool = False, - pairing: bool = True, - ) -> InteractionValues: - """Approximates the interaction values. - - Args: - budget: The budget of the approximation (how many times the game is queried). The game - is always queried for the empty and full set (`budget += 2`). - game: The game to be approximated. - batch_size: The batch size for the approximation. Defaults to `None`. If `None` the - batch size is set to the approximation budget. - replacement: Whether to sample subsets with replacement (`True`) or without replacement - (`False`). Defaults to `False`. - pairing: Whether to use the pairing sampling strategy or not. If paired sampling - (`True`) is used a subset is always paired with its complement subset and sampled - together. This may increase approximation quality. Defaults to `True`. - - Returns: - The interaction values. - - Raises: - np.linalg.LinAlgError: If the regression fails. - """ - # validate input parameters - batch_size = budget + 2 if batch_size is None else batch_size - used_budget = 0 - - # generate the dataset containing explicit and sampled subsets - all_subsets, estimation_flag, n_explicit_subsets = self._generate_shapley_dataset( - budget, pairing, replacement - ) - n_subsets = all_subsets.shape[0] - - # calculate the number of iterations and the last batch size - n_iterations, last_batch_size = self._calc_iteration_count( - n_subsets, batch_size, iteration_cost=self.iteration_cost - ) - - # get the fsi representation of the subsets - regression_subsets, num_players = self._get_fsi_subset_representation(all_subsets) # S, m - regression_weights = self._get_ksh_subset_weights(all_subsets) # W(|S|) - - # initialize the regression variables - game_values: np.ndarray[float] = np.zeros(shape=(n_subsets,), dtype=float) # \nu(S) - fsi_values: np.ndarray[float] = np.zeros(shape=(num_players,), dtype=float) - - # main regression loop computing the FSI values - for iteration in range(1, n_iterations + 1): - batch_size = batch_size if iteration != n_iterations else last_batch_size - batch_index = (iteration - 1) * batch_size - - # query the game for the batch of subsets - batch_subsets = all_subsets[batch_index : batch_index + batch_size] - game_values[batch_index : batch_index + batch_size] = game(batch_subsets) - - # compute the FSI values up to now - A = regression_subsets[0 : batch_index + batch_size] - B = game_values[0 : batch_index + batch_size] - W = regression_weights[0 : batch_index + batch_size] - W = np.sqrt(np.diag(W)) - Aw = np.dot(W, A) - Bw = np.dot(W, B) - - fsi_values = np.linalg.lstsq(Aw, Bw, rcond=None)[0] # \phi_i - - used_budget += batch_size - - return self._finalize_result(fsi_values, budget=used_budget, estimated=estimation_flag) - - def _get_fsi_subset_representation( - self, all_subsets: np.ndarray[bool] - ) -> tuple[np.ndarray[bool], int]: - """Transforms a subset matrix into the FSI representation. - - The FSI representation is a matrix of shape (n_subsets, num_players) where each interaction - up to the maximum order is an individual player. - - Args: - all_subsets: subset matrix in shape (n_subsets, n). - - Returns: - FSI representation of the subset matrix in shape (n_subsets, num_players). - """ - n_subsets = all_subsets.shape[0] - num_players = sum(int(binom(self.n, order)) for order in range(1, self.max_order + 1)) - regression_subsets = np.zeros(shape=(n_subsets, num_players), dtype=bool) - for interaction_index, interaction in enumerate( - powerset(self.N, min_size=1, max_size=self.max_order) - ): - regression_subsets[:, interaction_index] = all_subsets[:, interaction].all(axis=1) - return regression_subsets, num_players + def __init__(self, n: int, max_order: int, random_state: Optional[int] = None): + super().__init__(n, max_order, index="FSI", random_state=random_state) diff --git a/shapiq/approximator/regression/sii.py b/shapiq/approximator/regression/sii.py new file mode 100644 index 00000000..7a7a8535 --- /dev/null +++ b/shapiq/approximator/regression/sii.py @@ -0,0 +1,53 @@ +"""Regression with Shapley interaction index (SII) approximation.""" +from typing import Optional + +from ._base import Regression +from .._base import NShapleyMixin + + +class RegressionSII(Regression, NShapleyMixin): + """Estimates the SII values using the weighted least square approach. + + Args: + n: The number of players. + max_order: The interaction order of the approximation. + random_state: The random state of the estimator. Defaults to `None`. + + Attributes: + n: The number of players. + N: The set of players (starting from 0 to n - 1). + max_order: The interaction order of the approximation. + min_order: The minimum order of the approximation. For the regression estimator, min_order + is equal to 1. + iteration_cost: The cost of a single iteration of the regression SII. + + Example: + >>> from games import DummyGame + >>> from approximator import RegressionSII + >>> game = DummyGame(n=5, interaction=(1, 2)) + >>> approximator = RegressionSII(n=5, max_order=2) + >>> approximator.approximate(budget=100, game=game) + InteractionValues( + index=SII, order=2, estimated=False, estimation_budget=32, + values={ + (0,): 0.2, + (1,): 0.7, + (2,): 0.7, + (3,): 0.2, + (4,): 0.2, + (0, 1): 0, + (0, 2): 0, + (0, 3): 0, + (0, 4): 0, + (1, 2): 1.0, + (1, 3): 0, + (1, 4): 0, + (2, 3): 0, + (2, 4): 0, + (3, 4): 0 + } + ) + """ + + def __init__(self, n: int, max_order: int, random_state: Optional[int] = None): + super().__init__(n, max_order, index="SII", random_state=random_state) diff --git a/shapiq/approximator/shapiq/shapiq.py b/shapiq/approximator/shapiq/shapiq.py index 1f584d06..c118de9a 100644 --- a/shapiq/approximator/shapiq/shapiq.py +++ b/shapiq/approximator/shapiq/shapiq.py @@ -3,13 +3,13 @@ from typing import Callable, Optional import numpy as np -from approximator._base import Approximator, InteractionValues, ShapleySamplingMixin +from approximator._base import Approximator, InteractionValues, ShapleySamplingMixin, NShapleyMixin from utils import powerset AVAILABLE_INDICES_SHAPIQ = {"SII, STI, FSI, nSII"} -class ShapIQ(Approximator, ShapleySamplingMixin): +class ShapIQ(Approximator, ShapleySamplingMixin, NShapleyMixin): """The ShapIQ estimator. Args: @@ -146,6 +146,9 @@ def approximate( result_sampled = np.divide(result_sampled, counts, out=result_sampled, where=counts != 0) result = result_explicit + result_sampled + if self.index == "nSII": + result: np.ndarray[float] = self.transforms_sii_to_nsii(result) + return self._finalize_result(result, budget=used_budget, estimated=estimation_flag) def _sii_weight_kernel(self, subset_size: int, interaction_size: int) -> float: @@ -221,7 +224,7 @@ def _weight_kernel(self, subset_size: int, interaction_size: int) -> float: Returns: float: The weight for the interaction type. """ - if self.index == "SII" or self.index == "nSII": + if self.index == "SII" or self.index == "nSII": # in both cases return SII kernel return self._sii_weight_kernel(subset_size, interaction_size) elif self.index == "STI": return self._sti_weight_kernel(subset_size, interaction_size) diff --git a/shapiq/utils/sets.py b/shapiq/utils/sets.py index 6b7548cb..bc3caef1 100644 --- a/shapiq/utils/sets.py +++ b/shapiq/utils/sets.py @@ -175,9 +175,3 @@ def get_explicit_subsets(n: int, subset_sizes: list[int]) -> np.ndarray[bool]: subset_matrix[subset_index, subset] = True subset_index += 1 return subset_matrix - - -if __name__ == "__main__": - import doctest - - doctest.testmod() diff --git a/tests/test_approximator_nsii_estimation.py b/tests/test_approximator_nsii_estimation.py new file mode 100644 index 00000000..ce311c96 --- /dev/null +++ b/tests/test_approximator_nsii_estimation.py @@ -0,0 +1,64 @@ +"""Tests the approximiation of nSII values with PermutationSamplingSII and ShapIQ.""" +import numpy as np +import pytest + +from approximator import convert_nsii_into_one_dimension, transforms_sii_to_nsii +from shapiq import DummyGame, PermutationSamplingSII, ShapIQ + + +@pytest.mark.parametrize( + "sii_approximator, nsii_approximator", + [ + ( + PermutationSamplingSII(7, 2, "SII", False, random_state=42), + PermutationSamplingSII(7, 2, "nSII", False, random_state=42), + ), + (ShapIQ(7, 2, "SII", False, random_state=42), ShapIQ(7, 2, "nSII", False, random_state=42)), + ], +) +def test_nsii_estimation(sii_approximator, nsii_approximator): + """Tests the approximation of nSII values with PermutationSamplingSII and ShapIQ.""" + n = 7 + max_order = 2 + interaction = (1, 2) + game = DummyGame(n, interaction) + # sii_approximator = PermutationSamplingSII(n, max_order, "SII", False, random_state=42) + sii_estimates = sii_approximator.approximate(1_000, game, batch_size=None) + # nsii_approximator = PermutationSamplingSII(n, max_order, "nSII", False, random_state=42) + nsii_estimates = nsii_approximator.approximate(1_000, game, batch_size=None) + assert sii_estimates != nsii_estimates + assert nsii_estimates.index == "nSII" + + n_sii_transformed = nsii_approximator.transforms_sii_to_nsii(sii_estimates) + assert n_sii_transformed.index == "nSII" + assert n_sii_transformed == nsii_estimates # check weather transform and estimation are equal + + # nSII values for player 1 and 2 should be approximately 0.1429 and the interaction 1.0 + assert nsii_estimates[(1,)] == pytest.approx(0.1429, 0.4) + assert nsii_estimates[(2,)] == pytest.approx(0.1429, 0.4) + assert nsii_estimates[(1, 2)] == pytest.approx(1.0, 0.2) + + # check efficiency + efficiency = np.sum(nsii_estimates.values) + assert efficiency == pytest.approx(2.0, 0.01) + + # check one dim transform + pos_nsii_values, neg_nsii_values = convert_nsii_into_one_dimension(nsii_estimates) + assert pos_nsii_values.shape == (n,) and neg_nsii_values.shape == (n,) + assert np.all(pos_nsii_values >= 0) and np.all(neg_nsii_values <= 0) + sum_of_both = np.sum(pos_nsii_values) + np.sum(neg_nsii_values) + assert sum_of_both == pytest.approx(efficiency, 0.01) + assert sum_of_both != pytest.approx(0.0, 0.01) + + with pytest.raises(ValueError): + _ = convert_nsii_into_one_dimension(sii_estimates) + + # check transforms_sii_to_nsii function + transformed = transforms_sii_to_nsii(sii_estimates) + assert transformed.index == "nSII" + transformed = transforms_sii_to_nsii(sii_estimates.values, approximator=sii_approximator) + assert isinstance(transformed, np.ndarray) + transformed = transforms_sii_to_nsii(sii_estimates.values, n=n, max_order=max_order) + assert isinstance(transformed, np.ndarray) + with pytest.raises(ValueError): + _ = transforms_sii_to_nsii(sii_estimates.values) diff --git a/tests/test_approximator_permutation_sii.py b/tests/test_approximator_permutation_sii.py index c5a6b00e..9d2824cb 100644 --- a/tests/test_approximator_permutation_sii.py +++ b/tests/test_approximator_permutation_sii.py @@ -10,24 +10,30 @@ @pytest.mark.parametrize( - "n, max_order, top_order, expected", + "n, max_order, top_order, index, expected", [ - (3, 1, True, 6), - (3, 1, False, 6), - (3, 2, True, 8), - (3, 2, False, 14), - (10, 3, False, 120), + (3, 1, True, "SII", 6), + (3, 1, False, "SII", 6), + (3, 2, True, "SII", 8), + (3, 2, False, "SII", 14), + (10, 3, False, "SII", 120), + (10, 3, False, "nSII", 120), + (10, 3, False, "something", 120), # expected to fail with ValueError ], ) -def test_initialization(n, max_order, top_order, expected): +def test_initialization(n, max_order, top_order, index, expected): """Tests the initialization of the PermutationSamplingSII approximator.""" - approximator = PermutationSamplingSII(n, max_order, top_order) + if index == "something": + with pytest.raises(ValueError): + _ = PermutationSamplingSII(n, max_order, index, top_order) + return + approximator = PermutationSamplingSII(n, max_order, index, top_order) assert approximator.n == n assert approximator.max_order == max_order assert approximator.top_order == top_order assert approximator.min_order == (max_order if top_order else 1) assert approximator.iteration_cost == expected - assert approximator.index == "SII" + assert approximator.index == index approximator_copy = copy(approximator) approximator_deepcopy = deepcopy(approximator) @@ -55,7 +61,7 @@ def test_approximate(n, max_order, top_order, budget, batch_size): """Tests the approximation of the PermutationSamplingSII approximator.""" interaction = (1, 2) game = DummyGame(n, interaction) - approximator = PermutationSamplingSII(n, max_order, top_order, random_state=42) + approximator = PermutationSamplingSII(n, max_order, "SII", top_order, random_state=42) sii_estimates = approximator.approximate(budget, game, batch_size=batch_size) assert isinstance(sii_estimates, InteractionValues) assert sii_estimates.max_order == max_order @@ -66,9 +72,9 @@ def test_approximate(n, max_order, top_order, budget, batch_size): # check that the estimates are correct if not top_order: - # for order 1 player 1 and 2 are the most important with 0.7 - assert sii_estimates[(1,)] == pytest.approx(0.7, 0.5) # quite a large interval - assert sii_estimates[(2,)] == pytest.approx(0.7, 0.5) + # for order 1 player 1 and 2 are the most important with 0.6429 + assert sii_estimates[(1,)] == pytest.approx(0.6429, 0.4) # quite a large interval + assert sii_estimates[(2,)] == pytest.approx(0.6429, 0.4) # for order 2 the interaction between player 1 and 2 is the most important assert sii_estimates[(1, 2)] == pytest.approx(1.0, 0.2) diff --git a/tests/test_approximator_regression_sii.py b/tests/test_approximator_regression_sii.py new file mode 100644 index 00000000..e6b223b2 --- /dev/null +++ b/tests/test_approximator_regression_sii.py @@ -0,0 +1,79 @@ +"""This test module contains all tests regarding the SII regression approximator.""" +from copy import deepcopy, copy + +import numpy as np +import pytest + +from approximator._base import InteractionValues +from approximator.regression._base import Regression +from approximator.regression import RegressionSII +from games import DummyGame + + +@pytest.mark.parametrize( + "n, max_order", + [ + (3, 1), + (3, 1), + (3, 2), + (3, 2), + (7, 2), # used in subsequent tests + (10, 3), + ], +) +def test_initialization(n, max_order): + """Tests the initialization of the Regression approximator for SII.""" + approximator = RegressionSII(n, max_order) + assert approximator.n == n + assert approximator.max_order == max_order + assert approximator.top_order is False + assert approximator.min_order == 1 + assert approximator.iteration_cost == 1 + assert approximator.index == "SII" + + approximator_copy = copy(approximator) + approximator_deepcopy = deepcopy(approximator) + approximator_deepcopy.index = "something" + assert approximator_copy == approximator # check that the copy is equal + assert approximator_deepcopy != approximator # check that the deepcopy is not equal + approximator_string = str(approximator) + assert repr(approximator) == approximator_string + assert hash(approximator) == hash(approximator_copy) + assert hash(approximator) != hash(approximator_deepcopy) + with pytest.raises(ValueError): + _ = approximator == 1 + with pytest.raises(ValueError): + _ = Regression(n, max_order, index="something") + + +@pytest.mark.parametrize( + "n, max_order, budget, batch_size", [(7, 2, 380, 100), (7, 2, 380, None), (7, 2, 100, None)] +) +def test_approximate(n, max_order, budget, batch_size): + """Tests the approximation of the Regression approximator for SII.""" + interaction = (1, 2) + game = DummyGame(n, interaction) + approximator = RegressionSII(n, max_order, random_state=42) + sii_estimates = approximator.approximate(budget, game, batch_size=batch_size) + assert isinstance(sii_estimates, InteractionValues) + assert sii_estimates.max_order == max_order + assert sii_estimates.min_order == 1 + + # check that the budget is respected + assert game.access_counter <= budget + 2 + + # check that the estimates are correct + # for order 1 player 1 and 2 are the most important with 0.6429 + assert sii_estimates[(1,)] == pytest.approx(0.6429, 0.4) # quite a large interval + assert sii_estimates[(2,)] == pytest.approx(0.6429, 0.4) + + # for order 2 the interaction between player 1 and 2 is the most important + assert sii_estimates[(1, 2)] == pytest.approx(1.0, 0.2) + + # check efficiency + efficiency = np.sum(sii_estimates.values[:n]) + assert efficiency == pytest.approx(2.0, 0.01) + + # try covert to nSII + nsii_estimates = approximator.transforms_sii_to_nsii(sii_estimates) + assert nsii_estimates.index == "nSII" diff --git a/tests/test_approximator_shapiq.py b/tests/test_approximator_shapiq.py index e7dd93fe..02321089 100644 --- a/tests/test_approximator_shapiq.py +++ b/tests/test_approximator_shapiq.py @@ -69,7 +69,8 @@ def test_approximate_fsi(n, max_order, budget, batch_size): @pytest.mark.parametrize( - "n, max_order, top_order, budget, batch_size", [(7, 2, False, 100, None), (7, 2, True, 100, 10)] + "n, max_order, top_order, budget, batch_size", + [(7, 2, False, 100, None), (7, 2, True, 100, 10), (7, 2, False, 300, None)], ) def test_approximate_sii(n, max_order, top_order, budget, batch_size): """Tests the approximation of the ShapIQ SII approximation.""" @@ -88,9 +89,13 @@ def test_approximate_sii(n, max_order, top_order, budget, batch_size): assert estimates[interaction] == pytest.approx(1.0, 0.4) if not top_order: - # for order 1 (min_order) the interaction between player 1 and 2 is the most important (0.7) - assert estimates[(1,)] == pytest.approx(0.7, 0.4) - assert estimates[(2,)] == pytest.approx(0.7, 0.4) + # for order 1 (min_order) the interaction between 1 and 2 is the most important (0.6429) + if budget <= 2**n: + assert estimates[(1,)] == pytest.approx(0.6429, 0.4) + assert estimates[(2,)] == pytest.approx(0.6429, 0.4) + else: + assert estimates[(1,)] == pytest.approx(0.6429, 0.0001) + assert estimates[(2,)] == pytest.approx(0.6429, 0.0001) # check efficiency assert np.sum(estimates.values[:n]) == pytest.approx(2.0, 0.4) diff --git a/tests/test_integration_import_all.py b/tests/test_integration_import_all.py index 9d69ccb2..b8c84933 100644 --- a/tests/test_integration_import_all.py +++ b/tests/test_integration_import_all.py @@ -57,10 +57,17 @@ def test_approximator_imports(): from shapiq.approximator import ( PermutationSamplingSII, PermutationSamplingSTI, + RegressionSII, RegressionFSI, ShapIQ, ) - from shapiq import ShapIQ, PermutationSamplingSII, PermutationSamplingSTI, RegressionFSI + from shapiq import ( + ShapIQ, + PermutationSamplingSII, + PermutationSamplingSTI, + RegressionSII, + RegressionFSI, + ) assert True