Skip to content

Commit

Permalink
Merge pull request #29 from mmschlk/development
Browse files Browse the repository at this point in the history
Add additional Tests
  • Loading branch information
mmschlk authored Jan 4, 2024
2 parents 829f3f4 + e788748 commit 0c0bab7
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ nbsphinx==0.9.3
networkx==3.1
numpy==1.26.1
packaging==23.2
pandas==2.1.4
pandoc==2.3
pandocfilters==1.5.0
pathspec==0.11.2
Expand Down
2 changes: 2 additions & 0 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
get_parent_array,
powerset,
split_subsets_budget,
safe_isinstance,
)

from .datasets import load_bike
Expand All @@ -53,6 +54,7 @@
"split_subsets_budget",
"get_conditional_sample_weights",
"get_parent_array",
"safe_isinstance",
# datasets
"load_bike",
]
5 changes: 1 addition & 4 deletions shapiq/datasets/_all.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""This module contains functions to load datasets."""
import os
import pandas as pd

Expand Down Expand Up @@ -32,7 +33,3 @@ def load_bike() -> pd.DataFrame:
data.columns = list(map(str.title, data.columns))

return data


if __name__ == "__main__":
print(load_bike())
10 changes: 0 additions & 10 deletions shapiq/plot/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,3 @@
"BLUE",
"NEUTRAL",
]


if __name__ == "__main__":
red = [round(c * 255, 0) for c in RED.rgb]
blue = [round(c * 255, 0) for c in BLUE.rgb]
neutral = [round(c * 255, 0) for c in NEUTRAL.rgb]

print("RED", red)
print("BLUE", blue)
print("NEUTRAL", neutral)
2 changes: 1 addition & 1 deletion shapiq/plot/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def _add_legend_to_axis(axis: plt.Axes) -> None:

def network_plot(
*,
interaction_values: InteractionValues,
first_order_values: np.ndarray[float],
second_order_values: np.ndarray[float],
interaction_values: InteractionValues = None,
feature_names: Optional[list[Any]] = None,
feature_image_patches: Optional[dict[int, Image.Image]] = None,
feature_image_patches_size: Optional[Union[float, dict[int, float]]] = 0.2,
Expand Down
6 changes: 5 additions & 1 deletion shapiq/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

from .sets import get_explicit_subsets, pair_subset_sizes, powerset, split_subsets_budget
from .tree import get_conditional_sample_weights, get_parent_array
from .modules import safe_isinstance

__all__ = [
# sets
"powerset",
"pair_subset_sizes",
"split_subsets_budget",
"get_explicit_subsets",
# trees
# tree
"get_parent_array",
"get_conditional_sample_weights",
# modules
"safe_isinstance",
]
3 changes: 3 additions & 0 deletions tests/test_integration_import_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import games as games
import utils as utils
import plot as plot
import datasets as datasets


@pytest.mark.parametrize(
Expand All @@ -23,6 +24,7 @@
games,
utils,
plot,
datasets,
],
)
def test_import_package(package):
Expand All @@ -39,6 +41,7 @@ def test_import_package(package):
games,
utils,
plot,
datasets,
],
)
def test_import_submodules(package):
Expand Down
Empty file.
11 changes: 11 additions & 0 deletions tests/tests_datasets/test_bike.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""This test module contains the tests for the bike dataset."""

import pytest

from shapiq import load_bike


def test_load_bike():
data = load_bike()
# test if data is a pandas dataframe
assert isinstance(data, type(data))
Empty file added tests/tests_plots/__init__.py
Empty file.
68 changes: 68 additions & 0 deletions tests/tests_plots/test_network_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""This module contains all tests for the network plots."""
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from shapiq.plot import network_plot


def test_network_plot():
"""Tests whether the network plot can be created."""

first_order_values = np.asarray([0.1, -0.2, 0.3, 0.4, 0.5, 0.6])
second_order_values = np.random.rand(6, 6) - 0.5

fig, axes = network_plot(
first_order_values=first_order_values,
second_order_values=second_order_values,
)
assert fig is not None
assert axes is not None
plt.close(fig)

fig, axes = network_plot(
first_order_values=first_order_values[0:4],
second_order_values=second_order_values[0:4, 0:4],
feature_names=["a", "b", "c", "d"],
)
assert fig is not None
assert axes is not None
plt.close(fig)


def test_network_plot_with_image():
first_order_values = np.asarray([0.1, -0.2, 0.3, 0.4, 0.5, 0.6])
second_order_values = np.random.rand(6, 6) - 0.5
n_features = len(first_order_values)

# create dummyimage
image = np.random.rand(100, 100, 3)
image = (image * 255).astype(np.uint8)
image = Image.fromarray(image)

feature_image_patches: dict[int, Image.Image] = {}
feature_image_patches_size: dict[int, float] = {}
for feature_idx in range(n_features):
feature_image_patches[feature_idx] = image
feature_image_patches_size[feature_idx] = 0.1

fig, axes = network_plot(
first_order_values=first_order_values,
second_order_values=second_order_values,
center_image=image,
feature_image_patches=feature_image_patches,
)
assert fig is not None
assert axes is not None
plt.close(fig)

fig, axes = network_plot(
first_order_values=first_order_values,
second_order_values=second_order_values,
center_image=image,
feature_image_patches=feature_image_patches,
feature_image_patches_size=feature_image_patches_size,
)
assert fig is not None
assert axes is not None
plt.close(fig)
21 changes: 21 additions & 0 deletions tests/tests_utils/test_utils_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""This test module contains tests for utils.modules."""
import pytest

from shapiq.utils import safe_isinstance
from sklearn.tree import DecisionTreeRegressor


def test_safe_isinstance():
model = DecisionTreeRegressor()

assert safe_isinstance(model, "sklearn.tree.DecisionTreeRegressor")
assert safe_isinstance(
model, ["sklearn.tree.DecisionTreeClassifier", "sklearn.tree.DecisionTreeRegressor"]
)
assert safe_isinstance(model, ("sklearn.tree.DecisionTreeRegressor",))
with pytest.raises(ValueError):
safe_isinstance(model, "DecisionTreeRegressor")
with pytest.raises(ValueError):
safe_isinstance(model, None)
assert not safe_isinstance(model, "my.made.up.module")
assert not safe_isinstance(model, ["sklearn.ensemble.DecisionTreeRegressor"])

0 comments on commit 0c0bab7

Please sign in to comment.