From 8405b6828126246ac8a1c5d051c7dd09c8698118 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Wed, 22 Jan 2025 17:06:50 +0100 Subject: [PATCH 1/3] use image threshold when pixel threshold not available --- src/anomalib/post_processing/one_class.py | 76 ++++++++++++++--------- 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/src/anomalib/post_processing/one_class.py b/src/anomalib/post_processing/one_class.py index ffa906a176..bc7b504bac 100644 --- a/src/anomalib/post_processing/one_class.py +++ b/src/anomalib/post_processing/one_class.py @@ -64,21 +64,21 @@ def __init__( self.pixel_sensitivity = pixel_sensitivity # initialize threshold and normalization metrics - self._image_threshold = F1AdaptiveThreshold(fields=["pred_score", "gt_label"], strict=False) - self._pixel_threshold = F1AdaptiveThreshold(fields=["anomaly_map", "gt_mask"], strict=False) - self._image_min_max = MinMax(fields=["pred_score"], strict=False) - self._pixel_min_max = MinMax(fields=["anomaly_map"], strict=False) + self._image_threshold_metric = F1AdaptiveThreshold(fields=["pred_score", "gt_label"], strict=False) + self._pixel_threshold_metric = F1AdaptiveThreshold(fields=["anomaly_map", "gt_mask"], strict=False) + self._image_min_max_metric = MinMax(fields=["pred_score"], strict=False) + self._pixel_min_max_metric = MinMax(fields=["anomaly_map"], strict=False) # register buffers to persist threshold and normalization values - self.register_buffer("image_threshold", torch.tensor(0)) - self.register_buffer("pixel_threshold", torch.tensor(0)) + self.register_buffer("_image_threshold", torch.tensor(0)) + self.register_buffer("_pixel_threshold", torch.tensor(0)) self.register_buffer("image_min", torch.tensor(0)) self.register_buffer("image_max", torch.tensor(1)) self.register_buffer("pixel_min", torch.tensor(0)) self.register_buffer("pixel_max", torch.tensor(1)) - self.image_threshold: torch.Tensor - self.pixel_threshold: torch.Tensor + self._image_threshold: torch.Tensor + self._pixel_threshold: torch.Tensor self.image_min: torch.Tensor self.image_max: torch.Tensor self.pixel_min: torch.Tensor @@ -102,10 +102,10 @@ def on_validation_batch_end( **kwargs: Arbitrary keyword arguments. """ del trainer, pl_module, args, kwargs # Unused arguments. - self._image_threshold.update(outputs) - self._pixel_threshold.update(outputs) - self._image_min_max.update(outputs) - self._pixel_min_max.update(outputs) + self._image_threshold_metric.update(outputs) + self._pixel_threshold_metric.update(outputs) + self._image_min_max_metric.update(outputs) + self._pixel_min_max_metric.update(outputs) def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Compute final threshold and normalization values. @@ -115,14 +115,14 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) pl_module (LightningModule): PyTorch Lightning module instance. """ del trainer, pl_module - if self._image_threshold.update_called: - self.image_threshold = self._image_threshold.compute() - if self._pixel_threshold.update_called: - self.pixel_threshold = self._pixel_threshold.compute() - if self._image_min_max.update_called: - self.image_min, self.image_max = self._image_min_max.compute() - if self._pixel_min_max.update_called: - self.pixel_min, self.pixel_max = self._pixel_min_max.compute() + if self._image_threshold_metric.update_called: + self._image_threshold = self._image_threshold_metric.compute() + if self._pixel_threshold_metric.update_called: + self._pixel_threshold = self._pixel_threshold_metric.compute() + if self._image_min_max_metric.update_called: + self.image_min, self.image_max = self._image_min_max_metric.compute() + if self._pixel_min_max_metric.update_called: + self.pixel_min, self.pixel_max = self._pixel_min_max_metric.compute() def on_test_batch_end( self, @@ -181,8 +181,8 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch: msg = "At least one of pred_score or anomaly_map must be provided." raise ValueError(msg) pred_score = predictions.pred_score or torch.amax(predictions.anomaly_map, dim=(-2, -1)) - pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.image_threshold) - anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self.pixel_threshold) + pred_score = self._normalize(pred_score, self.image_min, self.image_max, self._image_threshold) + anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self._pixel_threshold) pred_label = self._apply_threshold(pred_score, self.normalized_image_threshold) pred_mask = self._apply_threshold(anomaly_map, self.normalized_pixel_threshold) return InferenceBatch( @@ -274,6 +274,26 @@ def _normalize( preds = ((preds - threshold) / (norm_max - norm_min)) + 0.5 return preds.clamp(min=0, max=1) + @property + def image_threshold(self) -> float: + """Get the image-level threshold. + + Returns: + float: Image-level threshold value. + """ + return self._image_threshold + + @property + def pixel_threshold(self) -> float: + """Get the pixel-level threshold. + + If the pixel-level threshold is not set, the image-level threshold is used. + + Returns: + float: Pixel-level threshold value. + """ + return self._pixel_threshold or self.image_threshold + @property def normalized_image_threshold(self) -> float: """Get the normalized image-level threshold. @@ -281,9 +301,9 @@ def normalized_image_threshold(self) -> float: Returns: float: Normalized image-level threshold value, adjusted by sensitivity. """ - if self.image_sensitivity is not None: - return torch.tensor(1.0) - self.image_sensitivity - return torch.tensor(0.5) + if self.image_sensitivity is None: + return torch.tensor(0.5) + return torch.tensor(1.0) - self.image_sensitivity @property def normalized_pixel_threshold(self) -> float: @@ -292,6 +312,6 @@ def normalized_pixel_threshold(self) -> float: Returns: float: Normalized pixel-level threshold value, adjusted by sensitivity. """ - if self.pixel_sensitivity is not None: - return torch.tensor(1.0) - self.pixel_sensitivity - return torch.tensor(0.5) + if self.pixel_sensitivity is None: + return torch.tensor(0.5) + return torch.tensor(1.0) - self.pixel_sensitivity From 5179fb43443709135c09247d4739d4a8f60d630f Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Wed, 22 Jan 2025 18:26:57 +0100 Subject: [PATCH 2/3] add unit tests for post-processor --- .../post_processing/test_post_processor.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 tests/unit/post_processing/test_post_processor.py diff --git a/tests/unit/post_processing/test_post_processor.py b/tests/unit/post_processing/test_post_processor.py new file mode 100644 index 0000000000..75db9e64ce --- /dev/null +++ b/tests/unit/post_processing/test_post_processor.py @@ -0,0 +1,89 @@ +"""Test the PostProcessor class.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from anomalib.data import ImageBatch +from anomalib.post_processing import OneClassPostProcessor + + +class TestPostProcessor: + """Test the PreProcessor class.""" + + @staticmethod + @pytest.mark.parametrize( + ("preds", "min_val", "max_val", "thresh", "target"), + [ + (torch.tensor([20, 40, 60, 80]), 0, 100, 50, torch.tensor([0.2, 0.4, 0.6, 0.8])), + (torch.tensor([20, 40, 60, 80]), 0, 100, 40, torch.tensor([0.3, 0.5, 0.7, 0.9])), # lower threshold + (torch.tensor([20, 40, 60, 80]), 0, 100, 60, torch.tensor([0.1, 0.3, 0.5, 0.7])), # higher threshold + (torch.tensor([0, 40, 80, 120]), 20, 100, 50, torch.tensor([0.0, 0.375, 0.875, 1.0])), # out of bounds + (torch.tensor([-80, -60, -40, -20]), -100, 0, -50, torch.tensor([0.2, 0.4, 0.6, 0.8])), # negative values + (torch.tensor([20, 40, 60, 80]), 0, 100, -50, torch.tensor([1.0, 1.0, 1.0, 1.0])), # threshold below range + (torch.tensor([20, 40, 60, 80]), 0, 100, 150, torch.tensor([0.0, 0.0, 0.0, 0.0])), # threshold above range + (torch.tensor([20, 40, 60, 80]), 50, 50, 50, torch.tensor([0.0, 0.0, 1.0, 1.0])), # all same + (torch.tensor(60), 0, 100, 50, torch.tensor(0.6)), # scalar tensor + (torch.tensor([[20, 40], [60, 80]]), 0, 100, 50, torch.tensor([[0.2, 0.4], [0.6, 0.8]])), # 2D tensor + ], + ) + def test_normalize( + preds: torch.Tensor, + min_val: float, + max_val: float, + thresh: float, + target: torch.Tensor, + ) -> None: + """Test the normalize method.""" + pre_processor = OneClassPostProcessor() + normalized = pre_processor._normalize(preds, min_val, max_val, thresh) # noqa: SLF001 + assert torch.allclose(normalized, target) + + @staticmethod + @pytest.mark.parametrize( + ("preds", "thresh", "target"), + [ + (torch.tensor(20), 50, torch.tensor(0).bool()), # test scalar + (torch.tensor([20, 40, 60, 80]), 50, torch.tensor([0, 0, 1, 1]).bool()), # test 1d tensor + (torch.tensor([[20, 40], [60, 80]]), 50, torch.tensor([[0, 0], [1, 1]]).bool()), # test 2d tensor + (torch.tensor(50), 50, torch.tensor(0).bool()), # test on threshold labeled as normal + (torch.tensor([-80, -60, -40, -20]), -50, torch.tensor([0, 0, 1, 1]).bool()), # test negative + ], + ) + def test_apply_threshold(preds: torch.Tensor, thresh: float, target: torch.Tensor) -> None: + """Test the apply_threshold method.""" + pre_processor = OneClassPostProcessor() + binary_preds = pre_processor._apply_threshold(preds, thresh) # noqa: SLF001 + assert torch.allclose(binary_preds, target) + + @staticmethod + def test_thresholds_computed() -> None: + """Test that both image and pixel threshold are computed correctly.""" + batch = ImageBatch( + image=torch.rand(4, 3, 3, 3), + anomaly_map=torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]]), + gt_mask=torch.tensor([[0, 0, 0], [0, 0, 0], [0, 1, 1]]), + pred_score=torch.tensor([20, 40, 60, 80]), + gt_label=torch.tensor([0, 0, 1, 1]), + ) + pre_processor = OneClassPostProcessor() + pre_processor.on_validation_batch_end(None, None, batch) + pre_processor.on_validation_epoch_end(None, None) + assert pre_processor.image_threshold == 60 + assert pre_processor.pixel_threshold == 80 + + @staticmethod + def test_pixel_threshold_used_as_image_threshold() -> None: + """Test that pixel_threshold is used as image threshold when no gt masks are available.""" + batch = ImageBatch( + image=torch.rand(4, 3, 10, 10), + anomaly_map=torch.rand(4, 10, 10), + pred_score=torch.tensor([20, 40, 60, 80]), + gt_label=torch.tensor([0, 0, 1, 1]), + ) + pre_processor = OneClassPostProcessor() + pre_processor.on_validation_batch_end(None, None, batch) + pre_processor.on_validation_epoch_end(None, None) + assert pre_processor.image_threshold == pre_processor.pixel_threshold From 2aede4a4c699a1bf3acbb6f7e00c8a1f15e31c56 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Wed, 22 Jan 2025 18:32:26 +0100 Subject: [PATCH 3/3] use right threshold in forward pass --- src/anomalib/post_processing/one_class.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anomalib/post_processing/one_class.py b/src/anomalib/post_processing/one_class.py index bc7b504bac..b227d5abb6 100644 --- a/src/anomalib/post_processing/one_class.py +++ b/src/anomalib/post_processing/one_class.py @@ -181,8 +181,8 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch: msg = "At least one of pred_score or anomaly_map must be provided." raise ValueError(msg) pred_score = predictions.pred_score or torch.amax(predictions.anomaly_map, dim=(-2, -1)) - pred_score = self._normalize(pred_score, self.image_min, self.image_max, self._image_threshold) - anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self._pixel_threshold) + pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.image_threshold) + anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self.pixel_threshold) pred_label = self._apply_threshold(pred_score, self.normalized_image_threshold) pred_mask = self._apply_threshold(anomaly_map, self.normalized_pixel_threshold) return InferenceBatch(