Skip to content

Commit

Permalink
removed call to shap in test
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Jan 10, 2025
1 parent ddcf5c7 commit 7fcafa2
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions tests/tests_explainer/test_explainer_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,25 +176,34 @@ def test_explain(dt_model, data, index, budget, max_order, imputer):

def test_against_shap_linear():
"""Tests weather TabularExplainer yields similar results as SHAP with a basic linear model."""
import shap

n_samples = 3
dim = 5
rng = np.random.default_rng(42)

def make_linear_model():
w = np.random.default_rng().normal(size=dim)
w = rng.normal(size=dim)

def model(X: np.ndarray):
return np.dot(X, w)

return model

X = np.random.default_rng().normal(size=(n_samples, dim))
X = rng.normal(size=(n_samples, dim))
model = make_linear_model()

# import shap
# compute with shap
explainer_shap = shap.explainers.Exact(model, X)
shap_values = explainer_shap(X).values
# explainer_shap = shap.explainers.Exact(model, X)
# shap_values = explainer_shap(X).values
# print(shap_values)
shap_values = np.array(
[
[-0.29565839, -0.36698085, -0.55970434, 0.22567077, 0.05852208],
[1.08513574, 0.06365536, 0.46312977, -0.61532757, 0.00370387],
[-0.78947735, 0.30332549, 0.09657457, 0.38965679, -0.06222595],
]
)

# compute with shapiq
explainer_shapiq = TabularExplainer(
Expand Down

0 comments on commit 7fcafa2

Please sign in to comment.