Skip to content

Commit

Permalink
Merge pull request #51 from mmschlk/26-add-treeshap-iq-explainer
Browse files Browse the repository at this point in the history
adds TreeExplainer with TreeSHAP-IQ
  • Loading branch information
mmschlk authored Mar 18, 2024
2 parents 596acd3 + e5f98bd commit 5478906
Show file tree
Hide file tree
Showing 51 changed files with 2,357 additions and 418 deletions.
32 changes: 17 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,14 @@ Explain your models with Shapley interaction values like the k-SII values:
```python
# train a model
from sklearn.ensemble import RandomForestRegressor

model = RandomForestRegressor(n_estimators=50, random_state=42)
model.fit(x_train, y_train)

# explain with k-SII interaction scores
from shapiq import InteractionExplainer
explainer = InteractionExplainer(
from shapiq import TabularExplainer

explainer = TabularExplainer(
model=model.predict,
background_data=x_train,
index="k-SII",
Expand All @@ -88,19 +90,19 @@ explainer = InteractionExplainer(
interaction_values = explainer.explain(x_explain, budget=2000)
print(interaction_values)

>>> InteractionValues(
>>> index=k-SII, max_order=2, min_order=1, estimated=True, estimation_budget=2000,
>>> values={
>>> (0,): -91.0403, # main effect for feature 0
>>> (1,): 4.1264, # main effect for feature 1
>>> (2,): -0.4724, # main effect for feature 2
>>> ...
>>> (0, 1): -0.8073, # 2-way interaction for feature 0 and 1
>>> (0, 2): 2.469, # 2-way interaction for feature 0 and 2
>>> ...
>>> (10, 11): 0.4057 # 2-way interaction for feature 10 and 11
>>> }
>>> )
>> > InteractionValues(
>> > index = k - SII, max_order = 2, min_order = 1, estimated = True, estimation_budget = 2000,
>> > values = {
>> > (0,): -91.0403, # main effect for feature 0
>> > (1,): 4.1264, # main effect for feature 1
>> > (2,): -0.4724, # main effect for feature 2
>> > ...
>> > (0, 1): -0.8073, # 2-way interaction for feature 0 and 1
>> > (0, 2): 2.469, # 2-way interaction for feature 0 and 2
>> > ...
>> > (10, 11): 0.4057 # 2-way interaction for feature 10 and 11
>> >}
>> > )
```

## 📊 Visualize your Interactions
Expand Down
12 changes: 6 additions & 6 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@
from .datasets import load_bike

# explainer classes
from .explainer import InteractionExplainer
from .explainer import TabularExplainer, TreeExplainer

# game classes
from .games import DummyGame
from .interaction_values import InteractionValues

# plotting functions
from .plot import network_plot, stacked_bar_plot

