-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from mmschlk/development
Add additional Tests
- Loading branch information
Showing
12 changed files
with
113 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""This test module contains the tests for the bike dataset.""" | ||
|
||
import pytest | ||
|
||
from shapiq import load_bike | ||
|
||
|
||
def test_load_bike(): | ||
data = load_bike() | ||
# test if data is a pandas dataframe | ||
assert isinstance(data, type(data)) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""This module contains all tests for the network plots.""" | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from PIL import Image | ||
|
||
from shapiq.plot import network_plot | ||
|
||
|
||
def test_network_plot(): | ||
"""Tests whether the network plot can be created.""" | ||
|
||
first_order_values = np.asarray([0.1, -0.2, 0.3, 0.4, 0.5, 0.6]) | ||
second_order_values = np.random.rand(6, 6) - 0.5 | ||
|
||
fig, axes = network_plot( | ||
first_order_values=first_order_values, | ||
second_order_values=second_order_values, | ||
) | ||
assert fig is not None | ||
assert axes is not None | ||
plt.close(fig) | ||
|
||
fig, axes = network_plot( | ||
first_order_values=first_order_values[0:4], | ||
second_order_values=second_order_values[0:4, 0:4], | ||
feature_names=["a", "b", "c", "d"], | ||
) | ||
assert fig is not None | ||
assert axes is not None | ||
plt.close(fig) | ||
|
||
|
||
def test_network_plot_with_image(): | ||
first_order_values = np.asarray([0.1, -0.2, 0.3, 0.4, 0.5, 0.6]) | ||
second_order_values = np.random.rand(6, 6) - 0.5 | ||
n_features = len(first_order_values) | ||
|
||
# create dummyimage | ||
image = np.random.rand(100, 100, 3) | ||
image = (image * 255).astype(np.uint8) | ||
image = Image.fromarray(image) | ||
|
||
feature_image_patches: dict[int, Image.Image] = {} | ||
feature_image_patches_size: dict[int, float] = {} | ||
for feature_idx in range(n_features): | ||
feature_image_patches[feature_idx] = image | ||
feature_image_patches_size[feature_idx] = 0.1 | ||
|
||
fig, axes = network_plot( | ||
first_order_values=first_order_values, | ||
second_order_values=second_order_values, | ||
center_image=image, | ||
feature_image_patches=feature_image_patches, | ||
) | ||
assert fig is not None | ||
assert axes is not None | ||
plt.close(fig) | ||
|
||
fig, axes = network_plot( | ||
first_order_values=first_order_values, | ||
second_order_values=second_order_values, | ||
center_image=image, | ||
feature_image_patches=feature_image_patches, | ||
feature_image_patches_size=feature_image_patches_size, | ||
) | ||
assert fig is not None | ||
assert axes is not None | ||
plt.close(fig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""This test module contains tests for utils.modules.""" | ||
import pytest | ||
|
||
from shapiq.utils import safe_isinstance | ||
from sklearn.tree import DecisionTreeRegressor | ||
|
||
|
||
def test_safe_isinstance(): | ||
model = DecisionTreeRegressor() | ||
|
||
assert safe_isinstance(model, "sklearn.tree.DecisionTreeRegressor") | ||
assert safe_isinstance( | ||
model, ["sklearn.tree.DecisionTreeClassifier", "sklearn.tree.DecisionTreeRegressor"] | ||
) | ||
assert safe_isinstance(model, ("sklearn.tree.DecisionTreeRegressor",)) | ||
with pytest.raises(ValueError): | ||
safe_isinstance(model, "DecisionTreeRegressor") | ||
with pytest.raises(ValueError): | ||
safe_isinstance(model, None) | ||
assert not safe_isinstance(model, "my.made.up.module") | ||
assert not safe_isinstance(model, ["sklearn.ensemble.DecisionTreeRegressor"]) |