Skip to content

Commit

Permalink
Improve testing (#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
heinzll authored Dec 20, 2024
1 parent fb47de6 commit ccc0b99
Show file tree
Hide file tree
Showing 51 changed files with 929 additions and 276 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
- adds the `upset_plot` function to the `plot` module to visualize the interactions of higher-order [#290](https://github.com/mmschlk/shapiq/issues/290)
- adds support for IsoForest models to explainer and tree explainer [#278](https://github.com/mmschlk/shapiq/issues/278)
- adds support for sub-selection of players in the interaction values data class [#276](https://github.com/mmschlk/shapiq/issues/276) which allows retrieving interaction values for a subset of players
- refactors game theory computations like `ExactComputer`, `MoebiusConverter`, `core`, among others to be more modular and flexible into the `game_theory` module [#258](https://github.com/mmschlk/shapiq/issues/258)
- improves quality of the tests by adding many more semantic tests to the different interaction indices and computations [#285](https://github.com/mmschlk/shapiq/pull/285)

### v1.1.1 (2024-11-13)

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ tqdm==4.67.1
torch==2.5.1
torchvision==0.20.1
transformers==4.46.3
tensorflow==2.18.0
tf-keras==2.18.0
xgboost==2.1.3
numpy==1.26.4
requests==2.32.3
Expand Down
6 changes: 3 additions & 3 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
# dataset functions
from .datasets import load_adult_census, load_bike_sharing, load_california_housing

# exact computer classes
from .exact import ExactComputer

# explainer classes
from .explainer import Explainer, TabularExplainer, TreeExplainer

# exact computer classes
from .game_theory.exact import ExactComputer

# game classes
# imputer classes
from .games import BaselineImputer, ConditionalImputer, Game, MarginalImputer
Expand Down
12 changes: 6 additions & 6 deletions shapiq/approximator/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

import numpy as np

from shapiq.approximator.sampling import CoalitionSampler
from shapiq.indices import (
from ..approximator.sampling import CoalitionSampler
from ..game_theory.indices import (
AVAILABLE_INDICES_FOR_APPROXIMATION,
get_computation_index,
is_empty_value_the_baseline,
is_index_aggregated,
)
from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import generate_interaction_lookup
from ..interaction_values import InteractionValues
from ..utils.sets import generate_interaction_lookup

__all__ = [
"Approximator",
Expand Down Expand Up @@ -318,7 +318,7 @@ def aggregate_interaction_values(
Returns:
The aggregated interaction values.
"""
from ..aggregation import aggregate_interaction_values
from shapiq.game_theory.aggregation import aggregate_interaction_values

if player_set is not None:
raise NotImplementedError(
Expand All @@ -339,6 +339,6 @@ def aggregate_to_one_dimension(
Returns:
tuple[np.ndarray, np.ndarray]: The positive and negative aggregated values.
"""
from ..aggregation import aggregate_to_one_dimension
from shapiq.game_theory.aggregation import aggregate_to_one_dimension

return aggregate_to_one_dimension(interaction_values)
4 changes: 2 additions & 2 deletions shapiq/approximator/marginals/owen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import numpy as np

from shapiq.approximator._base import Approximator
from shapiq.interaction_values import InteractionValues
from ...interaction_values import InteractionValues
from .._base import Approximator


class OwenSamplingSV(Approximator):
Expand Down
4 changes: 2 additions & 2 deletions shapiq/approximator/marginals/stratified.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import numpy as np

from shapiq.approximator._base import Approximator
from shapiq.interaction_values import InteractionValues
from ...interaction_values import InteractionValues
from .._base import Approximator


class StratifiedSamplingSV(Approximator):
Expand Down
8 changes: 4 additions & 4 deletions shapiq/approximator/montecarlo/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import numpy as np
from scipy.special import binom, factorial

from shapiq.approximator._base import Approximator
from shapiq.indices import AVAILABLE_INDICES_MONTE_CARLO
from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import powerset
from ...game_theory.indices import AVAILABLE_INDICES_MONTE_CARLO
from ...interaction_values import InteractionValues
from ...utils.sets import powerset
from .._base import Approximator


class MonteCarlo(Approximator):
Expand Down
6 changes: 3 additions & 3 deletions shapiq/approximator/permutation/stii.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import numpy as np
import scipy as sp

from shapiq.approximator._base import Approximator
from shapiq.interaction_values import InteractionValues
from shapiq.utils import get_explicit_subsets, powerset
from ...interaction_values import InteractionValues
from ...utils import get_explicit_subsets, powerset
from .._base import Approximator


class PermutationSamplingSTII(Approximator):
Expand Down
4 changes: 2 additions & 2 deletions shapiq/approximator/permutation/sv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import numpy as np

from shapiq.approximator._base import Approximator
from shapiq.interaction_values import InteractionValues
from ...interaction_values import InteractionValues
from .._base import Approximator


class PermutationSamplingSV(Approximator):
Expand Down
8 changes: 4 additions & 4 deletions shapiq/approximator/regression/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import numpy as np
from scipy.special import bernoulli, binom

from shapiq.approximator._base import Approximator
from shapiq.indices import AVAILABLE_INDICES_REGRESSION
from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import powerset
from ...game_theory.indices import AVAILABLE_INDICES_REGRESSION
from ...interaction_values import InteractionValues
from ...utils.sets import powerset
from .._base import Approximator


class Regression(Approximator):
Expand Down
2 changes: 1 addition & 1 deletion shapiq/approximator/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from scipy.special import binom

from shapiq.utils.sets import powerset
from ..utils.sets import powerset


class CoalitionSampler:
Expand Down
4 changes: 2 additions & 2 deletions shapiq/explainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import numpy as np

from shapiq.explainer.utils import get_explainers, get_predict_function_and_model_type, print_class
from shapiq.interaction_values import InteractionValues
from ..explainer.utils import get_explainers, get_predict_function_and_model_type, print_class
from ..interaction_values import InteractionValues


class Explainer:
Expand Down
8 changes: 4 additions & 4 deletions shapiq/explainer/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from shapiq.approximator import (
from ..approximator import (
SHAPIQ,
SVARMIQ,
InconsistentKernelSHAPIQ,
Expand All @@ -17,9 +17,9 @@
RegressionFSII,
UnbiasedKernelSHAP,
)
from shapiq.approximator._base import Approximator
from shapiq.explainer._base import Explainer
from shapiq.interaction_values import InteractionValues
from ..approximator._base import Approximator
from ..explainer._base import Explainer
from ..interaction_values import InteractionValues

APPROXIMATOR_CONFIGURATIONS = {
"regression": {
Expand Down
3 changes: 1 addition & 2 deletions shapiq/explainer/tree/conversion/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

import pandas as pd

from shapiq.utils.types import Model

from ....utils.types import Model
from ..base import TreeModel


Expand Down
53 changes: 35 additions & 18 deletions shapiq/explainer/tree/conversion/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from typing import Optional

import numpy as np
from sklearn.ensemble._iforest import _average_path_length

from shapiq.utils import safe_isinstance
from shapiq.utils.types import Model

from ....utils import safe_isinstance
from ....utils.types import Model
from ..base import TreeModel


Expand Down Expand Up @@ -77,11 +75,20 @@ def convert_sklearn_tree(
)


def average_path_length(isolation_forest):
def average_path_length(isolation_forest: Model) -> float:
"""Compute the average path length of the isolation forest.
Args:
isolation_forest: The isolation forest model.
Returns:
The average path length of the isolation forest.
"""
from sklearn.ensemble._iforest import _average_path_length

max_samples = isolation_forest._max_samples
average_path_length = _average_path_length(
[max_samples]
) # NOTE: _average_path_length func is equivalent to equation 1 in Isolation Forest paper Lui2008
# NOTE: _average_path_length func is equivalent to equation 1 in Isolation Forest paper Lui2008
average_path_length = _average_path_length([max_samples])
return average_path_length


Expand All @@ -99,33 +106,27 @@ def convert_sklearn_isolation_forest(
scaling = 1.0 / len(tree_model.estimators_)

return [
# convert_isolation_tree_shap_isotree(tree, features, scaling=scaling)
convert_isolation_tree(tree, features, scaling=scaling)
for tree, features in zip(tree_model.estimators_, tree_model.estimators_features_)
]


def convert_isolation_tree(
tree_model: Model,
tree_features,
class_label: Optional[int] = None,
tree_features: np.ndarray,
scaling: float = 1.0,
average_path_length: float = 1.0, # TODO fix default value
) -> TreeModel:
"""Convert a scikit-learn decision tree to the format used by shapiq.
Args:
tree_model: The scikit-learn decision tree model to convert.
class_label: The class label of the model to explain. Only used for classification models.
Defaults to ``1``.
tree_features: The features used in the tree.
scaling: The scaling factor for the tree values.
Returns:
The converted decision tree model.
"""
output_type = "raw"
tree_values = tree_model.tree_.value.copy()
tree_values = tree_values.flatten()
features_updated, values_updated = isotree_value_traversal(
tree_model.tree_, tree_features, normalize=False, scaling=1.0
)
Expand All @@ -145,8 +146,24 @@ def convert_isolation_tree(


def isotree_value_traversal(
tree, tree_features, normalize=False, scaling=1.0, data=None, data_missing=None
):
tree: Model,
tree_features: np.ndarray,
normalize: bool = False,
scaling: float = 1.0,
) -> tuple[np.ndarray, np.ndarray]:
"""Traverse the tree and calculate the average path length for each node.
Args:
tree: The tree to traverse.
tree_features: The features used in the tree.
normalize: Whether to normalize the values.
scaling: The scaling factor for the values.
Returns:
The updated features and values.
"""
from sklearn.ensemble._iforest import _average_path_length

features = tree.feature.copy()
corrected_values = tree.value.copy()
if safe_isinstance(tree, "sklearn.tree._tree.Tree"):
Expand Down
3 changes: 1 addition & 2 deletions shapiq/explainer/tree/conversion/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import numpy as np
import pandas as pd

from shapiq.utils.types import Model

from ....utils.types import Model
from ..base import TreeModel


Expand Down
7 changes: 4 additions & 3 deletions shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

import numpy as np

from shapiq.explainer._base import Explainer
from shapiq.interaction_values import InteractionValues

from ...interaction_values import InteractionValues
from .._base import Explainer
from .treeshapiq import TreeModel, TreeSHAPIQ
from .validation import validate_tree_model

Expand Down Expand Up @@ -77,6 +76,8 @@ def __init__(
self.baseline_value = self._compute_baseline_value()

def explain(self, x: np.ndarray) -> InteractionValues:
if len(x.shape) != 1:
raise TypeError("explain expects a single instance, not a batch.")
# run treeshapiq for all trees
interaction_values: list[InteractionValues] = []
for explainer in self._treeshapiq_explainers:
Expand Down
4 changes: 2 additions & 2 deletions shapiq/explainer/tree/treeshapiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np
import scipy as sp

from ...aggregation import aggregate_interaction_values
from ...indices import get_computation_index
from ...game_theory.aggregation import aggregate_interaction_values
from ...game_theory.indices import get_computation_index
from ...interaction_values import InteractionValues
from ...utils.sets import generate_interaction_lookup, powerset
from .base import EdgeTree, TreeModel
Expand Down
1 change: 1 addition & 0 deletions shapiq/explainer/tree/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"lightgbm.sklearn.LGBMRegressor",
"lightgbm.sklearn.LGBMClassifier",
"lightgbm.basic.Booster",
# xboost?
}


Expand Down
30 changes: 30 additions & 0 deletions shapiq/game_theory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""conversions of interaction values to different indices
"""

from .aggregation import aggregate_interaction_values
from .core import egalitarian_least_core
from .exact import ExactComputer, get_bernoulli_weights
from .indices import (
ALL_AVAILABLE_CONCEPTS,
get_computation_index,
index_generalizes_bv,
index_generalizes_sv,
is_empty_value_the_baseline,
is_index_aggregated,
)
from .moebius_converter import MoebiusConverter

__all__ = [
"ExactComputer",
"aggregate_interaction_values",
"get_bernoulli_weights",
"ALL_AVAILABLE_CONCEPTS",
"index_generalizes_sv",
"index_generalizes_bv",
"get_computation_index",
"is_index_aggregated",
"is_empty_value_the_baseline",
"egalitarian_least_core",
"MoebiusConverter",
]
# todo complete list
4 changes: 2 additions & 2 deletions shapiq/aggregation.py → shapiq/game_theory/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np
import scipy as sp

from .interaction_values import InteractionValues
from .utils.sets import powerset
from ..interaction_values import InteractionValues
from ..utils.sets import powerset


def _change_index(index: str) -> str:
Expand Down
Loading

0 comments on commit ccc0b99

Please sign in to comment.