-
Notifications
You must be signed in to change notification settings - Fork 851
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bug fixes for the FACTS method (#533)
* bugfix: drop_above arg used hard-coded names, removed * FACTS hotfix: drop_infeasible set to False this part of the code is problematic due to the use of hardcoded feature names. The required functionality should be achieved in some other way. * FACTS bugfix: feat weights were not passed properly * removed obsolete use of drop_above argument * FACTS: added test for user interface API * FACTS_bias_scan: test improvement Previously, the test case only had inf costs. Consequently, the exact values of the feature weights were not actually tested properly.
- Loading branch information
1 parent
7c4f172
commit e011686
Showing
4 changed files
with
179 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import numpy as np | ||
import pandas as pd | ||
|
||
import pytest | ||
|
||
from aif360.sklearn.detectors import FACTS, FACTS_bias_scan | ||
|
||
from aif360.sklearn.detectors.facts.predicate import Predicate | ||
from aif360.sklearn.detectors.facts.parameters import ParameterProxy, feature_change_builder | ||
|
||
def test_FACTS(): | ||
class MockModel: | ||
def predict(self, X: pd.DataFrame) -> np.ndarray: | ||
ret = [] | ||
for i, r in X.iterrows(): | ||
if r["a"] > 20: | ||
ret.append(1) | ||
elif r["c"] < 15: | ||
ret.append(1) | ||
else: | ||
ret.append(0) | ||
return np.array(ret) | ||
|
||
X = pd.DataFrame( | ||
[ | ||
[21, 2, 3, 4, "Female", pd.Interval(60, 70)], | ||
[21, 13, 3, 19, "Male", pd.Interval(60, 70)], | ||
[25, 2, 7, 4, "Female", pd.Interval(60, 70)], | ||
[21, 2, 3, 4, "Male", pd.Interval(60, 70)], | ||
[1, 2, 3, 4, "Male", pd.Interval(20, 30)], | ||
[1, 20, 30, 40, "Male", pd.Interval(40, 50)], | ||
[19, 2, 30, 43, "Male", pd.Interval(30, 40)], | ||
[19, 13, 30, 4, "Male", pd.Interval(10, 20)], | ||
[1, 2, 30, 4, "Female", pd.Interval(20, 30)], | ||
[19, 20, 30, 40, "Female", pd.Interval(40, 50)], | ||
[19, 2, 30, 4, "Female", pd.Interval(30, 40)], | ||
], | ||
columns=["a", "b", "c", "d", "sex", "age"] | ||
) | ||
model = MockModel() | ||
|
||
detector = FACTS( | ||
clf=model, | ||
prot_attr="sex", | ||
categorical_features=["sex", "age"], | ||
freq_itemset_min_supp=0.5, | ||
feature_weights={f: 10 for f in X.columns}, | ||
feats_not_allowed_to_change=[], | ||
) | ||
detector.fit(X, verbose=False) | ||
|
||
expected_ifthens = { | ||
Predicate.from_dict({"a": 19}): { | ||
"Male": (2/3, [ | ||
(Predicate.from_dict({"a": 21}), 1., 20.) | ||
]), | ||
"Female": (2/3, [ | ||
(Predicate.from_dict({"a": 21}), 1., 20.) | ||
]) | ||
}, | ||
Predicate.from_dict({"c": 30}): { | ||
"Male": (1., [ | ||
(Predicate.from_dict({"c": 3}), 1., 270.) | ||
]), | ||
"Female": (1., [ | ||
(Predicate.from_dict({"c": 3}), 1., 270.) | ||
]) | ||
}, | ||
Predicate.from_dict({"a": 19, "c": 30}): { | ||
"Male": (2/3, [ | ||
(Predicate.from_dict({"a": 21, "c": 3}), 1., 290.) | ||
]), | ||
"Female": (2/3, [ | ||
(Predicate.from_dict({"a": 21, "c": 3}), 1., 290.) | ||
]) | ||
}, | ||
} | ||
|
||
assert set(expected_ifthens.keys()) == set(detector.rules_by_if) | ||
for ifclause, all_thens in expected_ifthens.items(): | ||
assert detector.rules_by_if[ifclause] == all_thens | ||
|
||
def test_FACTS_bias_scan(): | ||
class MockModel: | ||
def predict(self, X: pd.DataFrame) -> np.ndarray: | ||
ret = [] | ||
for i, r in X.iterrows(): | ||
if r["sex"] == "Female" and r["d"] < 15: | ||
if r["c"] < 5: | ||
ret.append(1) | ||
else: | ||
ret.append(0) | ||
elif r["a"] > 20: | ||
ret.append(1) | ||
elif r["c"] < 15: | ||
ret.append(1) | ||
else: | ||
ret.append(0) | ||
return np.array(ret) | ||
|
||
X = pd.DataFrame( | ||
[ | ||
[21, 2, 3, 20, "Female", pd.Interval(60, 70)], | ||
[21, 13, 3, 19, "Male", pd.Interval(60, 70)], | ||
[25, 2, 7, 21, "Female", pd.Interval(60, 70)], | ||
[21, 2, 3, 4, "Male", pd.Interval(60, 70)], | ||
[1, 2, 7, 4, "Male", pd.Interval(20, 30)], | ||
[1, 2, 7, 40, "Female", pd.Interval(20, 30)], | ||
[1, 20, 30, 40, "Male", pd.Interval(40, 50)], | ||
[19, 2, 30, 43, "Male", pd.Interval(30, 40)], | ||
[19, 13, 30, 4, "Male", pd.Interval(10, 20)], | ||
[1, 2, 30, 4, "Female", pd.Interval(20, 30)], | ||
[19, 20, 30, 7, "Female", pd.Interval(40, 50)], | ||
[19, 2, 30, 4, "Female", pd.Interval(30, 40)], | ||
], | ||
columns=["a", "b", "c", "d", "sex", "age"] | ||
) | ||
model = MockModel() | ||
|
||
most_biased_subgroups = FACTS_bias_scan( | ||
X=X, | ||
clf=model, | ||
prot_attr="sex", | ||
metric="equal-cost-of-effectiveness", | ||
categorical_features=["sex", "age"], | ||
freq_itemset_min_supp=0.5, | ||
feature_weights={f: 10 for f in X.columns}, | ||
feats_not_allowed_to_change=[], | ||
viewpoint="macro", | ||
sort_strategy="max-cost-diff-decr", | ||
top_count=3, | ||
phi=0.5, | ||
verbose=False, | ||
print_recourse_report=False, | ||
) | ||
|
||
# just so we can see them here | ||
expected_ifthens = { | ||
Predicate.from_dict({"a": 19}): { | ||
"Male": (2/3, [ | ||
(Predicate.from_dict({"a": 21}), 1., 20.) | ||
]), | ||
"Female": (2/3, [ | ||
(Predicate.from_dict({"a": 21}), 0., 20.) | ||
]) | ||
}, | ||
Predicate.from_dict({"c": 30}): { | ||
"Male": (1., [ | ||
(Predicate.from_dict({"c": 7}), 1., 230.), | ||
(Predicate.from_dict({"c": 3}), 1., 270.), | ||
]), | ||
"Female": (1., [ | ||
(Predicate.from_dict({"c": 7}), 0., 230.), | ||
(Predicate.from_dict({"c": 3}), 1., 270.), | ||
]) | ||
}, | ||
Predicate.from_dict({"a": 19, "c": 30}): { | ||
"Male": (2/3, [ | ||
(Predicate.from_dict({"a": 21, "c": 3}), 1., 290.) | ||
]), | ||
"Female": (2/3, [ | ||
(Predicate.from_dict({"a": 21, "c": 3}), 1., 290.) | ||
]) | ||
}, | ||
} | ||
expected_most_biased_subgroups = [ | ||
({"a": 19}, float("inf")), | ||
({"c": 30}, 40.), | ||
({"a": 19, "c": 30}, 0.), | ||
] | ||
|
||
assert len(most_biased_subgroups) == len(expected_most_biased_subgroups) | ||
for g in expected_most_biased_subgroups: | ||
assert g in most_biased_subgroups | ||
for g in most_biased_subgroups: | ||
assert g in expected_most_biased_subgroups |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters