diff --git a/docs/source/conf.py b/docs/source/conf.py index a87ca1d5..b84204ef 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -3,22 +3,24 @@ # For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html -from sphinx.builders.html import StandaloneHTMLBuilder import os import sys + +from sphinx.builders.html import StandaloneHTMLBuilder + sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("../../shapiq")) import shapiq # -- Read the Docs --------------------------------------------------------------------------------- -master_doc = 'index' +master_doc = "index" # -- Project information --------------------------------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'shapiq' -copyright = '2023, the shapiq developers' -author = 'Maximilian Muschalik and Fabian Fumagalli' +project = "shapiq" +copyright = "2023, the shapiq developers" +author = "Maximilian Muschalik and Fabian Fumagalli" release = shapiq.__version__ version = shapiq.__version__ @@ -34,15 +36,15 @@ "sphinx.ext.autodoc", "sphinx.ext.doctest", "sphinx.ext.autosummary", - 'sphinx_copybutton', + "sphinx_copybutton", "sphinx.ext.viewcode", "sphinx.ext.autosectionlabel", "sphinx_autodoc_typehints", "sphinx_toolbox.more_autodoc.autoprotocol", ] -templates_path = ['_templates'] -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] source_suffix = { ".rst": "restructuredtext", @@ -59,9 +61,9 @@ # -- Options for HTML output ----------------------------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'furo' +html_theme = "furo" html_static_path = ["_static"] -html_favicon = '_static/shapiq.ico' +html_favicon = "_static/shapiq.ico" pygments_dark_style = "monokai" html_theme_options = { "sidebar_hide_name": True, @@ -83,19 +85,22 @@ # -- Autodoc --------------------------------------------------------------------------------------- autosummary_generate = True autodoc_default_options = { - 'show-inheritance': True, - 'members': True, - 'member-order': 'groupwise', - 'special-members': '__call__', - 'undoc-members': True, - 'exclude-members': '__weakref__' + "show-inheritance": True, + "members": True, + "member-order": "groupwise", + "special-members": "__call__", + "undoc-members": True, + "exclude-members": "__weakref__", } -autoclass_content = 'class' +autoclass_content = "class" autodoc_inherit_docstrings = False # -- Images ---------------------------------------------------------------------------------------- StandaloneHTMLBuilder.supported_image_types = [ - "image/svg+xml", "image/gif", "image/png", "image/jpeg" + "image/svg+xml", + "image/gif", + "image/png", + "image/jpeg", ] # -- Copy Paste Button ----------------------------------------------------------------------------- # Ignore >>> when copying code diff --git a/setup.py b/setup.py index 5d715049..6887fd69 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ -import setuptools -import io -import os import codecs +import os + +import setuptools NAME = "shapiq" DESCRIPTION = "SHAPley Interaction Quantification (SHAP-IQ) for Explainable AI" @@ -13,18 +13,21 @@ work_directory = os.path.abspath(os.path.dirname(__file__)) + # https://packaging.python.org/guides/single-sourcing-package-version/ def read(rel_path): - with codecs.open(os.path.join(work_directory, rel_path), 'r') as fp: + with codecs.open(os.path.join(work_directory, rel_path), "r") as fp: return fp.read() - + + def get_version(rel_path): for line in read(rel_path).splitlines(): - if line.startswith('__version__'): + if line.startswith("__version__"): delimiter = '"' if '"' in line else "'" return line.split(delimiter)[1] -with io.open(os.path.join(work_directory, "README.md"), encoding="utf-8") as f: + +with open(os.path.join(work_directory, "README.md"), encoding="utf-8") as f: long_description = "\n" + f.read() base_packages = ["numpy", "scipy", "pandas", "tqdm"] diff --git a/shapiq/approximator/_base.py b/shapiq/approximator/_base.py index 7bfaa27c..6aa7aa2e 100644 --- a/shapiq/approximator/_base.py +++ b/shapiq/approximator/_base.py @@ -5,9 +5,9 @@ import numpy as np +from shapiq.approximator._config import AVAILABLE_INDICES from shapiq.interaction_values import InteractionValues from shapiq.utils.sets import generate_interaction_lookup -from shapiq.approximator._config import AVAILABLE_INDICES __all__ = [ "Approximator", diff --git a/shapiq/approximator/permutation/sti.py b/shapiq/approximator/permutation/sti.py index aa4cb22d..218a3917 100644 --- a/shapiq/approximator/permutation/sti.py +++ b/shapiq/approximator/permutation/sti.py @@ -77,7 +77,9 @@ def approximate( counts: np.ndarray[int] = self._init_result(dtype=int) # compute all lower order interactions if budget allows it - lower_order_cost = sum(int(sp.special.binom(self.n, s)) for s in range(self.min_order, self.max_order)) + lower_order_cost = sum( + int(sp.special.binom(self.n, s)) for s in range(self.min_order, self.max_order) + ) if self.max_order > 1 and budget >= lower_order_cost: budget -= lower_order_cost used_budget += lower_order_cost diff --git a/shapiq/approximator/regression/__init__.py b/shapiq/approximator/regression/__init__.py index b03c5842..36e4d12f 100644 --- a/shapiq/approximator/regression/__init__.py +++ b/shapiq/approximator/regression/__init__.py @@ -1,5 +1,4 @@ -"""This module contains the regression-based approximators to estimate Shapley interaction values. -""" +"""This module contains the regression-based approximators to estimate Shapley interaction values.""" from .fsi import RegressionFSI from .sii import RegressionSII diff --git a/shapiq/approximator/regression/_base.py b/shapiq/approximator/regression/_base.py index 7b966282..0d1fa78c 100644 --- a/shapiq/approximator/regression/_base.py +++ b/shapiq/approximator/regression/_base.py @@ -169,7 +169,9 @@ def _get_fsi_subset_representation( of players. """ n_subsets = all_subsets.shape[0] - num_players = sum(int(sp.special.binom(self.n, order)) for order in range(1, self.max_order + 1)) + num_players = sum( + int(sp.special.binom(self.n, order)) for order in range(1, self.max_order + 1) + ) regression_subsets = np.zeros(shape=(n_subsets, num_players), dtype=bool) for interaction_index, interaction in enumerate( powerset(self.N, min_size=1, max_size=self.max_order) @@ -193,7 +195,9 @@ def _get_sii_subset_representation( of players. """ n_subsets = all_subsets.shape[0] - num_players = sum(int(sp.special.binom(self.n, order)) for order in range(1, self.max_order + 1)) + num_players = sum( + int(sp.special.binom(self.n, order)) for order in range(1, self.max_order + 1) + ) regression_subsets = np.zeros(shape=(n_subsets, num_players), dtype=float) for interaction_index, interaction in enumerate( powerset(self.N, min_size=1, max_size=self.max_order) @@ -216,7 +220,9 @@ def _get_bernoulli_weight(self, intersection_size: int, r_prime: int) -> float: """ weight = 0 for size in range(1, intersection_size + 1): - weight += sp.special.binom(intersection_size, size) * self._bernoulli_numbers[r_prime - size] + weight += ( + sp.special.binom(intersection_size, size) * self._bernoulli_numbers[r_prime - size] + ) return weight def _get_bernoulli_weights( diff --git a/shapiq/approximator/sampling.py b/shapiq/approximator/sampling.py index 588f8f9e..095ff540 100644 --- a/shapiq/approximator/sampling.py +++ b/shapiq/approximator/sampling.py @@ -18,7 +18,7 @@ class ShapleySamplingMixin(ABC): """ def _init_ksh_sampling_weights( - self: Union[Approximator, "ShapleySamplingMixin"] + self: Union[Approximator, "ShapleySamplingMixin"], ) -> np.ndarray[float]: """Initializes the weights for sampling subsets. @@ -54,7 +54,9 @@ def _get_ksh_subset_weights( ksh_weights = self._init_ksh_sampling_weights() # indexed by subset size subset_sizes = np.sum(subsets, axis=1) weights = ksh_weights[subset_sizes] # set the weights for each subset size - weights /= sp.special.binom(self.n, subset_sizes) # divide by the number of subsets of the same size + weights /= sp.special.binom( + self.n, subset_sizes + ) # divide by the number of subsets of the same size # set the weights for the empty and full sets to big M weights[np.logical_not(subsets).all(axis=1)] = float(1_000_000) diff --git a/shapiq/explainer/__init__.py b/shapiq/explainer/__init__.py index 0ec86f70..9c2a3dc5 100644 --- a/shapiq/explainer/__init__.py +++ b/shapiq/explainer/__init__.py @@ -1,6 +1,5 @@ """This module contains the explainer for the shapiq package.""" - from .tabular import TabularExplainer from .tree import TreeExplainer diff --git a/shapiq/explainer/tabular.py b/shapiq/explainer/tabular.py index f1c64d96..e1e03e66 100644 --- a/shapiq/explainer/tabular.py +++ b/shapiq/explainer/tabular.py @@ -13,9 +13,9 @@ ShapIQ, ) from shapiq.approximator._base import Approximator -from shapiq.interaction_values import InteractionValues from shapiq.explainer._base import Explainer from shapiq.explainer.imputer import MarginalImputer +from shapiq.interaction_values import InteractionValues __all__ = ["TabularExplainer"] diff --git a/shapiq/explainer/tree/__init__.py b/shapiq/explainer/tree/__init__.py index 43d65f18..124bef27 100644 --- a/shapiq/explainer/tree/__init__.py +++ b/shapiq/explainer/tree/__init__.py @@ -1,4 +1,5 @@ """This module contains the tree explainer implementation.""" + from .base import TreeModel from .explainer import TreeExplainer from .treeshapiq import TreeSHAPIQ diff --git a/shapiq/explainer/tree/base.py b/shapiq/explainer/tree/base.py index 7fac1ca8..d2773162 100644 --- a/shapiq/explainer/tree/base.py +++ b/shapiq/explainer/tree/base.py @@ -1,4 +1,5 @@ """This module contains the base class for tree model conversion.""" + from dataclasses import dataclass from typing import Any, Optional diff --git a/shapiq/explainer/tree/conversion/edges.py b/shapiq/explainer/tree/conversion/edges.py index c6bc953f..445be846 100644 --- a/shapiq/explainer/tree/conversion/edges.py +++ b/shapiq/explainer/tree/conversion/edges.py @@ -1,6 +1,7 @@ """This module contains the conversion functions to parse a tree model into the edge representation. The edge representation is used by the TreeSHAP-IQ algorithm to compute the interaction values of a tree-based model.""" + import numpy as np from scipy.special import binom diff --git a/shapiq/explainer/tree/conversion/sklearn.py b/shapiq/explainer/tree/conversion/sklearn.py index fd652f8a..8d0ece47 100644 --- a/shapiq/explainer/tree/conversion/sklearn.py +++ b/shapiq/explainer/tree/conversion/sklearn.py @@ -1,5 +1,5 @@ """This module contains functions for converting scikit-learn decision trees to the format used by - shapiq.""" +shapiq.""" from typing import Optional diff --git a/shapiq/explainer/tree/explainer.py b/shapiq/explainer/tree/explainer.py index 1abfea30..fb25d4fe 100644 --- a/shapiq/explainer/tree/explainer.py +++ b/shapiq/explainer/tree/explainer.py @@ -1,5 +1,6 @@ """This module contains the TreeExplainer class making use of the TreeSHAPIQ algorithm for computing any-order Shapley Interactions for tree ensembles.""" + import copy from typing import Any, Optional, Union diff --git a/shapiq/explainer/tree/treeshapiq.py b/shapiq/explainer/tree/treeshapiq.py index d0a6ea6c..7fd6d60a 100644 --- a/shapiq/explainer/tree/treeshapiq.py +++ b/shapiq/explainer/tree/treeshapiq.py @@ -1,4 +1,5 @@ """This module contains the tree explainer implementation.""" + import copy from math import factorial from typing import Any, Optional, Union @@ -489,7 +490,9 @@ def _precompute_subsets_with_feature( # prepare the interaction updates and positions for feature_i in range(n_features): - positions = np.zeros(int(sp.special.binom(n_features - 1, interaction_order - 1)), dtype=int) + positions = np.zeros( + int(sp.special.binom(n_features - 1, interaction_order - 1)), dtype=int + ) interaction_update_positions[feature_i] = positions.copy() interaction_updates[feature_i] = [] diff --git a/shapiq/explainer/tree/validation.py b/shapiq/explainer/tree/validation.py index e12da4f3..a27462ed 100644 --- a/shapiq/explainer/tree/validation.py +++ b/shapiq/explainer/tree/validation.py @@ -1,4 +1,5 @@ """This module contains conversion functions for the tree explainer implementation.""" + import warnings from typing import Any, Optional, Union diff --git a/shapiq/games/__init__.py b/shapiq/games/__init__.py index aeb75a53..c89f6a3d 100644 --- a/shapiq/games/__init__.py +++ b/shapiq/games/__init__.py @@ -1,6 +1,5 @@ """This module contains sample game functions for the shapiq package.""" - from .dummy import DummyGame __all__ = [ diff --git a/shapiq/interaction_values.py b/shapiq/interaction_values.py index ffcf5a48..848a89bf 100644 --- a/shapiq/interaction_values.py +++ b/shapiq/interaction_values.py @@ -1,11 +1,13 @@ """This module contains the InteractionValues Dataclass, which is used to store the interaction scores.""" + import copy import warnings from dataclasses import dataclass from typing import Optional, Union import numpy as np + from shapiq.utils import generate_interaction_lookup, powerset AVAILABLE_INDICES = {"k-SII", "SII", "STI", "FSI", "SV", "BZF"} diff --git a/shapiq/plot/network.py b/shapiq/plot/network.py index 56ceb6d7..2400f0c3 100644 --- a/shapiq/plot/network.py +++ b/shapiq/plot/network.py @@ -4,9 +4,9 @@ import math from typing import Any, Optional, Union -import numpy as np import matplotlib.pyplot as plt import networkx as nx +import numpy as np from PIL import Image from shapiq.interaction_values import InteractionValues diff --git a/shapiq/plot/stacked_bar.py b/shapiq/plot/stacked_bar.py index 86caf975..3b941412 100644 --- a/shapiq/plot/stacked_bar.py +++ b/shapiq/plot/stacked_bar.py @@ -3,8 +3,8 @@ from copy import deepcopy from typing import Optional, Union -import numpy as np import matplotlib.pyplot as plt +import numpy as np from matplotlib.patches import Patch from ._config import COLORS_N_SII diff --git a/shapiq/utils/types.py b/shapiq/utils/types.py index b7773d05..acb90a0a 100644 --- a/shapiq/utils/types.py +++ b/shapiq/utils/types.py @@ -1,4 +1,5 @@ """This module contains all custom types used in the shapiq package.""" + from typing import TypeVar # Model type for all machine learning models diff --git a/tests/conftest.py b/tests/conftest.py index 6e16e374..92ff9161 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,11 +2,12 @@ If it becomes too large, it can be split into multiple files like here: https://gist.github.com/peterhurford/09f7dcda0ab04b95c026c60fa49c2a68 """ + import numpy as np import pytest -from sklearn.datasets import make_regression, make_classification -from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier -from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier +from sklearn.datasets import make_classification, make_regression +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from shapiq.explainer.tree import TreeModel diff --git a/tests/test_abstract_classes.py b/tests/test_abstract_classes.py index 6f496796..380e1699 100644 --- a/tests/test_abstract_classes.py +++ b/tests/test_abstract_classes.py @@ -1,11 +1,12 @@ """This test module contains all tests regarding the base approximator class.""" + import numpy as np import pytest -from shapiq.games.base import Game from shapiq.approximator._base import Approximator -from shapiq.explainer.imputer._base import Imputer from shapiq.explainer._base import Explainer +from shapiq.explainer.imputer._base import Imputer +from shapiq.games.base import Game def concreter(abclass): diff --git a/tests/test_base_interaction_values.py b/tests/test_base_interaction_values.py index 1db9a461..a52b9761 100644 --- a/tests/test_base_interaction_values.py +++ b/tests/test_base_interaction_values.py @@ -1,4 +1,5 @@ """This test module contains all tests regarding the InteractionValues dataclass.""" + from copy import copy, deepcopy import numpy as np diff --git a/tests/test_integration_import_all.py b/tests/test_integration_import_all.py index 43439074..400a983b 100644 --- a/tests/test_integration_import_all.py +++ b/tests/test_integration_import_all.py @@ -4,15 +4,11 @@ import importlib import pkgutil import sys + import pytest import shapiq -from shapiq import approximator -from shapiq import explainer -from shapiq import games -from shapiq import utils -from shapiq import plot -from shapiq import datasets +from shapiq import approximator, datasets, explainer, games, plot, utils @pytest.mark.parametrize( diff --git a/tests/tests_approximators/test_approximator_ksii_estimation.py b/tests/tests_approximators/test_approximator_ksii_estimation.py index 96b6e3a6..9cbe527d 100644 --- a/tests/tests_approximators/test_approximator_ksii_estimation.py +++ b/tests/tests_approximators/test_approximator_ksii_estimation.py @@ -1,12 +1,13 @@ """Tests the approximiation of nSII values with PermutationSamplingSII and ShapIQ.""" + import numpy as np import pytest from shapiq.approximator import ( - convert_ksii_into_one_dimension, - transforms_sii_to_ksii, PermutationSamplingSII, ShapIQ, + convert_ksii_into_one_dimension, + transforms_sii_to_ksii, ) from shapiq.games import DummyGame diff --git a/tests/tests_approximators/test_approximator_permutation_sii.py b/tests/tests_approximators/test_approximator_permutation_sii.py index 9b7387c1..adc7a2e4 100644 --- a/tests/tests_approximators/test_approximator_permutation_sii.py +++ b/tests/tests_approximators/test_approximator_permutation_sii.py @@ -1,12 +1,13 @@ """This test module contains all tests regarding the SII permutation sampling approximator.""" + from copy import copy, deepcopy import numpy as np import pytest -from shapiq.interaction_values import InteractionValues from shapiq.approximator.permutation import PermutationSamplingSII from shapiq.games import DummyGame +from shapiq.interaction_values import InteractionValues @pytest.mark.parametrize( diff --git a/tests/tests_approximators/test_approximator_permutation_sti.py b/tests/tests_approximators/test_approximator_permutation_sti.py index 31dcb3a1..bf6c0905 100644 --- a/tests/tests_approximators/test_approximator_permutation_sti.py +++ b/tests/tests_approximators/test_approximator_permutation_sti.py @@ -1,12 +1,13 @@ """This test module contains all tests regarding the STI permutation sampling approximator.""" + from copy import copy, deepcopy import numpy as np import pytest -from shapiq.interaction_values import InteractionValues from shapiq.approximator.permutation import PermutationSamplingSTI from shapiq.games import DummyGame +from shapiq.interaction_values import InteractionValues @pytest.mark.parametrize( diff --git a/tests/tests_approximators/test_approximator_regression_fsi.py b/tests/tests_approximators/test_approximator_regression_fsi.py index e0d4fd66..389db30b 100644 --- a/tests/tests_approximators/test_approximator_regression_fsi.py +++ b/tests/tests_approximators/test_approximator_regression_fsi.py @@ -1,12 +1,13 @@ """This test module contains all tests regarding the FSI regression approximator.""" -from copy import deepcopy, copy + +from copy import copy, deepcopy import numpy as np import pytest -from shapiq.interaction_values import InteractionValues from shapiq.approximator.regression import RegressionFSI from shapiq.games import DummyGame +from shapiq.interaction_values import InteractionValues @pytest.mark.parametrize( diff --git a/tests/tests_approximators/test_approximator_regression_sii.py b/tests/tests_approximators/test_approximator_regression_sii.py index 55684cfe..b36cee0a 100644 --- a/tests/tests_approximators/test_approximator_regression_sii.py +++ b/tests/tests_approximators/test_approximator_regression_sii.py @@ -1,13 +1,14 @@ """This test module contains all tests regarding the SII regression approximator.""" -from copy import deepcopy, copy + +from copy import copy, deepcopy import numpy as np import pytest -from shapiq.interaction_values import InteractionValues -from shapiq.approximator.regression._base import Regression from shapiq.approximator.regression import RegressionSII +from shapiq.approximator.regression._base import Regression from shapiq.games import DummyGame +from shapiq.interaction_values import InteractionValues @pytest.mark.parametrize( diff --git a/tests/tests_approximators/test_approximator_regression_sv.py b/tests/tests_approximators/test_approximator_regression_sv.py index 64f1ab07..c66daeb6 100644 --- a/tests/tests_approximators/test_approximator_regression_sv.py +++ b/tests/tests_approximators/test_approximator_regression_sv.py @@ -1,12 +1,13 @@ """This test module contains all tests regarding the SV KernelSHAP regression approximator.""" -from copy import deepcopy, copy + +from copy import copy, deepcopy import numpy as np import pytest -from shapiq.interaction_values import InteractionValues from shapiq.approximator.regression import KernelSHAP from shapiq.games import DummyGame +from shapiq.interaction_values import InteractionValues @pytest.mark.parametrize( diff --git a/tests/tests_approximators/test_approximator_shapiq.py b/tests/tests_approximators/test_approximator_shapiq.py index f4f082f6..7dccdd90 100644 --- a/tests/tests_approximators/test_approximator_shapiq.py +++ b/tests/tests_approximators/test_approximator_shapiq.py @@ -1,12 +1,13 @@ """This test module contains all tests regarding the shapiq approximator.""" + from copy import copy, deepcopy import numpy as np import pytest from shapiq.approximator.shapiq import ShapIQ -from shapiq.interaction_values import InteractionValues from shapiq.games import DummyGame +from shapiq.interaction_values import InteractionValues @pytest.mark.parametrize( diff --git a/tests/tests_datasets/test_bike.py b/tests/tests_datasets/test_bike.py index e973f3e6..c76a0a54 100644 --- a/tests/tests_datasets/test_bike.py +++ b/tests/tests_datasets/test_bike.py @@ -1,4 +1,5 @@ """This test module contains the tests for the bike dataset.""" + from shapiq import load_bike diff --git a/tests/tests_explainer/test_explainer_tabular.py b/tests/tests_explainer/test_explainer_tabular.py index e3bcc6bc..ad840ec4 100644 --- a/tests/tests_explainer/test_explainer_tabular.py +++ b/tests/tests_explainer/test_explainer_tabular.py @@ -1,14 +1,12 @@ -"""This test module contains all tests regarding the interaciton explainer for the shapiq package. -""" +"""This test module contains all tests regarding the interaciton explainer for the shapiq package.""" import pytest - -from sklearn.tree import DecisionTreeRegressor -from sklearn.ensemble import RandomForestRegressor from sklearn.datasets import make_regression +from sklearn.ensemble import RandomForestRegressor +from sklearn.tree import DecisionTreeRegressor -from shapiq.explainer import TabularExplainer from shapiq.approximator import RegressionFSI +from shapiq.explainer import TabularExplainer @pytest.fixture diff --git a/tests/tests_explainer/tests_imputer/test_marginal_imputer.py b/tests/tests_explainer/tests_imputer/test_marginal_imputer.py index 0593629d..65a268d0 100644 --- a/tests/tests_explainer/tests_imputer/test_marginal_imputer.py +++ b/tests/tests_explainer/tests_imputer/test_marginal_imputer.py @@ -1,4 +1,5 @@ """This test module contains all tests for the marginal imputer module of the shapiq package.""" + import numpy as np from shapiq.explainer.imputer import MarginalImputer diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py index 547710b6..faef7d20 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py @@ -1,9 +1,9 @@ """This test module contains all tests for the tree explainer module of the shapiq package.""" + import numpy as np import pytest -from shapiq.explainer.tree import TreeModel -from shapiq.explainer.tree import TreeExplainer +from shapiq.explainer.tree import TreeExplainer, TreeModel def test_decision_tree_classifier(dt_clf_model, background_clf_data): diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py index b1584cc8..a6b31d28 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_conversion.py @@ -1,12 +1,12 @@ """This test module collects all tests for the conversions of the supported tree models for the TreeExplainer class.""" -import numpy as np +import numpy as np -from shapiq.utils import safe_isinstance from shapiq.explainer.tree.base import TreeModel -from shapiq.explainer.tree.conversion.sklearn import convert_sklearn_tree, convert_sklearn_forest from shapiq.explainer.tree.conversion.edges import create_edge_tree +from shapiq.explainer.tree.conversion.sklearn import convert_sklearn_forest, convert_sklearn_tree +from shapiq.utils import safe_isinstance def test_tree_model_init(): diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_utils.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_utils.py index c377099b..564478ce 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_utils.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_utils.py @@ -1,4 +1,5 @@ """This test module collects all tests for the utility functions of the tree explainer.""" + import numpy as np from shapiq.explainer.tree.utils import ( diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py index dae03967..78ddce24 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer_validate.py @@ -1,9 +1,10 @@ """This test module contains all tests for the validation functions of the tree explainer implementation.""" + import copy -import pytest import numpy as np +import pytest from shapiq import safe_isinstance from shapiq.explainer.tree.validation import validate_tree_model diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py index 49e971df..e7a65715 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_treeshapiq.py @@ -1,4 +1,5 @@ """This module contains all tests for the TreeExplainer class of the shapiq package.""" + import numpy as np import pytest diff --git a/tests/tests_games/test_base_game.py b/tests/tests_games/test_base_game.py index f526654a..3eaa9c9d 100644 --- a/tests/tests_games/test_base_game.py +++ b/tests/tests_games/test_base_game.py @@ -5,8 +5,6 @@ import numpy as np import pytest - -from shapiq.games.base import Game from shapiq.games.dummy import DummyGame # used to test the base class diff --git a/tests/tests_games/test_games_dummy.py b/tests/tests_games/test_games_dummy.py index b817d0f1..732c2986 100644 --- a/tests/tests_games/test_games_dummy.py +++ b/tests/tests_games/test_games_dummy.py @@ -1,4 +1,5 @@ """This test module contains the tests for the DummyGame class.""" + import numpy as np import pytest diff --git a/tests/tests_plots/test_network_plot.py b/tests/tests_plots/test_network_plot.py index 9d306167..4a9f67c6 100644 --- a/tests/tests_plots/test_network_plot.py +++ b/tests/tests_plots/test_network_plot.py @@ -1,12 +1,13 @@ """This module contains all tests for the network plots.""" -import numpy as np -import scipy as sp + import matplotlib.pyplot as plt +import numpy as np import pytest +import scipy as sp from PIL import Image -from shapiq.plot import network_plot from shapiq.interaction_values import InteractionValues +from shapiq.plot import network_plot def test_network_plot(): diff --git a/tests/tests_plots/test_stacked_bar.py b/tests/tests_plots/test_stacked_bar.py index f05fb6e1..b3815540 100644 --- a/tests/tests_plots/test_stacked_bar.py +++ b/tests/tests_plots/test_stacked_bar.py @@ -1,6 +1,7 @@ """This module contains all tests for the stacked bar plots.""" -import numpy as np + import matplotlib.pyplot as plt +import numpy as np from shapiq.plot import stacked_bar_plot diff --git a/tests/tests_utils/test_utils_modules.py b/tests/tests_utils/test_utils_modules.py index 5ed5adce..77032bab 100644 --- a/tests/tests_utils/test_utils_modules.py +++ b/tests/tests_utils/test_utils_modules.py @@ -1,6 +1,6 @@ """This test module contains tests for utils.modules.""" -import pytest +import pytest from sklearn.tree import DecisionTreeRegressor from shapiq.utils import safe_isinstance, try_import diff --git a/tests/tests_utils/test_utils_sets.py b/tests/tests_utils/test_utils_sets.py index 63cb5757..9a2d095c 100644 --- a/tests/tests_utils/test_utils_sets.py +++ b/tests/tests_utils/test_utils_sets.py @@ -1,13 +1,14 @@ """This test module contains the test cases for the utils sets module.""" + import numpy as np import pytest from shapiq.utils import ( - powerset, + generate_interaction_lookup, + get_explicit_subsets, pair_subset_sizes, + powerset, split_subsets_budget, - get_explicit_subsets, - generate_interaction_lookup, transform_coalitions_to_array, )