Skip to content

Commit

Permalink
Refactor callback handling in Engine class by removing post-process…
Browse files Browse the repository at this point in the history
…or and metrics callbacks. Simplify callback configuration to streamline the process and enhance maintainability.
  • Loading branch information
samet-akcay committed Dec 9, 2024
1 parent fab2bb6 commit b3c656d
Showing 1 changed file with 2 additions and 15 deletions.
17 changes: 2 additions & 15 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from anomalib.deploy import CompressionType, ExportType
from anomalib.models import AnomalibModule
from anomalib.utils.path import create_versioned_dir
from anomalib.visualization import ImageVisualizer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -258,7 +257,7 @@ def _setup_trainer(self, model: AnomalibModule) -> None:
self._cache.update(model)

# Setup anomalib callbacks to be used with the trainer
self._setup_anomalib_callbacks(model)
self._setup_anomalib_callbacks()

# Temporarily set devices to 1 to avoid issues with multiple processes
self._cache.args["devices"] = 1
Expand All @@ -267,7 +266,7 @@ def _setup_trainer(self, model: AnomalibModule) -> None:
if self._trainer is None:
self._trainer = Trainer(**self._cache.args)

def _setup_anomalib_callbacks(self, model: AnomalibModule) -> None:
def _setup_anomalib_callbacks(self) -> None:
"""Set up callbacks for the trainer."""
_callbacks: list[Callback] = []

Expand All @@ -282,18 +281,6 @@ def _setup_anomalib_callbacks(self, model: AnomalibModule) -> None:
),
)

# Add the post-processor callback.
if isinstance(model.post_processor, Callback):
_callbacks.append(model.post_processor)

# Add the metrics callback.
if isinstance(model.evaluator, Callback):
_callbacks.append(model.evaluator)

# Add the image visualizer callback if it is passed by the user.
if not any(isinstance(callback, ImageVisualizer) for callback in self._cache.args["callbacks"]):
_callbacks.append(ImageVisualizer())

_callbacks.append(TimerCallback())

# Combine the callbacks, and update the trainer callbacks.
Expand Down

0 comments on commit b3c656d

Please sign in to comment.