From 9cb6e8c518d096f8c596b6ffff647b490b407701 Mon Sep 17 00:00:00 2001 From: Victor Prins Date: Fri, 1 Dec 2023 16:19:49 +0100 Subject: [PATCH] Add `@override` for files in `src/lightning/fabric/loggers` (#19090) --- src/lightning/fabric/loggers/csv_logs.py | 9 +++++++++ src/lightning/fabric/loggers/tensorboard.py | 10 ++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py index 383f820162582..009edf97cd853 100644 --- a/src/lightning/fabric/loggers/csv_logs.py +++ b/src/lightning/fabric/loggers/csv_logs.py @@ -19,6 +19,7 @@ from typing import Any, Dict, List, Optional, Set, Union from torch import Tensor +from typing_extensions import override from lightning.fabric.loggers.logger import Logger, rank_zero_experiment from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem @@ -73,6 +74,7 @@ def __init__( self._flush_logs_every_n_steps = flush_logs_every_n_steps @property + @override def name(self) -> str: """Gets the name of the experiment. @@ -83,6 +85,7 @@ def name(self) -> str: return self._name @property + @override def version(self) -> Union[int, str]: """Gets the version of the experiment. @@ -95,11 +98,13 @@ def version(self) -> Union[int, str]: return self._version @property + @override def root_dir(self) -> str: """Gets the save directory where the versioned CSV experiments are saved.""" return self._root_dir @property + @override def log_dir(self) -> str: """The log directory for this run. @@ -128,10 +133,12 @@ def experiment(self) -> "_ExperimentWriter": self._experiment = _ExperimentWriter(log_dir=self.log_dir) return self._experiment + @override @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] raise NotImplementedError("The `CSVLogger` does not yet support logging hyperparameters.") + @override @rank_zero_only def log_metrics( # type: ignore[override] self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None @@ -143,11 +150,13 @@ def log_metrics( # type: ignore[override] if (step + 1) % self._flush_logs_every_n_steps == 0: self.save() + @override @rank_zero_only def save(self) -> None: super().save() self.experiment.save() + @override @rank_zero_only def finalize(self, status: str) -> None: if self._experiment is None: diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index a94a12b490282..3828b9ee0f4f3 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -20,6 +20,7 @@ from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.nn import Module +from typing_extensions import override from lightning.fabric.loggers.logger import Logger, rank_zero_experiment from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem @@ -109,6 +110,7 @@ def __init__( self._kwargs = kwargs @property + @override def name(self) -> str: """Get the name of the experiment. @@ -119,6 +121,7 @@ def name(self) -> str: return self._name @property + @override def version(self) -> Union[int, str]: """Get the experiment version. @@ -131,6 +134,7 @@ def version(self) -> Union[int, str]: return self._version @property + @override def root_dir(self) -> str: """Gets the save directory where the TensorBoard experiments are saved. @@ -141,6 +145,7 @@ def root_dir(self) -> str: return self._root_dir @property + @override def log_dir(self) -> str: """The directory for this run's tensorboard checkpoint. @@ -191,6 +196,7 @@ def experiment(self) -> "SummaryWriter": self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment + @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" @@ -212,6 +218,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor." ) from ex + @override @rank_zero_only def log_hyperparams( # type: ignore[override] self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None @@ -251,6 +258,7 @@ def log_hyperparams( # type: ignore[override] writer.add_summary(ssi) writer.add_summary(sei) + @override @rank_zero_only def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: model_example_input = getattr(model, "example_input_array", None) @@ -278,10 +286,12 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None else: self.experiment.add_graph(model, input_array) + @override @rank_zero_only def save(self) -> None: self.experiment.flush() + @override @rank_zero_only def finalize(self, status: str) -> None: if self._experiment is not None: