Skip to content

Commit

Permalink
Add RollingPRAUC test to test_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlpgomes committed May 10, 2024
1 parent f3da068 commit 7995570
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions river/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,24 @@ def roc_auc_score(y_true, y_score):
return sk_metrics.roc_auc_score(y_true, scores)


def pr_auc_score(y_true, y_score):
"""
This function is a wrapper to the scikit-learn precision_recall_curve and
auc functions. Returns 0 if y_true has only one class.
"""
nonzero = np.count_nonzero(y_true)
if nonzero == 0 or nonzero == len(y_true):
return 0

scores = [s[True] for s in y_score]
precision, recall, _ = sk_metrics.precision_recall_curve(y_true, scores)

# Monotonic. decreasing
precision = np.maximum.accumulate(precision)

return sk_metrics.auc(recall, precision)


TEST_CASES = [
(metrics.Accuracy(), sk_metrics.accuracy_score),
(metrics.Precision(), partial(sk_metrics.precision_score, zero_division=0)),
Expand Down Expand Up @@ -210,6 +228,7 @@ def roc_auc_score(y_true, y_score):
(metrics.MicroJaccard(), partial(sk_metrics.jaccard_score, average="micro")),
(metrics.WeightedJaccard(), partial(sk_metrics.jaccard_score, average="weighted")),
(metrics.RollingROCAUC(), roc_auc_score),
(metrics.RollingPRAUC(), pr_auc_score),
]

# HACK: not sure why this is needed, see this CI run https://github.com/online-ml/river/runs/7992357532?check_suite_focus=true
Expand Down

0 comments on commit 7995570

Please sign in to comment.