Skip to content

Commit

Permalink
Merge pull request #59 from mmschlk/tree-explainer-random-forrest
Browse files Browse the repository at this point in the history
Bugfix and Support for Random Forests in TreeExplainer
  • Loading branch information
mmschlk authored Mar 20, 2024
2 parents 4cb0b32 + acb9bae commit eb91bd8
Show file tree
Hide file tree
Showing 14 changed files with 348 additions and 342 deletions.
2 changes: 1 addition & 1 deletion shapiq/explainer/tree/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def reduce_feature_complexity(self) -> None:
Feature '8' is 'renamed' to '2' such that in the internal representation a one-hot vector
(and matrices) of length 3 suffices to represent the feature indices.
"""
if self.n_features_in_tree < self.max_feature_id:
if self.n_features_in_tree < self.max_feature_id + 1:
new_feature_ids = set(range(self.n_features_in_tree))
mapping_old_new = {old_id: new_id for new_id, old_id in enumerate(self.feature_ids)}
mapping_new_old = {new_id: old_id for new_id, old_id in enumerate(self.feature_ids)}
Expand Down
33 changes: 14 additions & 19 deletions shapiq/explainer/tree/conversion/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
"""This module contains functions for converting scikit-learn decision trees to the format used by
shapiq."""

from typing import Optional, Union
from typing import Optional

import numpy as np
from explainer.tree.base import TreeModel

from shapiq.utils import safe_isinstance

try:
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
except ImportError:
pass
from shapiq.utils.types import Model


def convert_sklearn_forest(
tree_model: Union["RandomForestRegressor", "RandomForestClassifier"],
class_label: int = 0,
output_type: Optional[str] = None,
tree_model: Model,
class_label: Optional[int] = None,
output_type: str = "raw",
) -> list[TreeModel]:
"""Transforms a scikit-learn random forest to the format used by shapiq.
Expand All @@ -33,8 +28,6 @@ def convert_sklearn_forest(
The converted random forest model.
"""
scaling = 1.0 / len(tree_model.estimators_)
if not safe_isinstance(tree_model, "sklearn.ensemble.RandomForestClassifier"):
output_type = None
return [
convert_sklearn_tree(
tree, scaling=scaling, class_label=class_label, output_type=output_type
Expand All @@ -44,8 +37,8 @@ def convert_sklearn_forest(


def convert_sklearn_tree(
tree_model: Union["DecisionTreeRegressor", "DecisionTreeClassifier"],
class_label: int = 0,
tree_model: Model,
class_label: Optional[int] = None,
scaling: float = 1.0,
output_type: str = "raw",
) -> TreeModel:
Expand All @@ -63,14 +56,16 @@ def convert_sklearn_tree(
The converted decision tree model.
"""
tree_values = tree_model.tree_.value.copy() * scaling
# set class label if not given and model is a classifier
if safe_isinstance(tree_model, "sklearn.tree.DecisionTreeClassifier") and class_label is None:
class_label = 1

if class_label is not None:
# turn node values into probabilities
if len(tree_values.shape) == 3:
tree_values = tree_values / np.sum(tree_values, axis=2, keepdims=True)
tree_values = tree_values[:, 0, class_label]
else:
tree_values = tree_values / np.sum(tree_values, axis=1, keepdims=True)
tree_values = tree_values[:, class_label]
tree_values = tree_values[:, 0, :]
tree_values = tree_values / np.sum(tree_values, axis=1, keepdims=True)
tree_values = tree_values[:, class_label]
if output_type != "raw":
# TODO: Add support for logits output type
raise NotImplementedError("Only raw output types are currently supported.")
Expand Down
201 changes: 0 additions & 201 deletions shapiq/explainer/tree/conversion/xgboost.py

This file was deleted.

13 changes: 8 additions & 5 deletions shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This module contains the TreeExplainer class making use of the TreeSHAPIQ algorithm for
computing any-order Shapley Interactions for tree ensembles."""
import copy
from typing import Any, Union
from typing import Any, Optional, Union

import numpy as np
from explainer._base import Explainer
Expand All @@ -17,17 +17,20 @@ def __init__(
model: Union[dict, TreeModel, Any],
max_order: int = 2,
min_order: int = 1,
class_label: Optional[int] = None,
output_type: str = "raw",
) -> None:
# validate and parse model
validated_model = _validate_model(model) # the parsed and validated model

validated_model = _validate_model(model, class_label=class_label, output_type=output_type)
self._trees: Union[TreeModel, list[TreeModel]] = copy.deepcopy(validated_model)
if not isinstance(self._trees, list):
self._trees = [self._trees]
self._n_trees = len(self._trees)

self._max_order = max_order
self._min_order = min_order
self._max_order: int = max_order
self._min_order: int = min_order
self._class_label: Optional[int] = class_label
self._output_type: str = output_type

# setup explainers for all trees
self._treeshapiq_explainers: list[TreeSHAPIQ] = [
Expand Down
11 changes: 5 additions & 6 deletions shapiq/explainer/tree/treeshapiq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains the tree explainer implementation."""
import copy
from math import factorial
from typing import Any, Optional, Union

import numpy as np
from approximator import transforms_sii_to_ksii
Expand Down Expand Up @@ -30,7 +31,7 @@ class TreeSHAPIQ:

def __init__(
self,
model: TreeModel,
model: Union[dict, TreeModel, Any],
max_order: int = 2,
min_order: int = 1,
interaction_type: str = "k-SII",
Expand Down Expand Up @@ -507,24 +508,22 @@ def _get_N_cii(self, interpolated_poly, order) -> np.ndarray[float]:
)
return Ns

def _get_subset_weight_cii(self, t, order) -> float:
def _get_subset_weight_cii(self, t, order) -> Optional[float]:
# TODO: add docstring
if self._interaction_type == "STI":
return self._max_order / (
self._n_features_in_tree * binom(self._n_features_in_tree - 1, t)
)
elif self._interaction_type == "FSI":
if self._interaction_type == "FSI":
return (
factorial(2 * self._max_order - 1)
/ factorial(self._max_order - 1) ** 2
* factorial(self._max_order + t - 1)
* factorial(self._n_features_in_tree - t - 1)
/ factorial(self._n_features_in_tree + self._max_order - 1)
)
elif self._interaction_type == "BZF":
if self._interaction_type == "BZF":
return 1 / (2 ** (self._n_features_in_tree - order))
else:
raise ValueError("Interaction type not supported")

@staticmethod
def _get_N_id(D) -> np.ndarray[float]:
Expand Down
Loading

0 comments on commit eb91bd8

Please sign in to comment.