diff --git a/shapiq/explainer/tree/base.py b/shapiq/explainer/tree/base.py index eb98321d..1d3afa56 100644 --- a/shapiq/explainer/tree/base.py +++ b/shapiq/explainer/tree/base.py @@ -109,7 +109,7 @@ def reduce_feature_complexity(self) -> None: Feature '8' is 'renamed' to '2' such that in the internal representation a one-hot vector (and matrices) of length 3 suffices to represent the feature indices. """ - if self.n_features_in_tree < self.max_feature_id: + if self.n_features_in_tree < self.max_feature_id + 1: new_feature_ids = set(range(self.n_features_in_tree)) mapping_old_new = {old_id: new_id for new_id, old_id in enumerate(self.feature_ids)} mapping_new_old = {new_id: old_id for new_id, old_id in enumerate(self.feature_ids)} diff --git a/shapiq/explainer/tree/conversion/sklearn.py b/shapiq/explainer/tree/conversion/sklearn.py index 8367cc72..5d7ec2a3 100644 --- a/shapiq/explainer/tree/conversion/sklearn.py +++ b/shapiq/explainer/tree/conversion/sklearn.py @@ -1,24 +1,19 @@ """This module contains functions for converting scikit-learn decision trees to the format used by shapiq.""" -from typing import Optional, Union +from typing import Optional import numpy as np from explainer.tree.base import TreeModel from shapiq.utils import safe_isinstance - -try: - from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor - from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor -except ImportError: - pass +from shapiq.utils.types import Model def convert_sklearn_forest( - tree_model: Union["RandomForestRegressor", "RandomForestClassifier"], - class_label: int = 0, - output_type: Optional[str] = None, + tree_model: Model, + class_label: Optional[int] = None, + output_type: str = "raw", ) -> list[TreeModel]: """Transforms a scikit-learn random forest to the format used by shapiq. @@ -33,8 +28,6 @@ def convert_sklearn_forest( The converted random forest model. """ scaling = 1.0 / len(tree_model.estimators_) - if not safe_isinstance(tree_model, "sklearn.ensemble.RandomForestClassifier"): - output_type = None return [ convert_sklearn_tree( tree, scaling=scaling, class_label=class_label, output_type=output_type @@ -44,8 +37,8 @@ def convert_sklearn_forest( def convert_sklearn_tree( - tree_model: Union["DecisionTreeRegressor", "DecisionTreeClassifier"], - class_label: int = 0, + tree_model: Model, + class_label: Optional[int] = None, scaling: float = 1.0, output_type: str = "raw", ) -> TreeModel: @@ -63,14 +56,16 @@ def convert_sklearn_tree( The converted decision tree model. """ tree_values = tree_model.tree_.value.copy() * scaling + # set class label if not given and model is a classifier + if safe_isinstance(tree_model, "sklearn.tree.DecisionTreeClassifier") and class_label is None: + class_label = 1 + if class_label is not None: # turn node values into probabilities if len(tree_values.shape) == 3: - tree_values = tree_values / np.sum(tree_values, axis=2, keepdims=True) - tree_values = tree_values[:, 0, class_label] - else: - tree_values = tree_values / np.sum(tree_values, axis=1, keepdims=True) - tree_values = tree_values[:, class_label] + tree_values = tree_values[:, 0, :] + tree_values = tree_values / np.sum(tree_values, axis=1, keepdims=True) + tree_values = tree_values[:, class_label] if output_type != "raw": # TODO: Add support for logits output type raise NotImplementedError("Only raw output types are currently supported.") diff --git a/shapiq/explainer/tree/conversion/xgboost.py b/shapiq/explainer/tree/conversion/xgboost.py deleted file mode 100644 index bab9f4a9..00000000 --- a/shapiq/explainer/tree/conversion/xgboost.py +++ /dev/null @@ -1,201 +0,0 @@ -import copy -import struct -from typing import Union - -import numpy as np -import scipy -from explainer.tree.base import TreeModel -from packaging import version - -from shapiq.utils import safe_isinstance - -try: - from xgboost.sklearn import XGBClassifier, XGBRegressor -except ImportError: - pass - - -def convert_xgboost_trees( - tree_model: Union["XGBClassifier", "XGBRegressor"], -) -> list[TreeModel]: - if safe_isinstance(tree_model, "xgboost.sklearn.XGBRegressor"): - model = tree_model.get_booster() - xgb_loader = XGBTreeModelLoader(model) - trees = xgb_loader.get_trees() - return copy.deepcopy(trees) - if safe_isinstance(tree_model, "xgboost.sklearn.XGBClassifier"): - raise NotImplementedError("XGBoost Classifier not implemented yet.") - - -class XGBTreeModelLoader: - """This loads an XGBoost model directly from a raw memory dump. - - We can't use the JSON dump because due to numerical precision issues those - tree can actually be wrong when feature values land almost on a threshold. - """ - - def __init__(self, xgb_model): - # new in XGBoost 1.1, 'binf' is appended to the buffer - self.buf = xgb_model.save_raw() - if self.buf.startswith(b"binf"): - self.buf = self.buf[4:] - self.pos = 0 - - # load the model parameters - self.base_score = self.read("f") - self.num_feature = self.read("I") - self.num_class = self.read("i") - self.contain_extra_attrs = self.read("i") - self.contain_eval_metrics = self.read("i") - self.read_arr("i", 29) # reserved - self.name_obj_len = self.read("Q") - self.name_obj = self.read_str(self.name_obj_len) - self.name_gbm_len = self.read("Q") - self.name_gbm = self.read_str(self.name_gbm_len) - - # new in XGBoost 1.0 is that the base_score is saved untransformed (https://github.com/dmlc/xgboost/pull/5101) - # so we have to transform it depending on the objective - import xgboost - - if version.parse(xgboost.__version__).major >= 1: - if self.name_obj in ["binary:logistic", "reg:logistic"]: - self.base_score = scipy.special.logit(self.base_score) # pylint: disable=no-member - - assert self.name_gbm == "gbtree", ( - "Only the 'gbtree' model type is supported, not '%s'!" % self.name_gbm - ) - - # load the gbtree specific parameters - self.num_trees = self.read("i") - self.num_roots = self.read("i") - self.num_feature = self.read("i") - self.pad_32bit = self.read("i") - self.num_pbuffer_deprecated = self.read("Q") - self.num_output_group = self.read("i") - self.size_leaf_vector = self.read("i") - self.read_arr("i", 32) # reserved - - # load each tree - self.num_roots = np.zeros(self.num_trees, dtype=np.int32) - self.num_nodes = np.zeros(self.num_trees, dtype=np.int32) - self.num_deleted = np.zeros(self.num_trees, dtype=np.int32) - self.max_depth = np.zeros(self.num_trees, dtype=np.int32) - self.num_feature = np.zeros(self.num_trees, dtype=np.int32) - self.size_leaf_vector = np.zeros(self.num_trees, dtype=np.int32) - self.node_parents = [] - self.node_cleft = [] - self.node_cright = [] - self.node_sindex = [] - self.node_info = [] - self.loss_chg = [] - self.sum_hess = [] - self.base_weight = [] - self.leaf_child_cnt = [] - for i in range(self.num_trees): - # load the per-tree params - self.num_roots[i] = self.read("i") - self.num_nodes[i] = self.read("i") - self.num_deleted[i] = self.read("i") - self.max_depth[i] = self.read("i") - self.num_feature[i] = self.read("i") - self.size_leaf_vector[i] = self.read("i") - - # load the nodes - self.read_arr("i", 31) # reserved - self.node_parents.append(np.zeros(self.num_nodes[i], dtype=np.int32)) - self.node_cleft.append(np.zeros(self.num_nodes[i], dtype=np.int32)) - self.node_cright.append(np.zeros(self.num_nodes[i], dtype=np.int32)) - self.node_sindex.append(np.zeros(self.num_nodes[i], dtype=np.uint32)) - self.node_info.append(np.zeros(self.num_nodes[i], dtype=np.float32)) - for j in range(self.num_nodes[i]): - self.node_parents[-1][j] = self.read("i") - self.node_cleft[-1][j] = self.read("i") - self.node_cright[-1][j] = self.read("i") - self.node_sindex[-1][j] = self.read("I") - self.node_info[-1][j] = self.read("f") - - # load the stat nodes - self.loss_chg.append(np.zeros(self.num_nodes[i], dtype=np.float32)) - self.sum_hess.append(np.zeros(self.num_nodes[i], dtype=np.float32)) - self.base_weight.append(np.zeros(self.num_nodes[i], dtype=np.float32)) - self.leaf_child_cnt.append(np.zeros(self.num_nodes[i], dtype=int)) - for j in range(self.num_nodes[i]): - self.loss_chg[-1][j] = self.read("f") - self.sum_hess[-1][j] = self.read("f") - self.base_weight[-1][j] = self.read("f") - self.leaf_child_cnt[-1][j] = self.read("i") - - def get_trees(self, data=None, data_missing=None): - shape = (self.num_trees, self.num_nodes.max()) - self.children_default = np.zeros(shape, dtype=int) - self.features = np.zeros(shape, dtype=int) - self.thresholds = np.zeros(shape, dtype=np.float32) - self.values = np.zeros((shape[0], shape[1], 1), dtype=np.float32) - trees = [] - for i in range(self.num_trees): - for j in range(self.num_nodes[i]): - if np.right_shift(self.node_sindex[i][j], np.uint32(31)) != 0: - self.children_default[i, j] = self.node_cleft[i][j] - else: - self.children_default[i, j] = self.node_cright[i][j] - self.features[i, j] = self.node_sindex[i][j] & ( - (np.uint32(1) << np.uint32(31)) - np.uint32(1) - ) - if self.node_cleft[i][j] >= 0: - # Xgboost uses < for thresholds where shap uses <= - # Move the threshold down by the smallest possible increment - self.thresholds[i, j] = np.nextafter(self.node_info[i][j], -np.float32(np.inf)) - else: - self.values[i, j] = self.node_info[i][j] - - size = len(self.node_cleft[i]) - trees.append( - TreeModel( - children_left=self.node_cleft[i], - children_right=self.node_cright[i], - features=self.features[i, :size], - thresholds=self.thresholds[i, :size], - values=self.values[i, :size].flatten(), - node_sample_weight=self.sum_hess[i], - ) - ) - return trees - - def read(self, dtype): - size = struct.calcsize(dtype) - val = struct.unpack(dtype, self.buf[self.pos : self.pos + size])[0] - self.pos += size - return val - - def read_arr(self, dtype, n_items): - format = "%d%s" % (n_items, dtype) - size = struct.calcsize(format) - val = struct.unpack(format, self.buf[self.pos : self.pos + size])[0] - self.pos += size - return val - - def read_str(self, size): - val = self.buf[self.pos : self.pos + size].decode("utf-8") - self.pos += size - return val - - def print_info(self): - print("--- global parmeters ---") - print("base_score =", self.base_score) - print("num_feature =", self.num_feature) - print("num_class =", self.num_class) - print("contain_extra_attrs =", self.contain_extra_attrs) - print("contain_eval_metrics =", self.contain_eval_metrics) - print("name_obj_len =", self.name_obj_len) - print("name_obj =", self.name_obj) - print("name_gbm_len =", self.name_gbm_len) - print("name_gbm =", self.name_gbm) - print() - print("--- gbtree specific parameters ---") - print("num_trees =", self.num_trees) - print("num_roots =", self.num_roots) - print("num_feature =", self.num_feature) - print("pad_32bit =", self.pad_32bit) - print("num_pbuffer_deprecated =", self.num_pbuffer_deprecated) - print("num_output_group =", self.num_output_group) - print("size_leaf_vector =", self.size_leaf_vector) diff --git a/shapiq/explainer/tree/explainer.py b/shapiq/explainer/tree/explainer.py index 387b6e12..bfaf0ded 100644 --- a/shapiq/explainer/tree/explainer.py +++ b/shapiq/explainer/tree/explainer.py @@ -1,7 +1,7 @@ """This module contains the TreeExplainer class making use of the TreeSHAPIQ algorithm for computing any-order Shapley Interactions for tree ensembles.""" import copy -from typing import Any, Union +from typing import Any, Optional, Union import numpy as np from explainer._base import Explainer @@ -17,17 +17,20 @@ def __init__( model: Union[dict, TreeModel, Any], max_order: int = 2, min_order: int = 1, + class_label: Optional[int] = None, + output_type: str = "raw", ) -> None: # validate and parse model - validated_model = _validate_model(model) # the parsed and validated model - + validated_model = _validate_model(model, class_label=class_label, output_type=output_type) self._trees: Union[TreeModel, list[TreeModel]] = copy.deepcopy(validated_model) if not isinstance(self._trees, list): self._trees = [self._trees] self._n_trees = len(self._trees) - self._max_order = max_order - self._min_order = min_order + self._max_order: int = max_order + self._min_order: int = min_order + self._class_label: Optional[int] = class_label + self._output_type: str = output_type # setup explainers for all trees self._treeshapiq_explainers: list[TreeSHAPIQ] = [ diff --git a/shapiq/explainer/tree/treeshapiq.py b/shapiq/explainer/tree/treeshapiq.py index 1a3da2fb..d254abb8 100644 --- a/shapiq/explainer/tree/treeshapiq.py +++ b/shapiq/explainer/tree/treeshapiq.py @@ -1,6 +1,7 @@ """This module contains the tree explainer implementation.""" import copy from math import factorial +from typing import Any, Optional, Union import numpy as np from approximator import transforms_sii_to_ksii @@ -30,7 +31,7 @@ class TreeSHAPIQ: def __init__( self, - model: TreeModel, + model: Union[dict, TreeModel, Any], max_order: int = 2, min_order: int = 1, interaction_type: str = "k-SII", @@ -507,13 +508,13 @@ def _get_N_cii(self, interpolated_poly, order) -> np.ndarray[float]: ) return Ns - def _get_subset_weight_cii(self, t, order) -> float: + def _get_subset_weight_cii(self, t, order) -> Optional[float]: # TODO: add docstring if self._interaction_type == "STI": return self._max_order / ( self._n_features_in_tree * binom(self._n_features_in_tree - 1, t) ) - elif self._interaction_type == "FSI": + if self._interaction_type == "FSI": return ( factorial(2 * self._max_order - 1) / factorial(self._max_order - 1) ** 2 @@ -521,10 +522,8 @@ def _get_subset_weight_cii(self, t, order) -> float: * factorial(self._n_features_in_tree - t - 1) / factorial(self._n_features_in_tree + self._max_order - 1) ) - elif self._interaction_type == "BZF": + if self._interaction_type == "BZF": return 1 / (2 ** (self._n_features_in_tree - order)) - else: - raise ValueError("Interaction type not supported") @staticmethod def _get_N_id(D) -> np.ndarray[float]: diff --git a/shapiq/explainer/tree/validation.py b/shapiq/explainer/tree/validation.py index 5af979dd..68f969a5 100644 --- a/shapiq/explainer/tree/validation.py +++ b/shapiq/explainer/tree/validation.py @@ -1,20 +1,22 @@ """This module contains conversion functions for the tree explainer implementation.""" -from typing import Any, Optional +from typing import Any, Optional, Union from shapiq.utils import safe_isinstance from .base import TreeModel -from .conversion.sklearn import convert_sklearn_tree +from .conversion.sklearn import convert_sklearn_forest, convert_sklearn_tree SUPPORTED_MODELS = { "sklearn.tree.DecisionTreeRegressor", "sklearn.tree.DecisionTreeClassifier", + "sklearn.ensemble.RandomForestClassifier", + "sklearn.ensemble.RandomForestRegressor", } def _validate_model( model: Any, class_label: Optional[int] = None, output_type: str = "raw" -) -> TreeModel: +) -> Union[TreeModel, list[TreeModel]]: """Validate the model. Args: @@ -25,13 +27,21 @@ def _validate_model( Returns: The validated model and the model function. """ - if isinstance(model, TreeModel): + # tree model (is already in the correct format) + if type(model).__name__ == "TreeModel": return model + # dict as model is parsed to TreeModel (the dict needs to have the correct format and names) + if type(model).__name__ == "dict": + return TreeModel(**model) + # sklearn decision trees if safe_isinstance(model, "sklearn.tree.DecisionTreeRegressor") or safe_isinstance( model, "sklearn.tree.DecisionTreeClassifier" ): - if safe_isinstance(model, "sklearn.tree.DecisionTreeClassifier") and class_label is None: - class_label = 1 - return convert_sklearn_tree(model, class_label=class_label) - else: - raise TypeError("Unsupported model type." f"Supported models are: {SUPPORTED_MODELS}") + return convert_sklearn_tree(model, class_label=class_label, output_type=output_type) + # sklearn random forests + if safe_isinstance(model, "sklearn.ensemble.RandomForestRegressor") or safe_isinstance( + model, "sklearn.ensemble.RandomForestClassifier" + ): + return convert_sklearn_forest(model, class_label=class_label, output_type=output_type) + # unsupported model + raise TypeError("Unsupported model type." f"Supported models are: {SUPPORTED_MODELS}") diff --git a/shapiq/utils/__init__.py b/shapiq/utils/__init__.py index 21f88890..f96e5e4c 100644 --- a/shapiq/utils/__init__.py +++ b/shapiq/utils/__init__.py @@ -1,6 +1,6 @@ """This module contains utility functions for the shapiq package.""" -from .modules import safe_isinstance +from .modules import safe_isinstance, try_import from .sets import ( generate_interaction_lookup, get_explicit_subsets, @@ -18,4 +18,5 @@ "generate_interaction_lookup", # modules "safe_isinstance", + "try_import", ] diff --git a/shapiq/utils/modules.py b/shapiq/utils/modules.py index 3aa4c257..ad658b7c 100644 --- a/shapiq/utils/modules.py +++ b/shapiq/utils/modules.py @@ -1,5 +1,7 @@ +import importlib import sys -from typing import Any, Union +from types import ModuleType +from typing import Any, Optional, Union def safe_isinstance(obj: Any, class_path_str: Union[str, list[str], tuple[str]]) -> bool: @@ -55,3 +57,23 @@ def safe_isinstance(obj: Any, class_path_str: Union[str, list[str], tuple[str]]) return True return False + + +def try_import(name: str, package: Optional[str] = None) -> Optional[ModuleType]: + """ + Try to import a module and return None if it fails. + + Note: + Solution adapted from [stack overflow](https://stackoverflow.com/a/53241197). + + Args: + name: The name of the module to import. + package: The package to import the module from. + + Returns: + The imported module or None if the import fails. + """ + try: + return importlib.import_module(name, package=package) + except ImportError: + return None diff --git a/shapiq/utils/types.py b/shapiq/utils/types.py new file mode 100644 index 00000000..b7773d05 --- /dev/null +++ b/shapiq/utils/types.py @@ -0,0 +1,5 @@ +"""This module contains all custom types used in the shapiq package.""" +from typing import TypeVar + +# Model type for all machine learning models +Model = TypeVar("Model") diff --git a/tests/conftest.py b/tests/conftest.py index 53edf2ff..cb5fd379 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,14 +2,15 @@ If it becomes too large, it can be split into multiple files like here: https://gist.github.com/peterhurford/09f7dcda0ab04b95c026c60fa49c2a68 """ - +import numpy as np import pytest from sklearn.datasets import make_regression, make_classification from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier +from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier @pytest.fixture -def dt_reg_model(): +def dt_reg_model() -> DecisionTreeRegressor: """Return a simple decision tree model.""" X, y = make_regression(n_samples=100, n_features=7, random_state=42) model = DecisionTreeRegressor(random_state=42, max_depth=3) @@ -18,7 +19,7 @@ def dt_reg_model(): @pytest.fixture -def dt_clf_model(): +def dt_clf_model() -> DecisionTreeClassifier: """Return a simple decision tree model.""" X, y = make_classification( n_samples=100, @@ -35,14 +36,40 @@ def dt_clf_model(): @pytest.fixture -def background_reg_data(): +def rf_reg_model() -> RandomForestRegressor: + """Return a simple random forest model.""" + X, y = make_regression(n_samples=100, n_features=7, random_state=42) + model = RandomForestRegressor(random_state=42, max_depth=3, n_estimators=3) + model.fit(X, y) + return model + + +@pytest.fixture +def rf_clf_model() -> RandomForestClassifier: + """Return a simple random forest model.""" + X, y = make_classification( + n_samples=100, + n_features=7, + random_state=42, + n_classes=3, + n_informative=7, + n_repeated=0, + n_redundant=0, + ) + model = RandomForestClassifier(random_state=42, max_depth=3, n_estimators=3) + model.fit(X, y) + return model + + +@pytest.fixture +def background_reg_data() -> tuple[np.ndarray, np.ndarray]: """Return a simple background dataset.""" X, y = make_regression(n_samples=100, n_features=7, random_state=42) return X @pytest.fixture -def background_clf_data(): +def background_clf_data() -> tuple[np.ndarray, np.ndarray]: """Return a simple background dataset.""" X, y = make_classification( n_samples=100, diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py index 0cabd046..fbcd857d 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py @@ -1,98 +1,56 @@ -"""This module contains all tests for the TreeExplainer class of the shapiq package.""" +"""This test module contains all tests for the tree explainer module of the shapiq package.""" import numpy as np import pytest -from shapiq.explainer.tree import TreeModel -from shapiq.explainer.tree import TreeSHAPIQ +from shapiq.explainer.tree import TreeExplainer -def test_init(dt_clf_model, background_clf_data): - """Test the initialization of the TreeExplainer class.""" - explainer = TreeSHAPIQ(model=dt_clf_model, max_order=1, interaction_type="SII", verbose=True) +def test_decision_tree_classifier(dt_clf_model, background_clf_data): + """Test TreeExplainer with a simple decision tree classifier.""" + explainer = TreeExplainer(model=dt_clf_model, max_order=2, min_order=1) x_explain = background_clf_data[0] - _ = explainer.explain(x_explain) + explanation = explainer.explain(x_explain) + + assert type(explanation).__name__ == "InteractionValues" # check correct return type + + # check init with class label + _ = TreeExplainer(model=dt_clf_model, max_order=2, min_order=1, class_label=0) - explainer = TreeSHAPIQ(model=dt_clf_model, max_order=1, interaction_type="k-SII") - x_explain = background_clf_data[0] - _ = explainer.explain(x_explain) assert True + # check with invalid output type + with pytest.raises(NotImplementedError): + _ = TreeExplainer( + model=dt_clf_model, max_order=2, min_order=1, output_type="invalid_output_type" + ) + + +def test_decision_tree_regression(dt_reg_model, background_reg_data): + """Test TreeExplainer with a simple decision tree regressor.""" + explainer = TreeExplainer(model=dt_reg_model, max_order=2, min_order=1) + + x_explain = background_reg_data[0] + explanation = explainer.explain(x_explain) + + assert type(explanation).__name__ == "InteractionValues" # check correct return type + + +def test_random_forrest_regression(rf_reg_model, background_reg_data): + """Test TreeExplainer with a simple decision tree regressor.""" + explainer = TreeExplainer(model=rf_reg_model, max_order=2, min_order=1) + + x_explain = background_reg_data[0] + explanation = explainer.explain(x_explain) + + assert type(explanation).__name__ == "InteractionValues" # check correct return type -@pytest.mark.parametrize( - "index, expected", - [ - ( - "SII", - { - (0,): -10.18947368, - (1,): -13.31052632, - (2,): 3.0, - (0, 1): -11.77894737, - (0, 2): -6.0, - (1, 2): 0, - }, - ), - ( - "BZF", - { - (0,): -10.18947368, - (1,): -13.31052632, - (2,): 3.0, - (0, 1): -11.77894737, - (0, 2): -6.0, - (1, 2): 0, - }, - ), - ( - "FSI", - { - (0,): -39.45789474, - (1,): -45.82105263, - (2,): 6.0, - (0, 1): -11.77894737, - (0, 2): -6.0, - (1, 2): 0, - }, - ), - ( - "STI", - { - (0,): -20.37894737, - (1,): -26.62105263, - (2,): 6.0, - (0, 1): -11.77894737, - (0, 2): -6.0, - (1, 2): 0, - }, - ), - ], -) -def test_manual_tree(index: str, expected: dict): - # manual values for a tree to test against the original treeshapiq implementation - - children_left = np.asarray([1, 2, 3, -1, -1, -1, 7, -1, -1]) - children_right = np.asarray([6, 5, 4, -1, -1, -1, 8, -1, -1]) - features = np.asarray([0, 1, 0, -2, -2, -2, 2, -2, -2]) - thresholds = np.asarray([0, 0, -0.5, -2, -2, -2, 0, -2, -2]) - node_sample_weight = np.asarray([100, 50, 38, 15, 23, 12, 50, 20, 30]) - values = np.asarray([110, 105, 95, 20, 50, 100, 75, 10, 40]) - - x_explain = np.asarray([-1, -0.5, 1, 0]) - - tree_model = TreeModel( - children_left=children_left, - children_right=children_right, - features=features, - thresholds=thresholds, - node_sample_weight=node_sample_weight, - values=values, - ) - - explainer = TreeSHAPIQ(model=tree_model, max_order=2, interaction_type=index) +def test_random_forrest_classification(rf_clf_model, background_clf_data): + """Test TreeExplainer with a simple decision tree regressor.""" + explainer = TreeExplainer(model=rf_clf_model, max_order=2, min_order=1) + + x_explain = background_clf_data[0] explanation = explainer.explain(x_explain) - print(explanation) - for key, value in expected.items(): - assert np.isclose(explanation[key], value, atol=1e-5) + assert type(explanation).__name__ == "InteractionValues" # check correct return type diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py index 4e5ce0b3..08a009be 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py @@ -1,11 +1,11 @@ """This test module collects all tests for the conversions of the supported tree models for the TreeExplainer class.""" - import numpy as np -from shapiq import safe_isinstance + +from shapiq.utils import safe_isinstance from shapiq.explainer.tree.base import TreeModel -from explainer.tree.conversion.sklearn import convert_sklearn_tree +from shapiq.explainer.tree.conversion.sklearn import convert_sklearn_tree, convert_sklearn_forest def test_tree_model_init(): @@ -37,25 +37,90 @@ def test_tree_model_init(): assert np.all(tree_model["children_left"] == np.array([1, 2, -1, -1, -1])) -def test_sklean_conversion(dt_reg_model, dt_clf_model): +def test_edge_tree_init(): + """Tests the initialization of the EdgeTree class.""" + + from explainer.tree.conversion.edges import create_edge_tree + + # setup test data (same as in test_manual_tree of test_tree_treeshapiq.py) + children_left = np.asarray([1, 2, 3, -1, -1, -1, 7, -1, -1]) + children_right = np.asarray([6, 5, 4, -1, -1, -1, 8, -1, -1]) + features = np.asarray([0, 1, 0, -2, -2, -2, 2, -2, -2]) + thresholds = np.asarray([0, 0, -0.5, -2, -2, -2, 0, -2, -2]) + node_sample_weight = np.asarray([100, 50, 38, 15, 23, 12, 50, 20, 30]) + values = np.asarray([110, 105, 95, 20, 50, 100, 75, 10, 40]) + + tree_model = TreeModel( + children_left=children_left, + children_right=children_right, + features=features, + thresholds=thresholds, + node_sample_weight=node_sample_weight, + values=values, + ) + + max_feature_id = tree_model.max_feature_id + n_nodes = tree_model.n_nodes + + interaction_update_positions = { + 1: {0: np.array([0]), 1: np.array([1]), 2: np.array([2])}, + 2: {0: np.array([0, 1]), 1: np.array([0, 2]), 2: np.array([1, 2])}, + } + + edge_tree = create_edge_tree( + children_left=tree_model.children_left, + children_right=tree_model.children_right, + features=tree_model.features, + node_sample_weight=tree_model.node_sample_weight, + values=tree_model.values, + max_interaction=1, + n_features=max_feature_id + 1, + n_nodes=n_nodes, + subset_updates_pos_store=interaction_update_positions, + ) + + assert safe_isinstance(edge_tree, ["explainer.tree.base.EdgeTree"]) + + # check if edge_tree can be accessed via __getitem__ + assert edge_tree["parents"] is not None + + +def test_sklean_dt_conversion(dt_reg_model, dt_clf_model): """Test the conversion of a scikit-learn decision tree model.""" # test regression model - class_path_str = ["explainer.tree.base.TreeModel"] + tree_model_class_path_str = ["explainer.tree.base.TreeModel"] tree_model = convert_sklearn_tree(dt_reg_model) - assert safe_isinstance(tree_model, class_path_str) + assert safe_isinstance(tree_model, tree_model_class_path_str) assert tree_model.empty_prediction is not None # test scaling tree_model = convert_sklearn_tree(dt_reg_model, scaling=0.5) - assert safe_isinstance(tree_model, class_path_str) + assert safe_isinstance(tree_model, tree_model_class_path_str) assert tree_model.empty_prediction is not None # test classification model with class label tree_model = convert_sklearn_tree(dt_clf_model, class_label=0) - assert safe_isinstance(tree_model, class_path_str) + assert safe_isinstance(tree_model, tree_model_class_path_str) assert tree_model.empty_prediction is not None # test classification model without class label tree_model = convert_sklearn_tree(dt_clf_model) - assert safe_isinstance(tree_model, class_path_str) + assert safe_isinstance(tree_model, tree_model_class_path_str) assert tree_model.empty_prediction is not None + + +def test_skleanr_rf_conversion(rf_clf_model, rf_reg_model): + """Test the conversion of a scikit-learn random forest model.""" + tree_model_class_path_str = ["explainer.tree.base.TreeModel"] + + # test the regression model + tree_model = convert_sklearn_forest(rf_reg_model) + assert type(tree_model) is list + assert safe_isinstance(tree_model[0], tree_model_class_path_str) + assert tree_model[0].empty_prediction is not None + + # test the classification model + tree_model = convert_sklearn_forest(rf_clf_model) + assert type(tree_model) is list + assert safe_isinstance(tree_model[0], tree_model_class_path_str) + assert tree_model[0].empty_prediction is not None diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py new file mode 100644 index 00000000..b1717ced --- /dev/null +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py @@ -0,0 +1,110 @@ +"""This module contains all tests for the TreeExplainer class of the shapiq package.""" +import numpy as np +import pytest + +from shapiq.explainer.tree import TreeModel, TreeSHAPIQ + + +def test_init(dt_clf_model, background_clf_data): + """Test the initialization of the TreeExplainer class.""" + explainer = TreeSHAPIQ(model=dt_clf_model, max_order=1, interaction_type="SII", verbose=True) + + x_explain = background_clf_data[0] + _ = explainer.explain(x_explain) + + explainer = TreeSHAPIQ(model=dt_clf_model, max_order=1, interaction_type="k-SII") + x_explain = background_clf_data[0] + _ = explainer.explain(x_explain) + + # test with dict input as tree + tree_model = { + "children_left": np.asarray([1, 2, 3, -1, -1, -1, 7, -1, -1]), + "children_right": np.asarray([6, 5, 4, -1, -1, -1, 8, -1, -1]), + "features": np.asarray([0, 1, 0, -2, -2, -2, 2, -2, -2]), + "thresholds": np.asarray([0, 0, -0.5, -2, -2, -2, 0, -2, -2]), + "node_sample_weight": np.asarray([100, 50, 38, 15, 23, 12, 50, 20, 30]), + "values": np.asarray([110, 105, 95, 20, 50, 100, 75, 10, 40]), + } + explainer = TreeSHAPIQ(model=tree_model, max_order=1, interaction_type="SII") + x_explain = np.asarray([-1, -0.5, 1, 0]) + _ = explainer.explain(x_explain) + + assert True + + +@pytest.mark.parametrize( + "index, expected", + [ + ( + "SII", + { + (0,): -10.18947368, + (1,): -13.31052632, + (2,): 3.0, + (0, 1): -11.77894737, + (0, 2): -6.0, + (1, 2): 0, + }, + ), + ( + "BZF", + { + (0,): -10.18947368, + (1,): -13.31052632, + (2,): 3.0, + (0, 1): -11.77894737, + (0, 2): -6.0, + (1, 2): 0, + }, + ), + ( + "FSI", + { + (0,): -39.45789474, + (1,): -45.82105263, + (2,): 6.0, + (0, 1): -11.77894737, + (0, 2): -6.0, + (1, 2): 0, + }, + ), + ( + "STI", + { + (0,): -20.37894737, + (1,): -26.62105263, + (2,): 6.0, + (0, 1): -11.77894737, + (0, 2): -6.0, + (1, 2): 0, + }, + ), + ], +) +def test_manual_tree(index: str, expected: dict): + # manual values for a tree to test against the original treeshapiq implementation + children_left = np.asarray([1, 2, 3, -1, -1, -1, 7, -1, -1]) + children_right = np.asarray([6, 5, 4, -1, -1, -1, 8, -1, -1]) + features = np.asarray([0, 1, 0, -2, -2, -2, 2, -2, -2]) + thresholds = np.asarray([0, 0, -0.5, -2, -2, -2, 0, -2, -2]) + node_sample_weight = np.asarray([100, 50, 38, 15, 23, 12, 50, 20, 30]) + values = np.asarray([110, 105, 95, 20, 50, 100, 75, 10, 40]) + + x_explain = np.asarray([-1, -0.5, 1, 0]) + + tree_model = TreeModel( + children_left=children_left, + children_right=children_right, + features=features, + thresholds=thresholds, + node_sample_weight=node_sample_weight, + values=values, + ) + + explainer = TreeSHAPIQ(model=tree_model, max_order=2, interaction_type=index) + + explanation = explainer.explain(x_explain) + print(explanation) + + for key, value in expected.items(): + assert np.isclose(explanation[key], value, atol=1e-5) diff --git a/tests/tests_utils/test_utils_modules.py b/tests/tests_utils/test_utils_modules.py index d36d9b52..05e7116b 100644 --- a/tests/tests_utils/test_utils_modules.py +++ b/tests/tests_utils/test_utils_modules.py @@ -1,11 +1,12 @@ """This test module contains tests for utils.modules.""" import pytest -from shapiq.utils import safe_isinstance +from shapiq.utils import safe_isinstance, try_import from sklearn.tree import DecisionTreeRegressor def test_safe_isinstance(): + """Test the safe_isinstance function.""" model = DecisionTreeRegressor() assert safe_isinstance(model, "sklearn.tree.DecisionTreeRegressor") @@ -19,3 +20,14 @@ def test_safe_isinstance(): safe_isinstance(model, None) assert not safe_isinstance(model, "my.made.up.module") assert not safe_isinstance(model, ["sklearn.ensemble.DecisionTreeRegressor"]) + + +def test_try_import(): + """Tests the try_import function.""" + + # Test with a package that exists + try_import("DecisionTreeClassifier", package="sklearn.tree") + # check if the module is imported in the current environment namespace + + # Test with a package that does not exist + try_import("DecisionTreeClassifier", package="my.made.up.module")