Skip to content

Commit

Permalink
Add visualizer to models
Browse files Browse the repository at this point in the history
Signed-off-by: Samet Akcay <[email protected]>
  • Loading branch information
samet-akcay committed Dec 9, 2024
1 parent ef2deb0 commit e11c663
Show file tree
Hide file tree
Showing 21 changed files with 240 additions and 38 deletions.
98 changes: 79 additions & 19 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from anomalib.metrics.threshold import Threshold
from anomalib.post_processing import OneClassPostProcessor, PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import ImageVisualizer, Visualizer

from .export_mixin import ExportMixin

Expand All @@ -42,6 +43,7 @@ def __init__(
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__()
logger.info("Initializing %s model.", self.__class__.__name__)
Expand All @@ -54,6 +56,7 @@ def __init__(
self.pre_processor = self._resolve_pre_processor(pre_processor)
self.post_processor = self._resolve_post_processor(post_processor)
self.evaluator = self._resolve_evaluator(evaluator)
self.visualizer = self._resolve_visualizer(visualizer)

self._input_size: tuple[int, int] | None = None
self._is_setup = False
Expand All @@ -79,25 +82,6 @@ def _setup(self) -> None:
initialization.
"""

def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor | None:
"""Resolve and validate which pre-processor to use..
Args:
pre_processor: Pre-processor configuration
- True -> use default pre-processor
- False -> no pre-processor
- PreProcessor -> use the provided pre-processor
Returns:
Configured pre-processor
"""
if isinstance(pre_processor, PreProcessor):
return pre_processor
if isinstance(pre_processor, bool):
return self.configure_pre_processor() if pre_processor else None
msg = f"Invalid pre-processor type: {type(pre_processor)}"
raise TypeError(msg)

def configure_callbacks(self) -> Sequence[Callback] | Callback:
"""Configure default callbacks for AnomalibModule."""
return [self.pre_processor] if self.pre_processor else []
Expand Down Expand Up @@ -170,6 +154,25 @@ def learning_type(self) -> LearningType:
"""Learning type of the model."""
raise NotImplementedError

def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor | None:
"""Resolve and validate which pre-processor to use..
Args:
pre_processor: Pre-processor configuration
- True -> use default pre-processor
- False -> no pre-processor
- PreProcessor -> use the provided pre-processor
Returns:
Configured pre-processor
"""
if isinstance(pre_processor, PreProcessor):
return pre_processor
if isinstance(pre_processor, bool):
return self.configure_pre_processor() if pre_processor else None
msg = f"Invalid pre-processor type: {type(pre_processor)}"
raise TypeError(msg)

@classmethod
def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor:
"""Configure the pre-processor.
Expand Down Expand Up @@ -289,6 +292,63 @@ def configure_evaluator() -> Evaluator:
test_metrics = [image_auroc, image_f1score, pixel_auroc, pixel_f1score]
return Evaluator(test_metrics=test_metrics)

def _resolve_visualizer(self, visualizer: Visualizer | bool) -> Visualizer | None:
"""Resolve and validate which visualizer to use.
Args:
visualizer: Visualizer configuration
- True -> use default visualizer
- False -> no visualizer
- Visualizer -> use the provided visualizer
Returns:
Configured visualizer
"""
if isinstance(visualizer, Visualizer):
return visualizer
if isinstance(visualizer, bool):
return self.configure_visualizer() if visualizer else None
msg = f"Visualizer must be of type Visualizer or bool, got {type(visualizer)}"
raise TypeError(msg)

@classmethod
def configure_visualizer(cls) -> ImageVisualizer:
"""Configure the default visualizer.
By default, this method returns an ImageVisualizer instance, which is suitable for
visualizing image-based anomaly detection results. However, the visualizer can be
customized based on your needs - for example, using VideoVisualizer for video data
or implementing a custom visualizer for specific visualization requirements.
Returns:
Visualizer: Configured visualizer instance (ImageVisualizer by default).
Examples:
Get default ImageVisualizer:
>>> visualizer = AnomalibModule.configure_visualizer()
Create model with VideoVisualizer:
>>> from custom_module import VideoVisualizer
>>> video_visualizer = VideoVisualizer()
>>> model = PatchCore(visualizer=video_visualizer)
Create model with custom visualizer:
>>> class CustomVisualizer(Visualizer):
... def __init__(self):
... super().__init__()
... # Custom visualization logic
>>> custom_visualizer = CustomVisualizer()
>>> model = PatchCore(visualizer=custom_visualizer)
Disable visualization:
>>> model = PatchCore(visualizer=False)
"""
return ImageVisualizer()

@property
def input_size(self) -> tuple[int, int] | None:
"""Return the effective input size of the model.
Expand Down
10 changes: 9 additions & 1 deletion src/anomalib/models/image/cfa/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from anomalib.models.components import AnomalibModule
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .loss import CfaLoss
from .torch_model import CfaModel
Expand Down Expand Up @@ -58,11 +59,18 @@ def __init__(
num_nearest_neighbors: int = 3,
num_hard_negative_features: int = 3,
radius: float = 1e-5,
# Anomalib's Auxiliary Components
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)
self.model: CfaModel = CfaModel(
backbone=backbone,
gamma_c=gamma_c,
Expand Down
9 changes: 8 additions & 1 deletion src/anomalib/models/image/cflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from anomalib.models.components import AnomalibModule
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .torch_model import CflowModel
from .utils import get_logp, positional_encoding_2d
Expand Down Expand Up @@ -73,8 +74,14 @@ def __init__(
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)

self.model: CflowModel = CflowModel(
backbone=backbone,
Expand Down
9 changes: 8 additions & 1 deletion src/anomalib/models/image/csflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from anomalib.models.components import AnomalibModule
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .loss import CsFlowLoss
from .torch_model import CsFlowModel
Expand Down Expand Up @@ -50,8 +51,14 @@ def __init__(
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)
if self.input_size is None:
msg = "CsFlow needs input size to build torch model."
raise ValueError(msg)
Expand Down
9 changes: 8 additions & 1 deletion src/anomalib/models/image/dfkde/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from anomalib.models.components.classification import FeatureScalingMethod
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .torch_model import DfkdeModel

Expand Down Expand Up @@ -52,8 +53,14 @@ def __init__(
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)

self.model = DfkdeModel(
layers=layers,
Expand Down
9 changes: 8 additions & 1 deletion src/anomalib/models/image/dfm/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from anomalib.models.components import AnomalibModule, MemoryBankMixin
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .torch_model import DFMModel

Expand Down Expand Up @@ -56,8 +57,14 @@ def __init__(
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)

self.model: DFMModel = DFMModel(
backbone=backbone,
Expand Down
9 changes: 8 additions & 1 deletion src/anomalib/models/image/draem/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from anomalib.models.components import AnomalibModule
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .loss import DraemLoss
from .torch_model import DraemModel
Expand Down Expand Up @@ -53,8 +54,14 @@ def __init__(
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)

self.augmenter = PerlinAnomalyGenerator(anomaly_source_path=anomaly_source_path, blend_factor=beta)
self.model = DraemModel(sspcab=enable_sspcab)
Expand Down
9 changes: 8 additions & 1 deletion src/anomalib/models/image/dsr/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from anomalib.models.image.dsr.torch_model import DsrModel
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

__all__ = ["Dsr"]

Expand Down Expand Up @@ -55,8 +56,14 @@ def __init__(
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)

self.automatic_optimization = False
self.upsampling_train_ratio = upsampling_train_ratio
Expand Down
9 changes: 8 additions & 1 deletion src/anomalib/models/image/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from anomalib.models.components import AnomalibModule
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .torch_model import EfficientAdModel, EfficientAdModelSize, reduce_tensor_elems

Expand Down Expand Up @@ -78,8 +79,14 @@ def __init__(
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)

self.imagenet_dir = Path(imagenet_dir)
if not isinstance(model_size, EfficientAdModelSize):
Expand Down
9 changes: 8 additions & 1 deletion src/anomalib/models/image/fastflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from anomalib.models.components import AnomalibModule
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .loss import FastflowLoss
from .torch_model import FastflowModel
Expand Down Expand Up @@ -52,8 +53,14 @@ def __init__(
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)
if self.input_size is None:
msg = "Fastflow needs input size to build torch model."
raise ValueError(msg)
Expand Down
9 changes: 8 additions & 1 deletion src/anomalib/models/image/fre/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from anomalib.models.components import AnomalibModule
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .torch_model import FREModel

Expand Down Expand Up @@ -58,8 +59,14 @@ def __init__(
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)

self.model: FREModel = FREModel(
backbone=backbone,
Expand Down
Loading

0 comments on commit e11c663

Please sign in to comment.