Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix post-processing visualization issues #2534

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 46 additions & 26 deletions src/anomalib/post_processing/one_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,21 @@ def __init__(
self.pixel_sensitivity = pixel_sensitivity

# initialize threshold and normalization metrics
self._image_threshold = F1AdaptiveThreshold(fields=["pred_score", "gt_label"], strict=False)
self._pixel_threshold = F1AdaptiveThreshold(fields=["anomaly_map", "gt_mask"], strict=False)
self._image_min_max = MinMax(fields=["pred_score"], strict=False)
self._pixel_min_max = MinMax(fields=["anomaly_map"], strict=False)
self._image_threshold_metric = F1AdaptiveThreshold(fields=["pred_score", "gt_label"], strict=False)
self._pixel_threshold_metric = F1AdaptiveThreshold(fields=["anomaly_map", "gt_mask"], strict=False)
self._image_min_max_metric = MinMax(fields=["pred_score"], strict=False)
self._pixel_min_max_metric = MinMax(fields=["anomaly_map"], strict=False)

# register buffers to persist threshold and normalization values
self.register_buffer("image_threshold", torch.tensor(0))
self.register_buffer("pixel_threshold", torch.tensor(0))
self.register_buffer("_image_threshold", torch.tensor(0))
self.register_buffer("_pixel_threshold", torch.tensor(0))
self.register_buffer("image_min", torch.tensor(0))
self.register_buffer("image_max", torch.tensor(1))
self.register_buffer("pixel_min", torch.tensor(0))
self.register_buffer("pixel_max", torch.tensor(1))

self.image_threshold: torch.Tensor
self.pixel_threshold: torch.Tensor
self._image_threshold: torch.Tensor
self._pixel_threshold: torch.Tensor
self.image_min: torch.Tensor
self.image_max: torch.Tensor
self.pixel_min: torch.Tensor
Expand All @@ -102,10 +102,10 @@ def on_validation_batch_end(
**kwargs: Arbitrary keyword arguments.
"""
del trainer, pl_module, args, kwargs # Unused arguments.
self._image_threshold.update(outputs)
self._pixel_threshold.update(outputs)
self._image_min_max.update(outputs)
self._pixel_min_max.update(outputs)
self._image_threshold_metric.update(outputs)
self._pixel_threshold_metric.update(outputs)
self._image_min_max_metric.update(outputs)
self._pixel_min_max_metric.update(outputs)

def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Compute final threshold and normalization values.
Expand All @@ -115,14 +115,14 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule)
pl_module (LightningModule): PyTorch Lightning module instance.
"""
del trainer, pl_module
if self._image_threshold.update_called:
self.image_threshold = self._image_threshold.compute()
if self._pixel_threshold.update_called:
self.pixel_threshold = self._pixel_threshold.compute()
if self._image_min_max.update_called:
self.image_min, self.image_max = self._image_min_max.compute()
if self._pixel_min_max.update_called:
self.pixel_min, self.pixel_max = self._pixel_min_max.compute()
if self._image_threshold_metric.update_called:
self._image_threshold = self._image_threshold_metric.compute()
if self._pixel_threshold_metric.update_called:
self._pixel_threshold = self._pixel_threshold_metric.compute()
if self._image_min_max_metric.update_called:
self.image_min, self.image_max = self._image_min_max_metric.compute()
if self._pixel_min_max_metric.update_called:
self.pixel_min, self.pixel_max = self._pixel_min_max_metric.compute()

def on_test_batch_end(
self,
Expand Down Expand Up @@ -274,16 +274,36 @@ def _normalize(
preds = ((preds - threshold) / (norm_max - norm_min)) + 0.5
return preds.clamp(min=0, max=1)

@property
def image_threshold(self) -> float:
"""Get the image-level threshold.

Returns:
float: Image-level threshold value.
"""
return self._image_threshold

@property
def pixel_threshold(self) -> float:
"""Get the pixel-level threshold.

If the pixel-level threshold is not set, the image-level threshold is used.

Returns:
float: Pixel-level threshold value.
"""
return self._pixel_threshold or self.image_threshold

@property
def normalized_image_threshold(self) -> float:
"""Get the normalized image-level threshold.

Returns:
float: Normalized image-level threshold value, adjusted by sensitivity.
"""
if self.image_sensitivity is not None:
return torch.tensor(1.0) - self.image_sensitivity
return torch.tensor(0.5)
if self.image_sensitivity is None:
return torch.tensor(0.5)
return torch.tensor(1.0) - self.image_sensitivity

@property
def normalized_pixel_threshold(self) -> float:
Expand All @@ -292,6 +312,6 @@ def normalized_pixel_threshold(self) -> float:
Returns:
float: Normalized pixel-level threshold value, adjusted by sensitivity.
"""
if self.pixel_sensitivity is not None:
return torch.tensor(1.0) - self.pixel_sensitivity
return torch.tensor(0.5)
if self.pixel_sensitivity is None:
return torch.tensor(0.5)
return torch.tensor(1.0) - self.pixel_sensitivity
89 changes: 89 additions & 0 deletions tests/unit/post_processing/test_post_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Test the PostProcessor class."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from anomalib.data import ImageBatch
from anomalib.post_processing import OneClassPostProcessor


class TestPostProcessor:
"""Test the PreProcessor class."""

@staticmethod
@pytest.mark.parametrize(
("preds", "min_val", "max_val", "thresh", "target"),
[
(torch.tensor([20, 40, 60, 80]), 0, 100, 50, torch.tensor([0.2, 0.4, 0.6, 0.8])),
(torch.tensor([20, 40, 60, 80]), 0, 100, 40, torch.tensor([0.3, 0.5, 0.7, 0.9])), # lower threshold
(torch.tensor([20, 40, 60, 80]), 0, 100, 60, torch.tensor([0.1, 0.3, 0.5, 0.7])), # higher threshold
(torch.tensor([0, 40, 80, 120]), 20, 100, 50, torch.tensor([0.0, 0.375, 0.875, 1.0])), # out of bounds
(torch.tensor([-80, -60, -40, -20]), -100, 0, -50, torch.tensor([0.2, 0.4, 0.6, 0.8])), # negative values
(torch.tensor([20, 40, 60, 80]), 0, 100, -50, torch.tensor([1.0, 1.0, 1.0, 1.0])), # threshold below range
(torch.tensor([20, 40, 60, 80]), 0, 100, 150, torch.tensor([0.0, 0.0, 0.0, 0.0])), # threshold above range
(torch.tensor([20, 40, 60, 80]), 50, 50, 50, torch.tensor([0.0, 0.0, 1.0, 1.0])), # all same
(torch.tensor(60), 0, 100, 50, torch.tensor(0.6)), # scalar tensor
(torch.tensor([[20, 40], [60, 80]]), 0, 100, 50, torch.tensor([[0.2, 0.4], [0.6, 0.8]])), # 2D tensor
],
)
def test_normalize(
preds: torch.Tensor,
min_val: float,
max_val: float,
thresh: float,
target: torch.Tensor,
) -> None:
"""Test the normalize method."""
pre_processor = OneClassPostProcessor()
normalized = pre_processor._normalize(preds, min_val, max_val, thresh) # noqa: SLF001
assert torch.allclose(normalized, target)

@staticmethod
@pytest.mark.parametrize(
("preds", "thresh", "target"),
[
(torch.tensor(20), 50, torch.tensor(0).bool()), # test scalar
(torch.tensor([20, 40, 60, 80]), 50, torch.tensor([0, 0, 1, 1]).bool()), # test 1d tensor
(torch.tensor([[20, 40], [60, 80]]), 50, torch.tensor([[0, 0], [1, 1]]).bool()), # test 2d tensor
(torch.tensor(50), 50, torch.tensor(0).bool()), # test on threshold labeled as normal
(torch.tensor([-80, -60, -40, -20]), -50, torch.tensor([0, 0, 1, 1]).bool()), # test negative
],
)
def test_apply_threshold(preds: torch.Tensor, thresh: float, target: torch.Tensor) -> None:
"""Test the apply_threshold method."""
pre_processor = OneClassPostProcessor()
binary_preds = pre_processor._apply_threshold(preds, thresh) # noqa: SLF001
assert torch.allclose(binary_preds, target)

@staticmethod
def test_thresholds_computed() -> None:
"""Test that both image and pixel threshold are computed correctly."""
batch = ImageBatch(
image=torch.rand(4, 3, 3, 3),
anomaly_map=torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]]),
gt_mask=torch.tensor([[0, 0, 0], [0, 0, 0], [0, 1, 1]]),
pred_score=torch.tensor([20, 40, 60, 80]),
gt_label=torch.tensor([0, 0, 1, 1]),
)
pre_processor = OneClassPostProcessor()
pre_processor.on_validation_batch_end(None, None, batch)
pre_processor.on_validation_epoch_end(None, None)
assert pre_processor.image_threshold == 60
assert pre_processor.pixel_threshold == 80

@staticmethod
def test_pixel_threshold_used_as_image_threshold() -> None:
"""Test that pixel_threshold is used as image threshold when no gt masks are available."""
batch = ImageBatch(
image=torch.rand(4, 3, 10, 10),
anomaly_map=torch.rand(4, 10, 10),
pred_score=torch.tensor([20, 40, 60, 80]),
gt_label=torch.tensor([0, 0, 1, 1]),
)
pre_processor = OneClassPostProcessor()
pre_processor.on_validation_batch_end(None, None, batch)
pre_processor.on_validation_epoch_end(None, None)
assert pre_processor.image_threshold == pre_processor.pixel_threshold
Loading