diff --git a/metrics/f1/README.md b/metrics/f1/README.md index 8f38be6e6..ec9935f49 100644 --- a/metrics/f1/README.md +++ b/metrics/f1/README.md @@ -48,6 +48,7 @@ At minimum, this metric requires predictions and references as input - 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters `'macro'` to account for label imbalance. This option can result in an F-score that is not between precision and recall. - 'samples': Calculate metrics for each instance, and find their average (only meaningful for multilabel classification). - **sample_weight** (`list` of `float`): Sample weights Defaults to None. +- **zero_division** ('warn' or 0.0 or 1.0 or np.nan): Sets the value to return when there is a zero division, i.e. when all predictions and labels are negative. Defaults to 'warn'. ### Output Values @@ -134,4 +135,4 @@ Example 4-A multiclass example, with different values for the `average` input. ``` -## Further References \ No newline at end of file +## Further References diff --git a/metrics/f1/f1.py b/metrics/f1/f1.py index fe7683489..267eafcc9 100644 --- a/metrics/f1/f1.py +++ b/metrics/f1/f1.py @@ -39,6 +39,7 @@ - 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters `'macro'` to account for label imbalance. This option can result in an F-score that is not between precision and recall. - 'samples': Calculate metrics for each instance, and find their average (only meaningful for multilabel classification). sample_weight (`list` of `float`): Sample weights Defaults to None. + zero_division ('warn' or 0.0 or 1.0 or np.nan): Sets the value to return when there is a zero division, i.e. when all predictions and labels are negative. Defaults to 'warn'. Returns: f1 (`float` or `array` of `float`): F1 score or list of f1 scores, depending on the value passed to `average`. Minimum possible value is 0. Maximum possible value is 1. Higher f1 scores are better. @@ -123,8 +124,8 @@ def _info(self): reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"], ) - def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None): + def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", zero_division=None, sample_weight=None): score = f1_score( - references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight + references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, zero_division=zero_division, ) return {"f1": float(score) if score.size == 1 else score}