diff --git a/vis4d/config/default/pl_trainer.py b/vis4d/config/default/pl_trainer.py index 6e43fd974..3fbbce793 100644 --- a/vis4d/config/default/pl_trainer.py +++ b/vis4d/config/default/pl_trainer.py @@ -1,7 +1,7 @@ """Default runtime configuration for PyTorch Lightning.""" import inspect -import pytorch_lightning as pl +from lightning import Trainer from vis4d.config import FieldConfigDict from vis4d.config.typing import ExperimentConfig @@ -12,7 +12,7 @@ def get_default_pl_trainer_cfg(config: ExperimentConfig) -> ExperimentConfig: pl_trainer = FieldConfigDict() # PL Trainer arguments - for k, v in inspect.signature(pl.Trainer).parameters.items(): + for k, v in inspect.signature(Trainer).parameters.items(): if not k in {"callbacks", "devices", "logger", "strategy"}: pl_trainer[k] = v.default diff --git a/vis4d/engine/callbacks/logging.py b/vis4d/engine/callbacks/logging.py index c62414e7e..8a35337eb 100644 --- a/vis4d/engine/callbacks/logging.py +++ b/vis4d/engine/callbacks/logging.py @@ -51,6 +51,7 @@ def on_train_epoch_start( """Hook to run at the start of a training epoch.""" if self.epoch_based: self.train_timer.reset() + self.last_step = 0 self._metrics.clear() elif trainer_state["global_step"] == 0: self.train_timer.reset() diff --git a/vis4d/engine/optim/scheduler.py b/vis4d/engine/optim/scheduler.py index d539cdd76..bf0e1a30e 100644 --- a/vis4d/engine/optim/scheduler.py +++ b/vis4d/engine/optim/scheduler.py @@ -7,7 +7,7 @@ from torch.optim.lr_scheduler import LRScheduler from vis4d.common.typing import DictStrAny -from vis4d.config import instantiate_classes +from vis4d.config import copy_and_resolve_references, instantiate_classes from vis4d.config.typing import LrSchedulerConfig @@ -30,7 +30,9 @@ def __init__( steps_per_epoch: int = -1, ) -> None: """Initialize LRSchedulerWrapper.""" - self.lr_schedulers_cfg = lr_schedulers_cfg + self.lr_schedulers_cfg: list[ + LrSchedulerConfig + ] = copy_and_resolve_references(lr_schedulers_cfg) self.lr_schedulers: dict[int, LRSchedulerDict] = {} super().__init__(optimizer) self.steps_per_epoch = steps_per_epoch