diff --git a/tests/tests_explainer/test_explainer_tabular.py b/tests/tests_explainer/test_explainer_tabular.py index 829eb586..0d5f33bf 100644 --- a/tests/tests_explainer/test_explainer_tabular.py +++ b/tests/tests_explainer/test_explainer_tabular.py @@ -176,25 +176,34 @@ def test_explain(dt_model, data, index, budget, max_order, imputer): def test_against_shap_linear(): """Tests weather TabularExplainer yields similar results as SHAP with a basic linear model.""" - import shap n_samples = 3 dim = 5 + rng = np.random.default_rng(42) def make_linear_model(): - w = np.random.default_rng().normal(size=dim) + w = rng.normal(size=dim) def model(X: np.ndarray): return np.dot(X, w) return model - X = np.random.default_rng().normal(size=(n_samples, dim)) + X = rng.normal(size=(n_samples, dim)) model = make_linear_model() + # import shap # compute with shap - explainer_shap = shap.explainers.Exact(model, X) - shap_values = explainer_shap(X).values + # explainer_shap = shap.explainers.Exact(model, X) + # shap_values = explainer_shap(X).values + # print(shap_values) + shap_values = np.array( + [ + [-0.29565839, -0.36698085, -0.55970434, 0.22567077, 0.05852208], + [1.08513574, 0.06365536, 0.46312977, -0.61532757, 0.00370387], + [-0.78947735, 0.30332549, 0.09657457, 0.38965679, -0.06222595], + ] + ) # compute with shapiq explainer_shapiq = TabularExplainer(