Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nSII and SII Regression Estimator #22

Merged
merged 6 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[run]
source = shapiq
omit = *tests* *venv*
omit = *tests* *venv* *docs* *examples*
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ repos:
name: ruff
language: python
types: [python]
entry: ruff
entry: ruff --fix
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 8 additions & 1 deletion shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from __version__ import __version__

# approximator classes
from .approximator import PermutationSamplingSII, PermutationSamplingSTI, RegressionFSI, ShapIQ
from .approximator import (
PermutationSamplingSII,
PermutationSamplingSTI,
RegressionSII,
RegressionFSI,
ShapIQ,
)

# explainer classes
from .explainer import Explainer
Expand All @@ -31,6 +37,7 @@
"ShapIQ",
"PermutationSamplingSII",
"PermutationSamplingSTI",
"RegressionSII",
"RegressionFSI",
# explainers
"Explainer",
Expand Down
6 changes: 5 additions & 1 deletion shapiq/approximator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""This module contains the approximators to estimate the Shapley interaction values."""
from ._base import convert_nsii_into_one_dimension, transforms_sii_to_nsii # TODO add to tests
from .permutation.sii import PermutationSamplingSII
from .permutation.sti import PermutationSamplingSTI
from .regression import RegressionFSI
from .regression import RegressionSII, RegressionFSI
from .shapiq import ShapIQ

__all__ = [
"PermutationSamplingSII",
"PermutationSamplingSTI",
"RegressionFSI",
"RegressionSII",
"ShapIQ",
"transforms_sii_to_nsii",
"convert_nsii_into_one_dimension",
]
247 changes: 222 additions & 25 deletions shapiq/approximator/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,25 @@
import copy
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional
from typing import Callable, Optional, Union

import numpy as np
from scipy.special import binom
from scipy.special import binom, bernoulli
from utils import get_explicit_subsets, powerset, split_subsets_budget

AVAILABLE_INDICES = {"SII", "nSII", "STI", "FSI"}


__all__ = [
"InteractionValues",
"Approximator",
"ShapleySamplingMixin",
"NShapleyMixin",
"transforms_sii_to_nsii",
"convert_nsii_into_one_dimension",
]


@dataclass
class InteractionValues:
"""This class contains the interaction values as estimated by an approximator.
Expand Down Expand Up @@ -46,14 +56,9 @@ def __post_init__(self) -> None:
f"Available indices are 'SII', 'nSII', 'STI', and 'FSI'."
)
if self.interaction_lookup is None:
self.interaction_lookup = {
interaction: i
for i, interaction in enumerate(
powerset(
range(self.n_players), min_size=self.min_order, max_size=self.max_order
)
)
}
self.interaction_lookup = _generate_interaction_lookup(
self.n_players, self.min_order, self.max_order
)

