diff --git a/shapiq/explainer/_base.py b/shapiq/explainer/_base.py index 39f1e6af..37e79860 100644 --- a/shapiq/explainer/_base.py +++ b/shapiq/explainer/_base.py @@ -46,4 +46,4 @@ def explain(self, x: np.ndarray) -> InteractionValues: def predict(self, x: np.ndarray) -> np.ndarray: """Provides a unified prediction interface.""" - return self._predict_function(x) \ No newline at end of file + return self._predict_function(self.model, x) \ No newline at end of file diff --git a/shapiq/explainer/utils.py b/shapiq/explainer/utils.py index 5068c7cd..6ea77759 100644 --- a/shapiq/explainer/utils.py +++ b/shapiq/explainer/utils.py @@ -18,7 +18,7 @@ def get_predict_function_and_model_type(model, model_class): _model_type = "tabular" # default if callable(model): - _predict_function = model + _predict_function = predict_callable # sklearn if model_class in [ @@ -76,6 +76,10 @@ def get_predict_function_and_model_type(model, model_class): return _predict_function, _model_type +def predict_callable(m, d): + return m(d) + + def predict_default(m, d): return m.predict(d) diff --git a/shapiq/games/imputer/base.py b/shapiq/games/imputer/base.py index b159eaec..51f2ac9e 100644 --- a/shapiq/games/imputer/base.py +++ b/shapiq/games/imputer/base.py @@ -6,6 +6,7 @@ import numpy as np from ..base import Game +from ...explainer import utils class Imputer(Game): @@ -23,17 +24,25 @@ class Imputer(Game): @abstractmethod def __init__( self, - model: Callable[[np.ndarray], np.ndarray], + model, data: np.ndarray, categorical_features: list[int] = None, random_state: Optional[int] = None, ) -> None: - self._model = model - self._data = data - self._n_features = self._data.shape[1] + if callable(model): + self._predict_function = utils.predict_callable + else: # shapiq.Explainer + self._predict_function = model._predict_function + self.model = model + self.data = data + self._n_features = self.data.shape[1] self._cat_features: list = [] if categorical_features is None else categorical_features self._random_state = random_state self._rng = np.random.default_rng(self._random_state) # the normalization_value needs to be set in the subclass super().__init__(n_players=self._n_features, normalize=False) + + def predict(self, x: np.ndarray) -> np.ndarray: + """Provides a unified prediction interface.""" + return self._predict_function(self.model, x) \ No newline at end of file diff --git a/shapiq/games/imputer/marginal_imputer.py b/shapiq/games/imputer/marginal_imputer.py index 430d354b..fe26c077 100644 --- a/shapiq/games/imputer/marginal_imputer.py +++ b/shapiq/games/imputer/marginal_imputer.py @@ -36,7 +36,7 @@ class MarginalImputer(Imputer): def __init__( self, - model: Callable[[np.ndarray], np.ndarray], + model, data: np.ndarray, x: Optional[np.ndarray] = None, sample_replacements: bool = False, @@ -51,7 +51,7 @@ def __init__( self._sample_replacements = sample_replacements self._sample_size: int = sample_size self.replacement_data: np.ndarray = np.zeros((1, self._n_features)) # will be overwritten - self.init_background(self._data) + self.init_background(self.data) self._x: np.ndarray = np.zeros((1, self._n_features)) # will be overwritten @ fit if x is not None: self.fit(x) @@ -77,7 +77,7 @@ def value_function(self, coalitions: np.ndarray) -> np.ndarray: if not self._sample_replacements: replacement_data = np.tile(self.replacement_data, (n_subsets, 1)) data[~coalitions] = replacement_data[~coalitions] - outputs = self._model(data) + outputs = self.predict(data) else: # sampling from background returning array of shape (sample_size, n_subsets, n_features) replacement_data = self._sample_replacement_values(coalitions) @@ -85,26 +85,26 @@ def value_function(self, coalitions: np.ndarray) -> np.ndarray: for i in range(self._sample_size): replacements = replacement_data[i].reshape(n_subsets, self._n_features) data[~coalitions] = replacements[~coalitions] - outputs[i] = self._model(data) + outputs[i] = self.predict(data) outputs = np.mean(outputs, axis=0) # average over the samples return outputs - def init_background(self, x_background: np.ndarray) -> "MarginalImputer": + def init_background(self, data: np.ndarray) -> "MarginalImputer": """Initializes the imputer to the background data. Args: - x_background: The background data to use for the imputer. The shape of the array must + data: The background data to use for the imputer. The shape of the array must be (n_samples, n_features). Returns: The initialized imputer. """ if self._sample_replacements: - self.replacement_data = x_background + self.replacement_data = data else: self.replacement_data = np.zeros((1, self._n_features), dtype=object) for feature in range(self._n_features): - feature_column = x_background[:, feature] + feature_column = data[:, feature] if feature in self._cat_features: # get mode for categorical features counts = np.unique(feature_column, return_counts=True) @@ -157,10 +157,10 @@ def _calc_empty_prediction(self) -> float: The empty prediction. """ if self._sample_replacements: - shuffled_background = self._rng.permutation(self._data) - empty_predictions = self._model(shuffled_background) + shuffled_background = self._rng.permutation(self.data) + empty_predictions = self.predict(shuffled_background) empty_prediction = float(np.mean(empty_predictions)) return empty_prediction - empty_prediction = self._model(self.replacement_data) + empty_prediction = self.predict(self.replacement_data) empty_prediction = float(empty_prediction) return empty_prediction diff --git a/shapiq/games/tabular.py b/shapiq/games/tabular.py index 09a14555..9683556d 100644 --- a/shapiq/games/tabular.py +++ b/shapiq/games/tabular.py @@ -17,11 +17,11 @@ class LocalExplanation(Game): of the data point (for more information see `MarginalImputer`). Args: - data: The background data used to fit the imputer. Should be a 2d matrix of shape - (n_samples, n_features). model: The model to explain as a callable function expecting data points as input and returning the model's predictions. The input should be a 2d matrix of shape (n_samples, n_features) and the output a 1d matrix of shape (n_samples). + data: The background data used to fit the imputer. Should be a 2d matrix of shape + (n_samples, n_features). x: The data point to explain. Can be an index of the background data or a 1d matrix of shape (n_features). random_state: The random state to use for the imputer. Defaults to `None`. @@ -56,25 +56,25 @@ class LocalExplanation(Game): def __init__( self, - data: np.ndarray, model: Callable[[np.ndarray], np.ndarray], + data: np.ndarray, x: Union[np.ndarray, int], random_state: Optional[int] = None, normalize: bool = True, ) -> None: # set attributes - self._model = model - self._data = data + self.model = model + self.data = data # set explanation point if isinstance(x, int): - x = self._data[x] + x = self.data[x] self.x = x # init the imputer which serves as the workhorse of this Game self._imputer = MarginalImputer( - model=self._model, - data=self._data, + model=self.model, + data=self.data, x=x, random_state=random_state, normalize=False, diff --git a/tests/test_abstract_classes.py b/tests/test_abstract_classes.py index 478bae7e..a8d455d7 100644 --- a/tests/test_abstract_classes.py +++ b/tests/test_abstract_classes.py @@ -39,8 +39,8 @@ def test_imputer(): model = lambda x: x data = np.asarray([[1, 2, 3], [4, 5, 6]]) imputer = concreter(Imputer)(model, data) - assert imputer._model == model - assert np.all(imputer._data == data) + assert imputer.model == model + assert np.all(imputer.data == data) assert imputer._n_features == 3 assert imputer._cat_features == [] assert imputer._random_state is None