Skip to content

Commit

Permalink
Add ClassificationQuality render and ClassificationPreset. (#1438)
Browse files Browse the repository at this point in the history
  • Loading branch information
Liraim authored Jan 21, 2025
1 parent b70044c commit 8fd49c2
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 11,670 deletions.
34 changes: 12 additions & 22 deletions examples/future_examples/future_dashboads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,12 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-09T23:58:58.734191Z",
"start_time": "2025-01-09T23:58:58.566108Z"
}
},
"outputs": [],
"metadata": {},
"source": [
"from random import randint\n",
"\n",
"from evidently.ui.dashboards.reports import DistributionPanel\n",
"from evidently.ui.dashboards.reports import DashboardPanelHistogram\n",
"from evidently.future.datasets import BinaryClassification\n",
"from evidently.future.metrics import PrecisionByLabel\n",
"from evidently.future.metrics.column_statistics import CategoryCount\n",
Expand Down Expand Up @@ -77,7 +70,7 @@
" value=PanelValue(field_path=\"value\", metric_args={\"metric.metric_id\": \"2e5caa9690281e02cf243c736d687782\"}),\n",
" filter=ReportFilter(metadata_values={}, tag_values=[], include_test_suites=True),\n",
" ))\n",
" project.dashboard.add_panel(DistributionPanel(\n",
" project.dashboard.add_panel(DashboardPanelHistogram(\n",
" title=\"Distr\",\n",
" value=PanelValue(field_path=\"values\", metric_args={\"metric.type\": \"evidently:metric_v2:UniqueValueCount\"}),\n",
" filter=ReportFilter(metadata_values={}, tag_values=[], include_test_suites=True),\n",
Expand All @@ -87,30 +80,27 @@
"\n",
"for i in range(10):\n",
" project.add_snapshot(create_snapshot(i)) "
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"execution_count": 2,
"id": "35a20fd14d577dc8",
"metadata": {
"ExecuteTime": {
"end_time": "2025-01-09T23:56:41.749035100Z",
"start_time": "2025-01-09T23:53:00.368240Z"
}
},
"outputs": [],
"metadata": {},
"source": [
"# use `evidently ui` to run UI service to see project dashboard."
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"execution_count": null,
"id": "babdaf65-d7a2-454d-9429-e270dd7fffc7",
"metadata": {},
"source": [],
"outputs": [],
"source": []
"execution_count": null
}
],
"metadata": {
Expand Down
10,072 changes: 62 additions & 10,010 deletions examples/future_examples/list_metrics.ipynb

Large diffs are not rendered by default.

1,647 changes: 26 additions & 1,621 deletions examples/future_examples/metric_workbench.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/evidently/future/presets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .classification import ClassificationDummyQuality
from .classification import ClassificationPreset
from .classification import ClassificationQuality
from .classification import ClassificationQualityByLabel
from .dataset_stats import DatasetStats
Expand All @@ -12,6 +13,7 @@

__all__ = [
"ClassificationDummyQuality",
"ClassificationPreset",
"ClassificationQuality",
"ClassificationQualityByLabel",
"ValueStats",
Expand Down
89 changes: 72 additions & 17 deletions src/evidently/future/presets/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,28 @@
from evidently.future.metrics.classification import DummyPrecision
from evidently.future.metrics.classification import DummyRecall
from evidently.future.report import Context
from evidently.metrics import ClassificationConfusionMatrix
from evidently.metrics import ClassificationDummyMetric
from evidently.metrics import ClassificationPRCurve
from evidently.metrics import ClassificationPRTable
from evidently.metrics import ClassificationQualityByClass
from evidently.metrics import ClassificationQualityMetric
from evidently.model.widget import BaseWidgetInfo


class ClassificationQuality(MetricContainer):
def __init__(
self,
probas_threshold: Optional[float] = None,
conf_matrix: bool = False,
pr_curve: bool = False,
pr_table: bool = False,
):
self._probas_threshold = probas_threshold
self._conf_matrix = conf_matrix
self._pr_curve = pr_curve
self._pr_table = pr_table

def generate_metrics(self, context: "Context") -> List[Metric]:
classification = context.data_definition.get_classification("default")
if classification is None:
Expand All @@ -41,44 +56,52 @@ def generate_metrics(self, context: "Context") -> List[Metric]:

if isinstance(classification, BinaryClassification):
metrics = [
Accuracy(),
Precision(),
Recall(),
F1Score(),
Accuracy(probas_threshold=self._probas_threshold),
Precision(probas_threshold=self._probas_threshold),
Recall(probas_threshold=self._probas_threshold),
F1Score(probas_threshold=self._probas_threshold),
]
if classification.prediction_probas is not None:
metrics.extend(
[
RocAuc(),
LogLoss(),
RocAuc(probas_threshold=self._probas_threshold),
LogLoss(probas_threshold=self._probas_threshold),
]
)
metrics.extend(
[
TPR(),
TNR(),
FPR(),
FNR(),
TPR(probas_threshold=self._probas_threshold),
TNR(probas_threshold=self._probas_threshold),
FPR(probas_threshold=self._probas_threshold),
FNR(probas_threshold=self._probas_threshold),
]
)
else:
metrics = [
Accuracy(),
Precision(),
Recall(),
F1Score(),
Accuracy(probas_threshold=self._probas_threshold),
Precision(probas_threshold=self._probas_threshold),
Recall(probas_threshold=self._probas_threshold),
F1Score(probas_threshold=self._probas_threshold),
]
if classification.prediction_probas is not None:
metrics.extend(
[
RocAuc(),
LogLoss(),
RocAuc(probas_threshold=self._probas_threshold),
LogLoss(probas_threshold=self._probas_threshold),
]
)
return metrics

def render(self, context: "Context", results: Dict[MetricId, MetricResult]) -> List[BaseWidgetInfo]:
_, render = context.get_legacy_metric(ClassificationQualityMetric())
_, render = context.get_legacy_metric(ClassificationQualityMetric(probas_threshold=self._probas_threshold))
if self._conf_matrix:
render += context.get_legacy_metric(ClassificationConfusionMatrix(probas_threshold=self._probas_threshold))[
1
]
if self._pr_curve:
render += context.get_legacy_metric(ClassificationPRCurve(probas_threshold=self._probas_threshold))[1]
if self._pr_table:
render += context.get_legacy_metric(ClassificationPRTable(probas_threshold=self._probas_threshold))[1]
return render


Expand Down Expand Up @@ -117,3 +140,35 @@ def generate_metrics(self, context: "Context") -> List[Metric]:
def render(self, context: "Context", results: Dict[MetricId, MetricResult]) -> List[BaseWidgetInfo]:
_, widgets = context.get_legacy_metric(ClassificationDummyMetric(self._probas_threshold, self._k))
return widgets


class ClassificationPreset(MetricContainer):
def __init__(self, probas_threshold: Optional[float] = None):
self._probas_threshold = probas_threshold
self._quality = ClassificationQuality(
probas_threshold=probas_threshold,
conf_matrix=True,
pr_curve=True,
pr_table=True,
)
self._quality_by_label = ClassificationQualityByLabel(probas_threshold=probas_threshold)
self._roc_auc: Optional[RocAuc] = None

def generate_metrics(self, context: "Context") -> List[Metric]:
classification = context.data_definition.get_classification("default")
if classification is None:
raise ValueError("Cannot use ClassificationPreset without a classification configration")
if classification.prediction_probas is not None:
self._roc_auc = RocAuc()
return (
self._quality.metrics(context)
+ self._quality_by_label.metrics(context)
+ ([] if self._roc_auc is None else [RocAuc(probas_threshold=self._probas_threshold)])
)

def render(self, context: "Context", results: Dict[MetricId, MetricResult]) -> List[BaseWidgetInfo]:
return (
self._quality.render(context, results)
+ self._quality_by_label.render(context, results)
+ ([] if self._roc_auc is None else context.get_metric_result(self._roc_auc).widget)
)

0 comments on commit 8fd49c2

Please sign in to comment.