def __repr__(self) -> str:
"""Returns the representation of the InteractionValues object."""
Expand Down Expand Up @@ -201,13 +206,10 @@ def __init__(
self.top_order: bool = top_order
self.max_order: int = max_order
self.min_order: int = self.max_order if self.top_order else 1
self.iteration_cost: Optional[int] = None
self._interaction_lookup: dict[tuple[int], int] = {
interaction: i
for i, interaction in enumerate(
powerset(self.N, min_size=self.min_order, max_size=self.max_order)
)
}
self.iteration_cost: int = 1 # default value, can be overwritten by subclasses
self._interaction_lookup = _generate_interaction_lookup(
self.n, self.min_order, self.max_order
)
self._random_state: Optional[int] = random_state
self._rng: Optional[np.random.Generator] = np.random.default_rng(seed=self._random_state)

Expand Down Expand Up @@ -253,24 +255,31 @@ def _order_iterator(self) -> range:
return range(self.min_order, self.max_order + 1)

def _finalize_result(
self, result, estimated: bool = True, budget: Optional[int] = None
self,
result,
estimated: bool = True,
budget: Optional[int] = None,
index: Optional[str] = None,
) -> InteractionValues:
"""Finalizes the result dictionary.

Args:
result: The result dictionary.
estimated: Whether the interaction values are estimated or not. Defaults to True.
budget: The budget used for the estimation. Defaults to None.
index: The interaction index estimated. Available indices are 'SII', 'nSII', 'STI', and
'FSI'. Defaults to None (i.e., the index of the approximator is used).

Returns:
The interaction values.
"""
# create InteractionValues object
if index is None:
index = self.index
return InteractionValues(
values=result,
estimated=estimated,
estimation_budget=budget,
index=self.index,
index=index,
min_order=self.min_order,
max_order=self.max_order,
n_players=self.n,
Expand Down Expand Up @@ -336,16 +345,22 @@ def __hash__(self) -> int:
"""Returns the hash of the Approximator object."""
return hash((self.n, self.max_order, self.index, self.top_order, self._random_state))

@property
def interaction_lookup(self):
return self._interaction_lookup


class ShapleySamplingMixin:
class ShapleySamplingMixin(ABC):
"""Mixin class for the computation of Shapley weights.

Provides the common functionality for regression-based approximators like
:class:`~shapiq.approximators.RegressionFSI`. The class offers computation of Shapley weights
and the corresponding sampling weights for the KernelSHAP-like estimation approaches.
"""

def _init_ksh_sampling_weights(self) -> np.ndarray[float]:
def _init_ksh_sampling_weights(
self: Union[Approximator, "ShapleySamplingMixin"]
) -> np.ndarray[float]:
"""Initializes the weights for sampling subsets.

The sampling weights are of size n + 1 and indexed by the size of the subset. The edges
Expand All @@ -354,13 +369,16 @@ def _init_ksh_sampling_weights(self) -> np.ndarray[float]:
Returns:
The weights for sampling subsets of size s in shape (n + 1,).
"""

weight_vector = np.zeros(shape=self.n - 1, dtype=float)
for subset_size in range(1, self.n):
weight_vector[subset_size - 1] = (self.n - 1) / (subset_size * (self.n - subset_size))
sampling_weight = (np.asarray([0] + [*weight_vector] + [0])) / sum(weight_vector)
return sampling_weight

def _get_ksh_subset_weights(self, subsets: np.ndarray[bool]) -> np.ndarray[float]:
def _get_ksh_subset_weights(
self: Union[Approximator, "ShapleySamplingMixin"], subsets: np.ndarray[bool]
) -> np.ndarray[float]:
"""Computes the KernelSHAP regression weights for the given subsets.

The weights for the subsets of size s are set to ksh_weights[s] / binom(n, s). The weights
Expand All @@ -385,7 +403,7 @@ def _get_ksh_subset_weights(self, subsets: np.ndarray[bool]) -> np.ndarray[float
return weights

def _sample_subsets(
self,
self: Union[Approximator, "ShapleySamplingMixin"],
budget: int,
sampling_weights: np.ndarray[float],
replacement: bool = False,
Expand Down Expand Up @@ -439,7 +457,10 @@ def _sample_subsets(
return subset_matrix

def _generate_shapley_dataset(
self, budget: int, pairing: bool = True, replacement: bool = False
self: Union[Approximator, "ShapleySamplingMixin"],
budget: int,
pairing: bool = True,
replacement: bool = False,
) -> tuple[np.ndarray[bool], bool, int]:
"""Generates the two-part dataset containing explicit and sampled subsets.

Expand Down Expand Up @@ -499,3 +520,179 @@ def _generate_shapley_dataset(
)
n_explicit_subsets += 2 # add empty and full set
return all_subsets, estimation_flag, n_explicit_subsets


class NShapleyMixin:
"""Mixin class for the computation of n-Shapley values from SII estimators.

Provides the common functionality for SII-based approximators like `PermutationSamplingSII` or
`ShapIQ` for SII to transform their interaction scores into nSII values. The nSII values are
proposed in this `paper<https://proceedings.mlr.press/v206/bordt23a>`_.
"""

def transforms_sii_to_nsii(
self: Approximator,
sii_values: Union[np.ndarray[float], InteractionValues],
) -> Union[np.ndarray[float], InteractionValues]:
"""Transforms the SII values into nSII values.

Args:
sii_values: The SII values to transform. Can be either a numpy array or an
InteractionValues object. The output will be of the same type.

Returns:
The nSII values in the same format as the input.
"""
return transforms_sii_to_nsii(
sii_values=sii_values,
approximator=self,
)


def transforms_sii_to_nsii(
sii_values: Union[np.ndarray[float], InteractionValues],
*,
approximator: Optional[Approximator] = None,
n: Optional[int] = None,
max_order: Optional[int] = None,
interaction_lookup: Optional[dict] = None,
) -> Union[np.ndarray[float], InteractionValues]:
"""Transforms the SII values into nSII values.

Args:
sii_values: The SII values to transform. Can be either a numpy array or an
InteractionValues object. The output will be of the same type.
approximator: The approximator used to estimate the SII values. If provided, meta
information for the transformation is taken from the approximator. Defaults to None.
n: The number of players. Required if `approximator` is not provided. Defaults to None.
max_order: The maximum order of the approximation. Required if `approximator` is not
provided. Defaults to None.
interaction_lookup: A dictionary that maps interactions to their index in the values
vector. If `interaction_lookup` is not provided, it is computed from the `n_players`
and the `max_order` parameters. Defaults to `None`.

Returns:
The nSII values in the same format as the input.
"""
if isinstance(sii_values, InteractionValues):
n_sii_values = _calculate_nsii_from_sii(
sii_values.values,
sii_values.n_players,
sii_values.max_order,
sii_values.interaction_lookup,
)
return InteractionValues(
values=n_sii_values,
index="nSII",
max_order=sii_values.max_order,
min_order=sii_values.min_order,
n_players=sii_values.n_players,
interaction_lookup=sii_values.interaction_lookup,
estimated=sii_values.estimated,
estimation_budget=sii_values.estimation_budget,
)
elif approximator is not None:
return _calculate_nsii_from_sii(
sii_values, approximator.n, approximator.max_order, approximator.interaction_lookup
)
elif n is not None and max_order is not None:
if interaction_lookup is None:
interaction_lookup = _generate_interaction_lookup(n, 1, max_order)
return _calculate_nsii_from_sii(sii_values, n, max_order, interaction_lookup)
else:
raise ValueError(
"If the SII values are not provided as InteractionValues, the approximator "
"or the number of players and the maximum order of the approximation must be "
"provided."
)


def _calculate_nsii_from_sii(
sii_values: np.ndarray[float],
n: int,
max_order: int,
interaction_lookup: Optional[dict] = None,
) -> np.ndarray[float]:
"""Calculates the nSII values from the SII values.

Args:
sii_values: The SII values to transform.
n: The number of players.
max_order: The maximum order of the approximation.
interaction_lookup: A dictionary that maps interactions to their index in the values
vector. If `interaction_lookup` is not provided, it is computed from the `n_players`,
`min_order`, and `max_order` parameters. Defaults to `None`.

Returns:
The nSII values.
"""
# compute nSII values from SII values
bernoulli_numbers = bernoulli(max_order)
nsii_values = np.zeros_like(sii_values)
# all subsets S with 1 <= |S| <= max_order
for subset in powerset(set(range(n)), min_size=1, max_size=max_order):
interaction_index = interaction_lookup[subset]
interaction_size = len(subset)
n_sii_value = sii_values[interaction_index]
# go over all subsets T of length |S| + 1, ..., n that contain S
for T in powerset(set(range(n)), min_size=interaction_size + 1, max_size=max_order):
if set(subset).issubset(T):
effect_index = interaction_lookup[T] # get the index of T
effect_value = sii_values[effect_index] # get the effect of T
bernoulli_factor = bernoulli_numbers[len(T) - interaction_size]
n_sii_value += bernoulli_factor * effect_value
nsii_values[interaction_index] = n_sii_value
return nsii_values


def convert_nsii_into_one_dimension(
n_sii_values: InteractionValues,
) -> tuple[np.ndarray[float], np.ndarray[float]]:
"""Converts the nSII values into one-dimensional values.

Args:
n_sii_values: The nSII values to convert.

Returns:
The positive and negative one-dimensional values.
"""
if n_sii_values.index != "nSII":
raise ValueError(
"Only nSII values can be converted into one-dimensional nSII values. Please use the "
"transforms_sii_to_nsii method to convert SII values into nSII values."
)
max_order = n_sii_values.max_order
min_order = n_sii_values.min_order
n = n_sii_values.n_players

pos_nsii_values = np.zeros(shape=(n,), dtype=float)
neg_nsii_values = np.zeros(shape=(n,), dtype=float)

for subset in powerset(set(range(n)), min_size=min_order, max_size=max_order):
n_sii_value = n_sii_values[subset] / len(subset) # distribute uniformly
for player in subset:
if n_sii_value >= 0:
pos_nsii_values[player] += n_sii_value
else:
neg_nsii_values[player] += n_sii_value
return pos_nsii_values, neg_nsii_values


def _generate_interaction_lookup(n: int, min_order: int, max_order: int) -> dict[tuple[int], int]:
"""Generates a lookup dictionary for interactions.

Args:
n: The number of players.
min_order: The minimum order of the approximation.
max_order: The maximum order of the approximation.

Returns:
A dictionary that maps interactions to their index in the values vector.
"""
interaction_lookup = {
interaction: i
for i, interaction in enumerate(
powerset(set(range(n)), min_size=min_order, max_size=max_order)
)
}
return interaction_lookup
Loading