# public utils functions
from .utils import ( # sets.py # tree.py
get_conditional_sample_weights,
get_explicit_subsets,
get_parent_array,
powerset,
safe_isinstance,
split_subsets_budget,
Expand All @@ -36,14 +35,17 @@
__all__ = [
# version
"__version__",
# base
"InteractionValues",
# approximators
"ShapIQ",
"PermutationSamplingSII",
"PermutationSamplingSTI",
"RegressionSII",
"RegressionFSI",
# explainers
"InteractionExplainer",
"TabularExplainer",
"TreeExplainer",
# games
"DummyGame",
# plots
Expand All @@ -53,8 +55,6 @@
"powerset",
"get_explicit_subsets",
"split_subsets_budget",
"get_conditional_sample_weights",
"get_parent_array",
"safe_isinstance",
# datasets
"load_bike",
Expand Down
10 changes: 6 additions & 4 deletions shapiq/approximator/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from typing import Callable, Optional

import numpy as np
from approximator._config import AVAILABLE_INDICES
from approximator._interaction_values import InteractionValues
from approximator._utils import _generate_interaction_lookup
from interaction_values import InteractionValues

from shapiq.utils.sets import generate_interaction_lookup

from ._config import AVAILABLE_INDICES

__all__ = [
"Approximator",
Expand Down Expand Up @@ -65,7 +67,7 @@ def __init__(
self.max_order: int = max_order
self.min_order: int = self.max_order if self.top_order else 1
self.iteration_cost: int = 1 # default value, can be overwritten by subclasses
self._interaction_lookup = _generate_interaction_lookup(
self._interaction_lookup = generate_interaction_lookup(
self.n, self.min_order, self.max_order
)
self._random_state: Optional[int] = random_state
Expand Down
145 changes: 0 additions & 145 deletions shapiq/approximator/_interaction_values.py

This file was deleted.

21 changes: 0 additions & 21 deletions shapiq/approximator/_utils.py

This file was deleted.

16 changes: 9 additions & 7 deletions shapiq/approximator/k_sii.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from typing import Optional, Union

import numpy as np
from approximator._base import Approximator
from approximator._interaction_values import InteractionValues
from approximator._utils import _generate_interaction_lookup
from interaction_values import InteractionValues
from scipy.special import bernoulli

from shapiq.utils import powerset
from shapiq.approximator._base import Approximator
from shapiq.utils import generate_interaction_lookup, powerset


class KShapleyMixin:
Expand Down Expand Up @@ -86,7 +85,7 @@ def transforms_sii_to_ksii(
)
elif n is not None and max_order is not None:
if interaction_lookup is None:
interaction_lookup = _generate_interaction_lookup(n, 1, max_order)
interaction_lookup = generate_interaction_lookup(n, 1, max_order)
return _calculate_ksii_from_sii(sii_values, n, max_order, interaction_lookup)
else:
raise ValueError(
Expand Down Expand Up @@ -120,9 +119,12 @@ def _calculate_ksii_from_sii(
nsii_values = np.zeros_like(sii_values)
# all subsets S with 1 <= |S| <= max_order
for subset in powerset(set(range(n)), min_size=1, max_size=max_order):
interaction_index = interaction_lookup[subset]
interaction_size = len(subset)
ksii_value = sii_values[interaction_index]
try:
interaction_index = interaction_lookup[subset]
ksii_value = sii_values[interaction_index]
except KeyError:
continue # a zero value is not scaled # TODO: verify this
# go over all subsets T of length |S| + 1, ..., n that contain S
for T in powerset(set(range(n)), min_size=interaction_size + 1, max_size=max_order):
if set(subset).issubset(T):
Expand Down
2 changes: 1 addition & 1 deletion shapiq/approximator/permutation/sii.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import numpy as np
from approximator._base import Approximator
from approximator._interaction_values import InteractionValues
from approximator.k_sii import KShapleyMixin
from interaction_values import InteractionValues
from utils import powerset


Expand Down
2 changes: 1 addition & 1 deletion shapiq/approximator/permutation/sti.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
from approximator._base import Approximator
from approximator._interaction_values import InteractionValues
from interaction_values import InteractionValues
from scipy.special import binom
from utils import get_explicit_subsets, powerset

Expand Down
2 changes: 1 addition & 1 deletion shapiq/approximator/regression/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import numpy as np
from approximator._base import Approximator
from approximator._interaction_values import InteractionValues
from approximator.sampling import ShapleySamplingMixin
from interaction_values import InteractionValues
from scipy.special import bernoulli, binom
from utils import powerset

Expand Down
2 changes: 1 addition & 1 deletion shapiq/approximator/shapiq/shapiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import numpy as np
from approximator._base import Approximator
from approximator._interaction_values import InteractionValues
from approximator.k_sii import KShapleyMixin
from approximator.sampling import ShapleySamplingMixin
from interaction_values import InteractionValues
from utils import powerset

AVAILABLE_INDICES_SHAPIQ = {"SII", "STI", "FSI", "k-SII"}
Expand Down
5 changes: 3 additions & 2 deletions shapiq/explainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains the explainer for the shapiq package."""

from .interaction import InteractionExplainer

from .tabular import TabularExplainer
from .tree import TreeExplainer

__all__ = ["InteractionExplainer", "TreeExplainer"]
__all__ = ["TabularExplainer", "TreeExplainer"]
Loading

0 comments on commit 5478906

Please sign in to comment.