Skip to content

Commit

Permalink
Fix Binary Classification with different configurations. (#1463)
Browse files Browse the repository at this point in the history
* Add default behaviour for Classification configuration.
Check for probas for RocAuc metrics and graphs.

* Fix linter
  • Loading branch information
Liraim authored Jan 30, 2025
1 parent e6dd8b3 commit aa7629f
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 11 deletions.
73 changes: 66 additions & 7 deletions src/evidently/future/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,81 @@ class ColumnInfo:

@dataclasses.dataclass
class BinaryClassification:
name: str = "default"
target: str = "target"
prediction_labels: Optional[str] = None
prediction_probas: Optional[str] = "prediction"
pos_label: Label = 1
labels: Optional[Dict[Label, str]] = None
name: str
target: str
prediction_labels: Optional[str]
prediction_probas: Optional[str]
pos_label: Label
labels: Optional[Dict[Label, str]]

def __init__(
self,
*,
name: str = "default",
target: Optional[str] = None,
prediction_labels: Optional[str] = None,
prediction_probas: Optional[str] = None,
pos_label: Optional[str] = None,
labels: Optional[Dict[Label, str]] = None,
):
self.name = name
if (
target is None
and prediction_labels is None
and prediction_probas is None
and pos_label is None
and labels is None
):
self.target = "target"
self.prediction_labels = None
self.prediction_probas = "prediction"
self.pos_label = 1
self.labels = None
return
if target is None or (prediction_labels is None and prediction_probas is None):
raise ValueError(
"Invalid BinaryClassification configuration:" " target and one of (labels or probas) should be set"
)
self.target = target
self.prediction_labels = prediction_labels
self.prediction_probas = prediction_probas
self.pos_label = pos_label if pos_label is not None else 1
self.labels = labels


@dataclasses.dataclass
class MulticlassClassification:
name: str = "default"
target: str = "target"
prediction_labels: str = "prediction"
prediction_labels: Optional[str] = "prediction"
prediction_probas: Optional[List[str]] = None
labels: Optional[Dict[Label, str]] = None

def __init__(
self,
*,
name: str = "default",
target: Optional[str] = None,
prediction_labels: Optional[str] = None,
prediction_probas: Optional[List[str]] = None,
labels: Optional[Dict[Label, str]] = None,
):
self.name = name
if target is None and prediction_labels is None and prediction_probas is None and labels is None:
self.target = "target"
self.prediction_labels = "prediction"
self.prediction_probas = None
self.labels = None
return
if target is None or (prediction_labels is None and prediction_probas is None):
raise ValueError(
"Invalid MulticlassClassification configuration:" " target and one of (labels or probas) should be set"
)
self.target = target
self.prediction_labels = prediction_labels
self.prediction_probas = prediction_probas
self.labels = labels


Classification = Union[BinaryClassification, MulticlassClassification]

Expand Down
19 changes: 15 additions & 4 deletions src/evidently/future/presets/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,15 @@ def render(self, context: "Context", results: Dict[MetricId, MetricResult]) -> L
ClassificationConfusionMatrix(probas_threshold=self._probas_threshold),
_gen_classification_input_data,
)[1]
if self._pr_curve:
classification = context.data_definition.get_classification("default")
if classification is None:
raise ValueError("Cannot use ClassificationQuality without a classification data")
if self._pr_curve and classification.prediction_probas is not None:
render += context.get_legacy_metric(
ClassificationPRCurve(probas_threshold=self._probas_threshold),
_gen_classification_input_data,
)[1]
if self._pr_table:
if self._pr_table and classification.prediction_probas is not None:
render += context.get_legacy_metric(
ClassificationPRTable(probas_threshold=self._probas_threshold),
_gen_classification_input_data,
Expand All @@ -122,12 +125,20 @@ def __init__(self, probas_threshold: Optional[float] = None, k: Optional[int] =
self._k = k

def generate_metrics(self, context: "Context") -> List[Metric]:
classification = context.data_definition.get_classification("default")
if classification is None:
raise ValueError("Cannot use ClassificationPreset without a classification configration")
return [
F1ByLabel(probas_threshold=self._probas_threshold, k=self._k),
PrecisionByLabel(probas_threshold=self._probas_threshold, k=self._k),
RecallByLabel(probas_threshold=self._probas_threshold, k=self._k),
RocAucByLabel(probas_threshold=self._probas_threshold, k=self._k),
]
] + (
[]
if classification.prediction_probas is None
else [
RocAucByLabel(probas_threshold=self._probas_threshold, k=self._k),
]
)

def render(self, context: "Context", results: Dict[MetricId, MetricResult]):
render = context.get_legacy_metric(
Expand Down

0 comments on commit aa7629f

Please sign in to comment.