Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional Tests #29

Merged
merged 3 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ nbsphinx==0.9.3
networkx==3.1
numpy==1.26.1
packaging==23.2
pandas==2.1.4
pandoc==2.3
pandocfilters==1.5.0
pathspec==0.11.2
Expand Down
2 changes: 2 additions & 0 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
get_parent_array,
powerset,
split_subsets_budget,
safe_isinstance,
)

from .datasets import load_bike
Expand All @@ -53,6 +54,7 @@
"split_subsets_budget",
"get_conditional_sample_weights",
"get_parent_array",
"safe_isinstance",
# datasets
"load_bike",
]
5 changes: 1 addition & 4 deletions shapiq/datasets/_all.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""This module contains functions to load datasets."""
import os
import pandas as pd

Expand Down Expand Up @@ -32,7 +33,3 @@ def load_bike() -> pd.DataFrame:
data.columns = list(map(str.title, data.columns))

return data


if __name__ == "__main__":
print(load_bike())
10 changes: 0 additions & 10 deletions shapiq/plot/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,3 @@
"BLUE",
"NEUTRAL",
]


if __name__ == "__main__":
red = [round(c * 255, 0) for c in RED.rgb]
blue = [round(c * 255, 0) for c in BLUE.rgb]
neutral = [round(c * 255, 0) for c in NEUTRAL.rgb]

print("RED", red)
print("BLUE", blue)
print("NEUTRAL", neutral)
2 changes: 1 addition & 1 deletion shapiq/plot/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def _add_legend_to_axis(axis: plt.Axes) -> None:

def network_plot(
*,
interaction_values: InteractionValues,
first_order_values: np.ndarray[float],
second_order_values: np.ndarray[float],
interaction_values: InteractionValues = None,
feature_names: Optional[list[Any]] = None,
feature_image_patches: Optional[dict[int, Image.Image]] = None,
feature_image_patches_size: Optional[Union[float, dict[int, float]]] = 0.2,
Expand Down
6 changes: 5 additions & 1 deletion shapiq/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

from .sets import get_explicit_subsets, pair_subset_sizes, powerset, split_subsets_budget
from .tree import get_conditional_sample_weights, get_parent_array
from .modules import safe_isinstance

__all__ = [
# sets
"powerset",
"pair_subset_sizes",
"split_subsets_budget",
"get_explicit_subsets",
# trees
# tree
"get_parent_array",
"get_conditional_sample_weights",
# modules
"safe_isinstance",
]
3 changes: 3 additions & 0 deletions tests/test_integration_import_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import games as games
import utils as utils
import plot as plot
import datasets as datasets


@pytest.mark.parametrize(
Expand All @@ -23,6 +24,7 @@
games,
utils,
plot,
datasets,
],
)
def test_import_package(package):
Expand All @@ -39,6 +41,7 @@ def test_import_package(package):
games,
utils,
plot,
datasets,
],
)
def test_import_submodules(package):
Expand Down
Empty file.
11 changes: 11 additions & 0 deletions tests/tests_datasets/test_bike.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""This test module contains the tests for the bike dataset."""

import pytest

from shapiq import load_bike


def test_load_bike():
data = load_bike()
# test if data is a pandas dataframe
assert isinstance(data, type(data))
Empty file added tests/tests_plots/__init__.py
Empty file.
68 changes: 68 additions & 0 deletions tests/tests_plots/test_network_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""This module contains all tests for the network plots."""
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from shapiq.plot import network_plot


def test_network_plot():
"""Tests whether the network plot can be created."""

first_order_values = np.asarray([0.1, -0.2, 0.3, 0.4, 0.5, 0.6])
second_order_values = np.random.rand(6, 6) - 0.5

fig, axes = network_plot(
first_order_values=first_order_values,
second_order_values=second_order_values,
)
assert fig is not None
assert axes is not None
plt.close(fig)

fig, axes = network_plot(
first_order_values=first_order_values[0:4],
second_order_values=second_order_values[0:4, 0:4],
feature_names=["a", "b", "c", "d"],
)
assert fig is not None
assert axes is not None
plt.close(fig)


def test_network_plot_with_image():
first_order_values = np.asarray([0.1, -0.2, 0.3, 0.4, 0.5, 0.6])
second_order_values = np.random.rand(6, 6) - 0.5
n_features = len(first_order_values)

# create dummyimage
image = np.random.rand(100, 100, 3)
image = (image * 255).astype(np.uint8)
image = Image.fromarray(image)

feature_image_patches: dict[int, Image.Image] = {}
feature_image_patches_size: dict[int, float] = {}
for feature_idx in range(n_features):
feature_image_patches[feature_idx] = image
feature_image_patches_size[feature_idx] = 0.1

fig, axes = network_plot(
first_order_values=first_order_values,
second_order_values=second_order_values,
center_image=image,
feature_image_patches=feature_image_patches,
)
assert fig is not None
assert axes is not None
plt.close(fig)

fig, axes = network_plot(
first_order_values=first_order_values,
second_order_values=second_order_values,
center_image=image,
feature_image_patches=feature_image_patches,
feature_image_patches_size=feature_image_patches_size,
)
assert fig is not None
assert axes is not None
plt.close(fig)
21 changes: 21 additions & 0 deletions tests/tests_utils/test_utils_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""This test module contains tests for utils.modules."""
import pytest

from shapiq.utils import safe_isinstance
from sklearn.tree import DecisionTreeRegressor


def test_safe_isinstance():
model = DecisionTreeRegressor()

assert safe_isinstance(model, "sklearn.tree.DecisionTreeRegressor")
assert safe_isinstance(
model, ["sklearn.tree.DecisionTreeClassifier", "sklearn.tree.DecisionTreeRegressor"]
)
assert safe_isinstance(model, ("sklearn.tree.DecisionTreeRegressor",))
with pytest.raises(ValueError):
safe_isinstance(model, "DecisionTreeRegressor")
with pytest.raises(ValueError):
safe_isinstance(model, None)
assert not safe_isinstance(model, "my.made.up.module")
assert not safe_isinstance(model, ["sklearn.ensemble.DecisionTreeRegressor"])