From e11c663f05c24711b1dc3f88c7f11e52d80b6697 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 9 Dec 2024 13:38:36 +0000 Subject: [PATCH] Add visualizer to models Signed-off-by: Samet Akcay --- .../models/components/base/anomaly_module.py | 98 +++++++++++++++---- .../models/image/cfa/lightning_model.py | 10 +- .../models/image/cflow/lightning_model.py | 9 +- .../models/image/csflow/lightning_model.py | 9 +- .../models/image/dfkde/lightning_model.py | 9 +- .../models/image/dfm/lightning_model.py | 9 +- .../models/image/draem/lightning_model.py | 9 +- .../models/image/dsr/lightning_model.py | 9 +- .../image/efficient_ad/lightning_model.py | 9 +- .../models/image/fastflow/lightning_model.py | 9 +- .../models/image/fre/lightning_model.py | 9 +- .../models/image/ganomaly/lightning_model.py | 9 +- .../models/image/padim/lightning_model.py | 9 +- .../models/image/patchcore/lightning_model.py | 9 +- .../reverse_distillation/lightning_model.py | 9 +- .../models/image/stfpm/lightning_model.py | 9 +- .../models/image/uflow/lightning_model.py | 12 ++- .../models/image/winclip/lightning_model.py | 9 +- src/anomalib/visualization/__init__.py | 3 + src/anomalib/visualization/base.py | 14 +++ .../visualization/image/visualizer.py | 6 +- 21 files changed, 240 insertions(+), 38 deletions(-) create mode 100644 src/anomalib/visualization/base.py diff --git a/src/anomalib/models/components/base/anomaly_module.py b/src/anomalib/models/components/base/anomaly_module.py index c7b0d0d420..dace0a7082 100644 --- a/src/anomalib/models/components/base/anomaly_module.py +++ b/src/anomalib/models/components/base/anomaly_module.py @@ -25,6 +25,7 @@ from anomalib.metrics.threshold import Threshold from anomalib.post_processing import OneClassPostProcessor, PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import ImageVisualizer, Visualizer from .export_mixin import ExportMixin @@ -42,6 +43,7 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: super().__init__() logger.info("Initializing %s model.", self.__class__.__name__) @@ -54,6 +56,7 @@ def __init__( self.pre_processor = self._resolve_pre_processor(pre_processor) self.post_processor = self._resolve_post_processor(post_processor) self.evaluator = self._resolve_evaluator(evaluator) + self.visualizer = self._resolve_visualizer(visualizer) self._input_size: tuple[int, int] | None = None self._is_setup = False @@ -79,25 +82,6 @@ def _setup(self) -> None: initialization. """ - def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor | None: - """Resolve and validate which pre-processor to use.. - - Args: - pre_processor: Pre-processor configuration - - True -> use default pre-processor - - False -> no pre-processor - - PreProcessor -> use the provided pre-processor - - Returns: - Configured pre-processor - """ - if isinstance(pre_processor, PreProcessor): - return pre_processor - if isinstance(pre_processor, bool): - return self.configure_pre_processor() if pre_processor else None - msg = f"Invalid pre-processor type: {type(pre_processor)}" - raise TypeError(msg) - def configure_callbacks(self) -> Sequence[Callback] | Callback: """Configure default callbacks for AnomalibModule.""" return [self.pre_processor] if self.pre_processor else [] @@ -170,6 +154,25 @@ def learning_type(self) -> LearningType: """Learning type of the model.""" raise NotImplementedError + def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor | None: + """Resolve and validate which pre-processor to use.. + + Args: + pre_processor: Pre-processor configuration + - True -> use default pre-processor + - False -> no pre-processor + - PreProcessor -> use the provided pre-processor + + Returns: + Configured pre-processor + """ + if isinstance(pre_processor, PreProcessor): + return pre_processor + if isinstance(pre_processor, bool): + return self.configure_pre_processor() if pre_processor else None + msg = f"Invalid pre-processor type: {type(pre_processor)}" + raise TypeError(msg) + @classmethod def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor: """Configure the pre-processor. @@ -289,6 +292,63 @@ def configure_evaluator() -> Evaluator: test_metrics = [image_auroc, image_f1score, pixel_auroc, pixel_f1score] return Evaluator(test_metrics=test_metrics) + def _resolve_visualizer(self, visualizer: Visualizer | bool) -> Visualizer | None: + """Resolve and validate which visualizer to use. + + Args: + visualizer: Visualizer configuration + - True -> use default visualizer + - False -> no visualizer + - Visualizer -> use the provided visualizer + + Returns: + Configured visualizer + """ + if isinstance(visualizer, Visualizer): + return visualizer + if isinstance(visualizer, bool): + return self.configure_visualizer() if visualizer else None + msg = f"Visualizer must be of type Visualizer or bool, got {type(visualizer)}" + raise TypeError(msg) + + @classmethod + def configure_visualizer(cls) -> ImageVisualizer: + """Configure the default visualizer. + + By default, this method returns an ImageVisualizer instance, which is suitable for + visualizing image-based anomaly detection results. However, the visualizer can be + customized based on your needs - for example, using VideoVisualizer for video data + or implementing a custom visualizer for specific visualization requirements. + + Returns: + Visualizer: Configured visualizer instance (ImageVisualizer by default). + + Examples: + Get default ImageVisualizer: + + >>> visualizer = AnomalibModule.configure_visualizer() + + Create model with VideoVisualizer: + + >>> from custom_module import VideoVisualizer + >>> video_visualizer = VideoVisualizer() + >>> model = PatchCore(visualizer=video_visualizer) + + Create model with custom visualizer: + + >>> class CustomVisualizer(Visualizer): + ... def __init__(self): + ... super().__init__() + ... # Custom visualization logic + >>> custom_visualizer = CustomVisualizer() + >>> model = PatchCore(visualizer=custom_visualizer) + + Disable visualization: + + >>> model = PatchCore(visualizer=False) + """ + return ImageVisualizer() + @property def input_size(self) -> tuple[int, int] | None: """Return the effective input size of the model. diff --git a/src/anomalib/models/image/cfa/lightning_model.py b/src/anomalib/models/image/cfa/lightning_model.py index 4c2a9cffe2..9eed15b6a7 100644 --- a/src/anomalib/models/image/cfa/lightning_model.py +++ b/src/anomalib/models/image/cfa/lightning_model.py @@ -20,6 +20,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .loss import CfaLoss from .torch_model import CfaModel @@ -58,11 +59,18 @@ def __init__( num_nearest_neighbors: int = 3, num_hard_negative_features: int = 3, radius: float = 1e-5, + # Anomalib's Auxiliary Components pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.model: CfaModel = CfaModel( backbone=backbone, gamma_c=gamma_c, diff --git a/src/anomalib/models/image/cflow/lightning_model.py b/src/anomalib/models/image/cflow/lightning_model.py index f82872cb84..4dd9c25850 100644 --- a/src/anomalib/models/image/cflow/lightning_model.py +++ b/src/anomalib/models/image/cflow/lightning_model.py @@ -27,6 +27,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .torch_model import CflowModel from .utils import get_logp, positional_encoding_2d @@ -73,8 +74,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.model: CflowModel = CflowModel( backbone=backbone, diff --git a/src/anomalib/models/image/csflow/lightning_model.py b/src/anomalib/models/image/csflow/lightning_model.py index 3324972422..8e9994631a 100644 --- a/src/anomalib/models/image/csflow/lightning_model.py +++ b/src/anomalib/models/image/csflow/lightning_model.py @@ -18,6 +18,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .loss import CsFlowLoss from .torch_model import CsFlowModel @@ -50,8 +51,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) if self.input_size is None: msg = "CsFlow needs input size to build torch model." raise ValueError(msg) diff --git a/src/anomalib/models/image/dfkde/lightning_model.py b/src/anomalib/models/image/dfkde/lightning_model.py index be141c45ef..16ccac6403 100644 --- a/src/anomalib/models/image/dfkde/lightning_model.py +++ b/src/anomalib/models/image/dfkde/lightning_model.py @@ -17,6 +17,7 @@ from anomalib.models.components.classification import FeatureScalingMethod from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .torch_model import DfkdeModel @@ -52,8 +53,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.model = DfkdeModel( layers=layers, diff --git a/src/anomalib/models/image/dfm/lightning_model.py b/src/anomalib/models/image/dfm/lightning_model.py index 61595434c9..96a4388835 100644 --- a/src/anomalib/models/image/dfm/lightning_model.py +++ b/src/anomalib/models/image/dfm/lightning_model.py @@ -18,6 +18,7 @@ from anomalib.models.components import AnomalibModule, MemoryBankMixin from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .torch_model import DFMModel @@ -56,8 +57,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.model: DFMModel = DFMModel( backbone=backbone, diff --git a/src/anomalib/models/image/draem/lightning_model.py b/src/anomalib/models/image/draem/lightning_model.py index 13fa1346e7..84b143f3f5 100644 --- a/src/anomalib/models/image/draem/lightning_model.py +++ b/src/anomalib/models/image/draem/lightning_model.py @@ -21,6 +21,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .loss import DraemLoss from .torch_model import DraemModel @@ -53,8 +54,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.augmenter = PerlinAnomalyGenerator(anomaly_source_path=anomaly_source_path, blend_factor=beta) self.model = DraemModel(sspcab=enable_sspcab) diff --git a/src/anomalib/models/image/dsr/lightning_model.py b/src/anomalib/models/image/dsr/lightning_model.py index 5f9c1c625d..dd80e88ba7 100644 --- a/src/anomalib/models/image/dsr/lightning_model.py +++ b/src/anomalib/models/image/dsr/lightning_model.py @@ -25,6 +25,7 @@ from anomalib.models.image.dsr.torch_model import DsrModel from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer __all__ = ["Dsr"] @@ -55,8 +56,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.automatic_optimization = False self.upsampling_train_ratio = upsampling_train_ratio diff --git a/src/anomalib/models/image/efficient_ad/lightning_model.py b/src/anomalib/models/image/efficient_ad/lightning_model.py index af90173c52..aa99d6a439 100644 --- a/src/anomalib/models/image/efficient_ad/lightning_model.py +++ b/src/anomalib/models/image/efficient_ad/lightning_model.py @@ -24,6 +24,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .torch_model import EfficientAdModel, EfficientAdModelSize, reduce_tensor_elems @@ -78,8 +79,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.imagenet_dir = Path(imagenet_dir) if not isinstance(model_size, EfficientAdModelSize): diff --git a/src/anomalib/models/image/fastflow/lightning_model.py b/src/anomalib/models/image/fastflow/lightning_model.py index 9d51f99489..8a98ea9e7a 100644 --- a/src/anomalib/models/image/fastflow/lightning_model.py +++ b/src/anomalib/models/image/fastflow/lightning_model.py @@ -18,6 +18,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .loss import FastflowLoss from .torch_model import FastflowModel @@ -52,8 +53,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) if self.input_size is None: msg = "Fastflow needs input size to build torch model." raise ValueError(msg) diff --git a/src/anomalib/models/image/fre/lightning_model.py b/src/anomalib/models/image/fre/lightning_model.py index c748705c3a..953fcd4322 100755 --- a/src/anomalib/models/image/fre/lightning_model.py +++ b/src/anomalib/models/image/fre/lightning_model.py @@ -19,6 +19,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .torch_model import FREModel @@ -58,8 +59,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.model: FREModel = FREModel( backbone=backbone, diff --git a/src/anomalib/models/image/ganomaly/lightning_model.py b/src/anomalib/models/image/ganomaly/lightning_model.py index cf7d7525d0..4b48b0b633 100644 --- a/src/anomalib/models/image/ganomaly/lightning_model.py +++ b/src/anomalib/models/image/ganomaly/lightning_model.py @@ -19,6 +19,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .loss import DiscriminatorLoss, GeneratorLoss from .torch_model import GanomalyModel @@ -73,8 +74,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) if self.input_size is None: msg = "GANomaly needs input size to build torch model." raise ValueError(msg) diff --git a/src/anomalib/models/image/padim/lightning_model.py b/src/anomalib/models/image/padim/lightning_model.py index d658b1cfa1..4a223c9e62 100644 --- a/src/anomalib/models/image/padim/lightning_model.py +++ b/src/anomalib/models/image/padim/lightning_model.py @@ -17,6 +17,7 @@ from anomalib.models.components import AnomalibModule, MemoryBankMixin from anomalib.post_processing import OneClassPostProcessor, PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .torch_model import PadimModel @@ -52,8 +53,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.model: PadimModel = PadimModel( backbone=backbone, diff --git a/src/anomalib/models/image/patchcore/lightning_model.py b/src/anomalib/models/image/patchcore/lightning_model.py index d2fb922da3..689b7ac81f 100644 --- a/src/anomalib/models/image/patchcore/lightning_model.py +++ b/src/anomalib/models/image/patchcore/lightning_model.py @@ -20,6 +20,7 @@ from anomalib.models.components import AnomalibModule, MemoryBankMixin from anomalib.post_processing import OneClassPostProcessor, PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .torch_model import PatchcoreModel @@ -55,8 +56,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.model: PatchcoreModel = PatchcoreModel( backbone=backbone, diff --git a/src/anomalib/models/image/reverse_distillation/lightning_model.py b/src/anomalib/models/image/reverse_distillation/lightning_model.py index 4fb6e06b2f..3eb3bf903c 100644 --- a/src/anomalib/models/image/reverse_distillation/lightning_model.py +++ b/src/anomalib/models/image/reverse_distillation/lightning_model.py @@ -18,6 +18,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .anomaly_map import AnomalyMapGenerationMode from .loss import ReverseDistillationLoss @@ -50,8 +51,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) if self.input_size is None: msg = "Input size is required for Reverse Distillation model." raise ValueError(msg) diff --git a/src/anomalib/models/image/stfpm/lightning_model.py b/src/anomalib/models/image/stfpm/lightning_model.py index b94d5b6639..f3daafe407 100644 --- a/src/anomalib/models/image/stfpm/lightning_model.py +++ b/src/anomalib/models/image/stfpm/lightning_model.py @@ -19,6 +19,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .loss import STFPMLoss from .torch_model import STFPMModel @@ -46,8 +47,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.model = STFPMModel(backbone=backbone, layers=layers) self.loss = STFPMLoss() diff --git a/src/anomalib/models/image/uflow/lightning_model.py b/src/anomalib/models/image/uflow/lightning_model.py index 445cfcd6ee..bfd51195ca 100644 --- a/src/anomalib/models/image/uflow/lightning_model.py +++ b/src/anomalib/models/image/uflow/lightning_model.py @@ -21,6 +21,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .loss import UFlowLoss from .torch_model import UflowModel @@ -51,6 +52,7 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: """Uflow model. @@ -69,8 +71,16 @@ def __init__( evaluator (Evaluator, optional): Evaluator for the model. This is used to evaluate the model. Defaults to ``True``. + visualizer (Visualizer, optional): Visualizer for the model. + This is used to visualize the model. + Defaults to ``True``. """ - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) if self.input_size is None: msg = "Input size is required for UFlow model." raise ValueError(msg) diff --git a/src/anomalib/models/image/winclip/lightning_model.py b/src/anomalib/models/image/winclip/lightning_model.py index 9f16558619..23a7cf23a1 100644 --- a/src/anomalib/models/image/winclip/lightning_model.py +++ b/src/anomalib/models/image/winclip/lightning_model.py @@ -22,6 +22,7 @@ from anomalib.models.components import AnomalibModule from anomalib.post_processing import OneClassPostProcessor, PostProcessor from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer from .torch_model import WinClipModel @@ -58,8 +59,14 @@ def __init__( pre_processor: PreProcessor | bool = True, post_processor: PostProcessor | bool = True, evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator) + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) self.model = WinClipModel(scales=scales, apply_transform=False) self.class_name = class_name diff --git a/src/anomalib/visualization/__init__.py b/src/anomalib/visualization/__init__.py index ca0b7bc138..989f4cc34c 100644 --- a/src/anomalib/visualization/__init__.py +++ b/src/anomalib/visualization/__init__.py @@ -3,10 +3,13 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from .base import Visualizer from .image import ImageVisualizer, visualize_anomaly_map, visualize_mask from .image.item_visualizer import visualize_image_item __all__ = [ + # Base visualizer class + "Visualizer", # Image visualizer class "ImageVisualizer", # Image visualization functions diff --git a/src/anomalib/visualization/base.py b/src/anomalib/visualization/base.py new file mode 100644 index 0000000000..dc49a85401 --- /dev/null +++ b/src/anomalib/visualization/base.py @@ -0,0 +1,14 @@ +"""Base Visualizer.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from lightning.pytorch import Callback + + +class Visualizer(Callback): + """Base class for all visualizers. + + In Anomalib, the visualizer is used to visualize the results of the model + during the testing and prediction phases. + """ diff --git a/src/anomalib/visualization/image/visualizer.py b/src/anomalib/visualization/image/visualizer.py index c8eafc8ae8..97d11546ea 100644 --- a/src/anomalib/visualization/image/visualizer.py +++ b/src/anomalib/visualization/image/visualizer.py @@ -6,11 +6,12 @@ from pathlib import Path from typing import Any -from lightning.pytorch import Callback, Trainer +from lightning.pytorch import Trainer from anomalib.data import ImageBatch from anomalib.models import AnomalibModule from anomalib.utils.path import generate_output_filename +from anomalib.visualization.base import Visualizer from .item_visualizer import ( DEFAULT_FIELDS_CONFIG, @@ -20,7 +21,7 @@ ) -class ImageVisualizer(Callback): +class ImageVisualizer(Visualizer): """Image Visualizer. This class is responsible for visualizing images and their corresponding anomaly maps @@ -127,6 +128,7 @@ def __init__( text_config: dict[str, Any] | None = None, output_dir: str | Path | None = None, ) -> None: + super().__init__() self.fields = fields or ["image", "gt_mask"] self.overlay_fields = overlay_fields or [("image", ["anomaly_map"]), ("image", ["pred_mask"])] self.field_size = field_size