From 5a807391c89fbed967db3a57b2c492c3b70d31fe Mon Sep 17 00:00:00 2001 From: Maximilian Date: Thu, 4 Jan 2024 15:35:18 +0100 Subject: [PATCH 1/2] add additional tests for plots and utils --- shapiq/__init__.py | 2 + shapiq/datasets/_all.py | 5 +- shapiq/plot/_config.py | 10 ---- shapiq/plot/network.py | 2 +- shapiq/utils/__init__.py | 6 ++- tests/test_integration_import_all.py | 3 ++ tests/tests_datasets/__init__.py | 0 tests/tests_datasets/test_bike.py | 11 ++++ tests/tests_plots/__init__.py | 0 tests/tests_plots/test_network_plot.py | 68 +++++++++++++++++++++++++ tests/tests_utils/test_utils_modules.py | 21 ++++++++ 11 files changed, 112 insertions(+), 16 deletions(-) create mode 100644 tests/tests_datasets/__init__.py create mode 100644 tests/tests_datasets/test_bike.py create mode 100644 tests/tests_plots/__init__.py create mode 100644 tests/tests_plots/test_network_plot.py create mode 100644 tests/tests_utils/test_utils_modules.py diff --git a/shapiq/__init__.py b/shapiq/__init__.py index fa112c7d..72ea5e39 100644 --- a/shapiq/__init__.py +++ b/shapiq/__init__.py @@ -28,6 +28,7 @@ get_parent_array, powerset, split_subsets_budget, + safe_isinstance, ) from .datasets import load_bike @@ -53,6 +54,7 @@ "split_subsets_budget", "get_conditional_sample_weights", "get_parent_array", + "safe_isinstance", # datasets "load_bike", ] diff --git a/shapiq/datasets/_all.py b/shapiq/datasets/_all.py index f144566e..9053a107 100644 --- a/shapiq/datasets/_all.py +++ b/shapiq/datasets/_all.py @@ -1,3 +1,4 @@ +"""This module contains functions to load datasets.""" import os import pandas as pd @@ -32,7 +33,3 @@ def load_bike() -> pd.DataFrame: data.columns = list(map(str.title, data.columns)) return data - - -if __name__ == "__main__": - print(load_bike()) diff --git a/shapiq/plot/_config.py b/shapiq/plot/_config.py index ca728702..9bf34ca5 100644 --- a/shapiq/plot/_config.py +++ b/shapiq/plot/_config.py @@ -10,13 +10,3 @@ "BLUE", "NEUTRAL", ] - - -if __name__ == "__main__": - red = [round(c * 255, 0) for c in RED.rgb] - blue = [round(c * 255, 0) for c in BLUE.rgb] - neutral = [round(c * 255, 0) for c in NEUTRAL.rgb] - - print("RED", red) - print("BLUE", blue) - print("NEUTRAL", neutral) diff --git a/shapiq/plot/network.py b/shapiq/plot/network.py index 6640e98d..0f660533 100644 --- a/shapiq/plot/network.py +++ b/shapiq/plot/network.py @@ -131,9 +131,9 @@ def _add_legend_to_axis(axis: plt.Axes) -> None: def network_plot( *, - interaction_values: InteractionValues, first_order_values: np.ndarray[float], second_order_values: np.ndarray[float], + interaction_values: InteractionValues = None, feature_names: Optional[list[Any]] = None, feature_image_patches: Optional[dict[int, Image.Image]] = None, feature_image_patches_size: Optional[Union[float, dict[int, float]]] = 0.2, diff --git a/shapiq/utils/__init__.py b/shapiq/utils/__init__.py index 1fabe1ca..01f1ec3d 100644 --- a/shapiq/utils/__init__.py +++ b/shapiq/utils/__init__.py @@ -2,13 +2,17 @@ from .sets import get_explicit_subsets, pair_subset_sizes, powerset, split_subsets_budget from .tree import get_conditional_sample_weights, get_parent_array +from .modules import safe_isinstance __all__ = [ + # sets "powerset", "pair_subset_sizes", "split_subsets_budget", "get_explicit_subsets", - # trees + # tree "get_parent_array", "get_conditional_sample_weights", + # modules + "safe_isinstance", ] diff --git a/tests/test_integration_import_all.py b/tests/test_integration_import_all.py index b8c84933..448ba68c 100644 --- a/tests/test_integration_import_all.py +++ b/tests/test_integration_import_all.py @@ -12,6 +12,7 @@ import games as games import utils as utils import plot as plot +import datasets as datasets @pytest.mark.parametrize( @@ -23,6 +24,7 @@ games, utils, plot, + datasets, ], ) def test_import_package(package): @@ -39,6 +41,7 @@ def test_import_package(package): games, utils, plot, + datasets, ], ) def test_import_submodules(package): diff --git a/tests/tests_datasets/__init__.py b/tests/tests_datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tests_datasets/test_bike.py b/tests/tests_datasets/test_bike.py new file mode 100644 index 00000000..6b51d729 --- /dev/null +++ b/tests/tests_datasets/test_bike.py @@ -0,0 +1,11 @@ +"""This test module contains the tests for the bike dataset.""" + +import pytest + +from shapiq import load_bike + + +def test_load_bike(): + data = load_bike() + # test if data is a pandas dataframe + assert isinstance(data, type(data)) diff --git a/tests/tests_plots/__init__.py b/tests/tests_plots/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tests_plots/test_network_plot.py b/tests/tests_plots/test_network_plot.py new file mode 100644 index 00000000..47404c39 --- /dev/null +++ b/tests/tests_plots/test_network_plot.py @@ -0,0 +1,68 @@ +"""This module contains all tests for the network plots.""" +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image + +from shapiq.plot import network_plot + + +def test_network_plot(): + """Tests whether the network plot can be created.""" + + first_order_values = np.asarray([0.1, -0.2, 0.3, 0.4, 0.5, 0.6]) + second_order_values = np.random.rand(6, 6) - 0.5 + + fig, axes = network_plot( + first_order_values=first_order_values, + second_order_values=second_order_values, + ) + assert fig is not None + assert axes is not None + plt.close(fig) + + fig, axes = network_plot( + first_order_values=first_order_values[0:4], + second_order_values=second_order_values[0:4, 0:4], + feature_names=["a", "b", "c", "d"], + ) + assert fig is not None + assert axes is not None + plt.close(fig) + + +def test_network_plot_with_image(): + first_order_values = np.asarray([0.1, -0.2, 0.3, 0.4, 0.5, 0.6]) + second_order_values = np.random.rand(6, 6) - 0.5 + n_features = len(first_order_values) + + # create dummyimage + image = np.random.rand(100, 100, 3) + image = (image * 255).astype(np.uint8) + image = Image.fromarray(image) + + feature_image_patches: dict[int, Image.Image] = {} + feature_image_patches_size: dict[int, float] = {} + for feature_idx in range(n_features): + feature_image_patches[feature_idx] = image + feature_image_patches_size[feature_idx] = 0.1 + + fig, axes = network_plot( + first_order_values=first_order_values, + second_order_values=second_order_values, + center_image=image, + feature_image_patches=feature_image_patches, + ) + assert fig is not None + assert axes is not None + plt.close(fig) + + fig, axes = network_plot( + first_order_values=first_order_values, + second_order_values=second_order_values, + center_image=image, + feature_image_patches=feature_image_patches, + feature_image_patches_size=feature_image_patches_size, + ) + assert fig is not None + assert axes is not None + plt.close(fig) diff --git a/tests/tests_utils/test_utils_modules.py b/tests/tests_utils/test_utils_modules.py new file mode 100644 index 00000000..d36d9b52 --- /dev/null +++ b/tests/tests_utils/test_utils_modules.py @@ -0,0 +1,21 @@ +"""This test module contains tests for utils.modules.""" +import pytest + +from shapiq.utils import safe_isinstance +from sklearn.tree import DecisionTreeRegressor + + +def test_safe_isinstance(): + model = DecisionTreeRegressor() + + assert safe_isinstance(model, "sklearn.tree.DecisionTreeRegressor") + assert safe_isinstance( + model, ["sklearn.tree.DecisionTreeClassifier", "sklearn.tree.DecisionTreeRegressor"] + ) + assert safe_isinstance(model, ("sklearn.tree.DecisionTreeRegressor",)) + with pytest.raises(ValueError): + safe_isinstance(model, "DecisionTreeRegressor") + with pytest.raises(ValueError): + safe_isinstance(model, None) + assert not safe_isinstance(model, "my.made.up.module") + assert not safe_isinstance(model, ["sklearn.ensemble.DecisionTreeRegressor"]) From e78874886915f0ae524888de995f306a3a57b67a Mon Sep 17 00:00:00 2001 From: Maximilian Date: Thu, 4 Jan 2024 15:44:16 +0100 Subject: [PATCH 2/2] bugfix pandas as new dependency --- docs/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index 81e5660f..02474db9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -57,6 +57,7 @@ nbsphinx==0.9.3 networkx==3.1 numpy==1.26.1 packaging==23.2 +pandas==2.1.4 pandoc==2.3 pandocfilters==1.5.0 pathspec==0.11.2