Skip to content

Commit

Permalink
fix predict in imputer
Browse files Browse the repository at this point in the history
`x_background` -> `data`
  • Loading branch information
hbaniecki committed Apr 8, 2024
1 parent a8f9810 commit b3920ca
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 27 deletions.
2 changes: 1 addition & 1 deletion shapiq/explainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def explain(self, x: np.ndarray) -> InteractionValues:

def predict(self, x: np.ndarray) -> np.ndarray:
"""Provides a unified prediction interface."""
return self._predict_function(x)
return self._predict_function(self.model, x)
6 changes: 5 additions & 1 deletion shapiq/explainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_predict_function_and_model_type(model, model_class):
_model_type = "tabular" # default

if callable(model):
_predict_function = model
_predict_function = predict_callable

# sklearn
if model_class in [
Expand Down Expand Up @@ -76,6 +76,10 @@ def get_predict_function_and_model_type(model, model_class):
return _predict_function, _model_type


def predict_callable(m, d):
return m(d)


def predict_default(m, d):
return m.predict(d)

Expand Down
17 changes: 13 additions & 4 deletions shapiq/games/imputer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from ..base import Game
from ...explainer import utils


class Imputer(Game):
Expand All @@ -23,17 +24,25 @@ class Imputer(Game):
@abstractmethod
def __init__(
self,
model: Callable[[np.ndarray], np.ndarray],
model,
data: np.ndarray,
categorical_features: list[int] = None,
random_state: Optional[int] = None,
) -> None:
self._model = model
self._data = data
self._n_features = self._data.shape[1]
if callable(model):
self._predict_function = utils.predict_callable
else: # shapiq.Explainer
self._predict_function = model._predict_function
self.model = model
self.data = data
self._n_features = self.data.shape[1]
self._cat_features: list = [] if categorical_features is None else categorical_features
self._random_state = random_state
self._rng = np.random.default_rng(self._random_state)

# the normalization_value needs to be set in the subclass
super().__init__(n_players=self._n_features, normalize=False)

def predict(self, x: np.ndarray) -> np.ndarray:
"""Provides a unified prediction interface."""
return self._predict_function(self.model, x)
22 changes: 11 additions & 11 deletions shapiq/games/imputer/marginal_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MarginalImputer(Imputer):

def __init__(
self,
model: Callable[[np.ndarray], np.ndarray],
model,
data: np.ndarray,
x: Optional[np.ndarray] = None,
sample_replacements: bool = False,
Expand All @@ -51,7 +51,7 @@ def __init__(
self._sample_replacements = sample_replacements
self._sample_size: int = sample_size
self.replacement_data: np.ndarray = np.zeros((1, self._n_features)) # will be overwritten
self.init_background(self._data)
self.init_background(self.data)
self._x: np.ndarray = np.zeros((1, self._n_features)) # will be overwritten @ fit
if x is not None:
self.fit(x)
Expand All @@ -77,34 +77,34 @@ def value_function(self, coalitions: np.ndarray) -> np.ndarray:
if not self._sample_replacements:
replacement_data = np.tile(self.replacement_data, (n_subsets, 1))
data[~coalitions] = replacement_data[~coalitions]
outputs = self._model(data)
outputs = self.predict(data)
else:
# sampling from background returning array of shape (sample_size, n_subsets, n_features)
replacement_data = self._sample_replacement_values(coalitions)
outputs = np.zeros((self._sample_size, n_subsets))
for i in range(self._sample_size):
replacements = replacement_data[i].reshape(n_subsets, self._n_features)
data[~coalitions] = replacements[~coalitions]
outputs[i] = self._model(data)
outputs[i] = self.predict(data)
outputs = np.mean(outputs, axis=0) # average over the samples
return outputs

def init_background(self, x_background: np.ndarray) -> "MarginalImputer":
def init_background(self, data: np.ndarray) -> "MarginalImputer":
"""Initializes the imputer to the background data.
Args:
x_background: The background data to use for the imputer. The shape of the array must
data: The background data to use for the imputer. The shape of the array must
be (n_samples, n_features).
Returns:
The initialized imputer.
"""
if self._sample_replacements:
self.replacement_data = x_background
self.replacement_data = data
else:
self.replacement_data = np.zeros((1, self._n_features), dtype=object)
for feature in range(self._n_features):
feature_column = x_background[:, feature]
feature_column = data[:, feature]
if feature in self._cat_features:
# get mode for categorical features
counts = np.unique(feature_column, return_counts=True)
Expand Down Expand Up @@ -157,10 +157,10 @@ def _calc_empty_prediction(self) -> float:
The empty prediction.
"""
if self._sample_replacements:
shuffled_background = self._rng.permutation(self._data)
empty_predictions = self._model(shuffled_background)
shuffled_background = self._rng.permutation(self.data)
empty_predictions = self.predict(shuffled_background)
empty_prediction = float(np.mean(empty_predictions))
return empty_prediction
empty_prediction = self._model(self.replacement_data)
empty_prediction = self.predict(self.replacement_data)
empty_prediction = float(empty_prediction)
return empty_prediction
16 changes: 8 additions & 8 deletions shapiq/games/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ class LocalExplanation(Game):
of the data point (for more information see `MarginalImputer`).
Args:
data: The background data used to fit the imputer. Should be a 2d matrix of shape
(n_samples, n_features).
model: The model to explain as a callable function expecting data points as input and
returning the model's predictions. The input should be a 2d matrix of shape
(n_samples, n_features) and the output a 1d matrix of shape (n_samples).
data: The background data used to fit the imputer. Should be a 2d matrix of shape
(n_samples, n_features).
x: The data point to explain. Can be an index of the background data or a 1d matrix
of shape (n_features).
random_state: The random state to use for the imputer. Defaults to `None`.
Expand Down Expand Up @@ -56,25 +56,25 @@ class LocalExplanation(Game):

def __init__(
self,
data: np.ndarray,
model: Callable[[np.ndarray], np.ndarray],
data: np.ndarray,
x: Union[np.ndarray, int],
random_state: Optional[int] = None,
normalize: bool = True,
) -> None:
# set attributes
self._model = model
self._data = data
self.model = model
self.data = data

# set explanation point
if isinstance(x, int):
x = self._data[x]
x = self.data[x]
self.x = x

# init the imputer which serves as the workhorse of this Game
self._imputer = MarginalImputer(
model=self._model,
data=self._data,
model=self.model,
data=self.data,
x=x,
random_state=random_state,
normalize=False,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_abstract_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def test_imputer():
model = lambda x: x
data = np.asarray([[1, 2, 3], [4, 5, 6]])
imputer = concreter(Imputer)(model, data)
assert imputer._model == model
assert np.all(imputer._data == data)
assert imputer.model == model
assert np.all(imputer.data == data)
assert imputer._n_features == 3
assert imputer._cat_features == []
assert imputer._random_state is None
Expand Down

0 comments on commit b3920ca

Please sign in to comment.